# Distillation 

Ce a pour but d'entrainer un MPS multicouche à l'aide d'un réseau complètement connecté

In [12]:
%pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jemisjoky/TorchMPS.git
  Cloning https://github.com/jemisjoky/TorchMPS.git to /tmp/pip-req-build-g19ez2yt
  Running command git clone --filter=blob:none --quiet https://github.com/jemisjoky/TorchMPS.git /tmp/pip-req-build-g19ez2yt
  Resolved https://github.com/jemisjoky/TorchMPS.git to commit 6c0bc1a8e2c15acba8570ca9ffe2b4a0c7135165
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [13]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train = 2
nb_test = 50
chosen_bs = 100
chosen_bs = min(nb_train, chosen_bs)
input_size = 28*28
nb_classes = 10

# Teacher hyperparameters
chosen_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_epochs_fcnn = 20
hidden_size = 70
teacher_loss = nn.CrossEntropyLoss()
#Optimizer parameters
chosen_lr = 0.001
chosen_momentum = 0.9

# Student hyperparameters
# MPS parameters
bond_dim = 20
adaptive_mode = False
periodic_bc = False
feature_map = lambda x : torch.tensor([1, x]).to(chosen_device)
feature_dim = 2
# Training parameters
gauss_epochs = 5 # number of epochs with added gaussian noise
gn_var = 0.3 #added gaussian noise variance
gn_mean = 0 #added gaussian noise mean
n_epochs_lmps = 15
mps_learn_rate = 0.0001
mps_reg = 0.0
mps1_distill_loss = nn.MSELoss()
# We choose this loss 
mps2_distill_loss = nn.KLDivLoss(reduction = "batchmean", log_target = True)

In [14]:
# Import the mnist dataset
train_set = torchvision.datasets.MNIST(root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train))

train_iterator = torch.utils.data.DataLoader(dataset = train_set, 
    sampler = train_subset, batch_size=chosen_bs)


test_set = torchvision.datasets.MNIST(root = './datasets',
 train = False, transform = transforms.ToTensor(),  download = True)

test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test))

test_iterator = torch.utils.data.DataLoader(dataset = test_set, 
    sampler = test_subset, batch_size = chosen_bs)

In [15]:

# Create the fcnn class
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.relu = nn.ReLU()
        self.lin1 = nn.Linear(784, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, hidden_size)
        self.lin4 = nn.Linear(hidden_size, hidden_size)
        self.lin5 = nn.Linear(hidden_size, hidden_size)
        self.lin6 = nn.Linear(hidden_size, 10)

    def middleforward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        return y

    def forward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        y = self.lin4(y)
        y = self.relu(y)
        y = self.lin5(y)
        y = self.relu(y)
        y = self.lin6(y)
        y = self.relu(y)
        return y

#Instantiate and put the model on the chosen device
fcnn_teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
optimizer = torch.optim.Adam(fcnn_teacher.parameters())

#Training loop
for epoch in range(n_epochs_fcnn):
    for (x_mb, y_mb) in train_iterator:
        # Reshape the train_tuple and put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Foward propagation
        y_hat_mb = fcnn_teacher(x_mb)
        # We use soft-cross-entropy
        # Softmax is precomputed in the loss already in PyTorch
        # Ground truth is scalar tensor
        loss = teacher_loss(y_hat_mb, y_mb)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    print(loss.item())

# get accuracy
# Get the validation set classification accuracy
teacher_acc_score = 0
teacher_acc_metric = MulticlassAccuracy(num_classes=nb_classes).to(chosen_device)
for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.reshape(-1, 784).to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    # add the number of datapoints we classified right
    teacher_acc_score += x_mb.size()[0] * teacher_acc_metric( fcnn_teacher(x_mb), y_mb )
print("The teacher's accuracy score is:")
print(teacher_acc_score / nb_test) #divide by total size



2.2214975357055664
2.209477186203003
2.1985538005828857
2.187264919281006
2.1749372482299805
2.161470413208008
2.1465814113616943
2.129547595977783
2.109717607498169
2.0863301753997803
2.0580015182495117
2.0244317054748535
1.986724853515625
1.9402668476104736
1.8847534656524658
1.8179469108581543
1.7379167079925537
1.6422667503356934
1.5310845375061035
1.4052988290786743
The teacher's accuracy score is:
tensor(0.0160, device='cuda:0')


