# Distillation 

Ce a pour but d'entrainer un MPS multicouche à l'aide d'un réseau complètement connecté

In [None]:
%pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

In [None]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train = 1
nb_test = 1
chosen_bs = 1
input_size = 28*28
nb_classes = 10

# Teacher hyperparameters
chosen_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_epochs_fcnn = 15
hidden_size = 70
chosen_loss = nn.CrossEntropyLoss()
#Optimizer parameters
chosen_lr = 0.001
chosen_momentum = 0.9

# Student hyperparameters
# MPS parameters
bond_dim = 20
adaptive_mode = False
periodic_bc = False
feature_map = lambda x : torch.tensor([1, x]).to(chosen_device)
feature_dim = 2
# Training parameters
gn_var = 0.3 #added gaussian noise variance
gn_mean = 0 #added gaussian noise mean
n_epochs_lmps = 1
mps_learn_rate = 0.0001
mps_reg = 0.0
mps_chosen_loss = nn.CrossEntropyLoss().to(chosen_device)

In [None]:
# Import the mnist dataset
train_set = torchvision.datasets.MNIST(root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train))

train_iterator = torch.utils.data.DataLoader(dataset = train_set, 
    sampler = train_subset, batch_size=chosen_bs)


test_set = torchvision.datasets.MNIST(root = './datasets',
 train = False, transform = transforms.ToTensor(),  download = True)

test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test))

test_iterator = torch.utils.data.DataLoader(dataset = test_set, 
    sampler = test_subset, batch_size = chosen_bs)

In [None]:

# Create the fcnn class
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.relu = nn.ReLU()
        self.lin1 = nn.Linear(784, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, hidden_size)
        self.lin4 = nn.Linear(hidden_size, hidden_size)
        self.lin5 = nn.Linear(hidden_size, hidden_size)
        self.lin6 = nn.Linear(hidden_size, 10)

    def middleforward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        return y

    def forward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        y = self.lin4(y)
        y = self.relu(y)
        y = self.lin5(y)
        y = self.relu(y)
        y = self.lin6(y)
        y = self.relu(y)
        return y

#Instantiate and put the model on the chosen device
fcnn_teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
optimizer = torch.optim.Adam(fcnn_teacher.parameters())

#Training loop
for epoch in range(n_epochs_fcnn):
    for (x_mb, y_mb) in train_iterator:
        # Reshape the train_tuple and put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Foward propagation
        y_hat_mb = fcnn_teacher(x_mb)
        loss = chosen_loss(y_hat_mb, y_mb)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(loss.item())

# get accuracy
for (x_mb, y_mb) in test_iterator:
    print("testing teacher")
    print(MulticlassAccuracy(num_classes=10)(fcnn_teacher(x_mb.reshape(-1, 784)), y_mb))



In [None]:

# Initialize the MPS modules
student_mps1 = MPS(
    input_dim = input_size,
    feature_dim = feature_dim,
    output_dim = hidden_size,
    bond_dim = bond_dim,
    adaptive_mode = adaptive_mode,
    periodic_bc = periodic_bc,
).to(chosen_device)
student_mps1.register_feature_map(feature_map)

student_mps2 = MPS(
    input_dim = hidden_size, #crucial setting
    feature_dim = feature_dim,
    output_dim = nb_classes,
    bond_dim = bond_dim,
    adaptive_mode = adaptive_mode,
    periodic_bc = periodic_bc,
).to(chosen_device)
student_mps2.register_feature_map(feature_map)

def student(x):
    return student_mps2( student_mps1(x) )


# Instantiate the optimizer and softmax
lmps1_optimizer = torch.optim.Adam(student_mps1.parameters(), lr = mps_learn_rate,
                                  weight_decay = mps_reg)

lmps2_optimizer = torch.optim.Adam(student_mps2.parameters(), lr = mps_learn_rate,
                                  weight_decay = mps_reg)



softmax = nn.Softmax(dim=1)

# Training loop for the student_mps1
for epoch in range(n_epochs_lmps):
    for (x_mb, _) in train_iterator:
        # Reshape, add gaussian noise, put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        x_mb = (x_mb + torch.randn(size=x_mb.size())).to(chosen_device)

        # Get softmax of output of teacher in middlelayer!
        y_mb = (softmax(fcnn_teacher.middleforward(x_mb))).to(chosen_device)

        # Foward propagation
        y_hat_mb = student_mps1(x_mb)
        loss = mps_chosen_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        lmps1_optimizer.step()
        lmps1_optimizer.zero_grad()

    print(loss.item())

# Training loop for the student_mps2
for epoch in range(n_epochs_lmps):
    for (x_mb, _) in train_iterator:
        # Reshape, add gaussian noise, put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        x_mb = (x_mb + torch.randn(size=x_mb.size())).to(chosen_device)

        # Get softmax of output of teacher
        y_mb = ( softmax(fcnn_teacher(x_mb)) ).to(chosen_device)

        # Foward propagation
        y_hat_mb = student_mps1(x_mb)
        # Apply second forward propagation on first output
        y_hat_mb = student_mps2(y_hat_mb)
        loss = mps_chosen_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        lmps2_optimizer.step()
        lmps2_optimizer.zero_grad()

    print(loss.item())



# Get the validation set classification accuracy
student_acc_score = 0
student_acc_metric = MulticlassAccuracy(num_classes=nb_classes).to(chosen_device)
for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.reshape(-1, 784).to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    # add the number of datapoints we classified right
    student_acc_score += x_mb.size()[0] * student_acc_metric( student(x_mb), y_mb )
print("Student_acc_score:")
print(student_acc_score / nb_test) #divide by total size