In [20]:

# Initialize the MPS modules
student_mps1 = MPS(
    input_dim = input_size,
    feature_dim = feature_dim,
    output_dim = hidden_size,
    bond_dim = bond_dim,
    adaptive_mode = adaptive_mode,
    periodic_bc = periodic_bc,
).to(chosen_device)
student_mps1.register_feature_map(feature_map)

student_mps2 = MPS(
    input_dim = hidden_size, #crucial setting
    feature_dim = feature_dim,
    output_dim = nb_classes,
    bond_dim = bond_dim,
    adaptive_mode = adaptive_mode,
    periodic_bc = periodic_bc,
).to(chosen_device)
student_mps2.register_feature_map(feature_map)

def student(x):
    return student_mps2( student_mps1(x) )


# Instantiate the optimizer and softmax
lmps1_optimizer = torch.optim.Adam(student_mps1.parameters(), 
    lr = mps_learn_rate, weight_decay = mps_reg)

lmps2_optimizer = torch.optim.Adam(student_mps2.parameters(), 
    lr = mps_learn_rate, weight_decay = mps_reg)

print("turn to the second mps")
# Training loop for the student_mps1
for epoch in range(n_epochs_lmps):
    for (x_mb, _) in train_iterator:

        # Reshape put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)

        # Added gaussian noise for the last epochs
        if epoch > (n_epochs_lmps - gauss_epochs):
          x_mb = x_mb + torch.randn(size=x_mb.size()).to(chosen_device)

        # Get output of teacher in middlelayer
        # Softmax is precomputed in the loss already
        y_mb = fcnn_teacher.middleforward(x_mb).to(chosen_device)

        # Foward propagation
        y_hat_mb = student_mps1(x_mb)
        # We use mean squared error loss
        loss = mps1_distill_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        lmps1_optimizer.step()
        lmps1_optimizer.zero_grad()

    print(loss.item())

softmax = torch.nn.Softmax(dim=0)

# Training loop for the student_mps2
for epoch in range(n_epochs_lmps):
    for (x_mb, _) in train_iterator:
        # Reshape put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)

        # Added gaussian noise
        if epoch > (n_epochs_lmps - gauss_epochs):
          x_mb = x_mb + torch.randn(size=x_mb.size()).to(chosen_device)

        # Get log_softmax of logit outputs of teacher
        teacher_output = fcnn_teacher(x_mb)
        #teacher_output = torch.logit(fcnn_teacher(x_mb))
        teacher_output = nn.functional.log_softmax( teacher_output, dim=1 ) 
        # Get Log_softmax of outputs of student
        student_output = student(x_mb)
        #student_output = torch.logit( student(x_mb) )
        student_output = nn.functional.log_softmax( student_output, dim=1 ) 

        # Chosen loss is Kull divergence
        loss = mps2_distill_loss(student_output, teacher_output)
        

        # Backpropagation
        loss.backward()
        lmps2_optimizer.step()
        lmps2_optimizer.zero_grad()

    print(loss.item())



# Get the validation set classification accuracy
student_acc_score = 0
student_acc_metric = MulticlassAccuracy(num_classes=nb_classes).to(chosen_device)
for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.reshape(-1, 784).to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    # add the number of datapoints we classified right
    student_acc_score += x_mb.size()[0] * student_acc_metric( student(x_mb), y_mb )
print("Student_acc_score:")
print(student_acc_score / nb_test) #divide by total size


turn to the second mps
0.6938904523849487
0.5102578401565552
0.37086889147758484
0.37950247526168823
0.36000439524650574
0.36592897772789
0.36413097381591797
0.35697558522224426
0.3576754629611969
0.35902753472328186
0.3585229516029358
0.2780931890010834
0.34702497720718384
0.3985869586467743
0.34828054904937744
0.2905879020690918
0.29051995277404785
0.2904491126537323
0.29035505652427673
0.2902480959892273
0.29010796546936035
0.2899174392223358
0.28965723514556885
0.28929564356803894
0.2887907028198242
0.28808218240737915
0.23274168372154236
0.2906225621700287
0.3276994228363037
0.3096701204776764
Student_acc_score:
tensor(0.0160, device='cuda:0')
