In [72]:
%pip install "git+https://github.com/ScierKnave/TorchMPS.git"
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS
import math

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/ScierKnave/TorchMPS.git
  Cloning https://github.com/ScierKnave/TorchMPS.git to /tmp/pip-req-build-zhn8pj2k
  Running command git clone --filter=blob:none --quiet https://github.com/ScierKnave/TorchMPS.git /tmp/pip-req-build-zhn8pj2k
  Resolved https://github.com/ScierKnave/TorchMPS.git to commit f716a08e15d0af50dbfdfc435ab9604e82562ea3
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Hyperparameters

In [73]:
# FC to 2-MPS

# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train_HP = 2000
nb_test_HP = 500
batch_sz_HP = 150
batch_sz_HP = min(batch_sz_HP, nb_train_HP)
nb_classes_HP = 10

# Teacher hyperparameters
nepochs_teacher_HP = 40
teacher_loss_HP = nn.CrossEntropyLoss()
teacher_lr_HP = 1e-4
teacher_reg_HP = 0.1
teacher_hidden_size_HP = 18432
# Student hyperparameters
# MPS parameters
bond_dim_HP = 10
adaptive_mode_HP = False
periodic_bc_HP = False

# Training parameters
nepochs_student_HP = 25 
student_lr_HP = 1e-4
student_reg_HP = 0.01
student_loss_HP = nn.KLDivLoss(reduction = "batchmean", log_target = True)

# Gaussian parameters
gauss_epochs_HP = 0 # number of epochs with added gaussian noise
gn_var_HP = 0.3 #added gaussian noise variance
gn_mean_HP = 0 #added gaussian noise mean
#nepochs_student_HP = 25 + gauss_epochs_HP


# Premilinaries: Importing the data and utils subroutines

In [74]:
# Import the mnist train dataset
train_set = torchvision.datasets.MNIST(
    root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

# Create a training batch iterator
train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train_HP))
train_iterator = torch.utils.data.DataLoader(
    dataset = train_set, 
    sampler = train_subset, batch_size=batch_sz_HP
    )

# Import the mnist test set
test_set = torchvision.datasets.MNIST(
    root = './datasets',
    train = False, transform = transforms.ToTensor(),  download = True
    )
# Create a testing batch iterator
test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test_HP))
test_iterator = torch.utils.data.DataLoader(
    dataset = test_set, 
    sampler = test_subset, batch_size = batch_sz_HP
    )

In [75]:
# Returns the validation set classification accuracy
# of the given input model (this is a higher order function)
def get_acc(model, iterator):
    # Get the validation set classification accuracy
    total_good_classifications = 0
    acc_metric = MulticlassAccuracy(num_classes=nb_classes_HP).to(chosen_device)
    for (x_mb, y_mb) in iterator:
        x_mb = x_mb.to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Add the number of datapoints we classified right to the total
        batch_size = x_mb.size()[0]
        y_hat = model(x_mb)
        batch_good_classifications = batch_size * acc_metric(y_hat, y_mb)
        total_good_classifications += batch_good_classifications
    return total_good_classifications / nb_test_HP # divide by total size

In [76]:
# Create the fcnn class
import torch.nn as nn
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 128, 5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.lin1= nn.Linear(1600, 10)
    def middleforward(self, x):
      x = self.conv1(x)
      x = torch.flatten(x, start_dim=1)
      return x
    def forward(self, x):
      x = self.conv1(x)
      x = self.conv2(x)
      x = torch.flatten(x, start_dim=1)
      x = self.lin1(x)
      return x

#Instantiate and put the model on the chosen device
teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
teacher_optimizer = torch.optim.Adam(teacher.parameters())

# Create an array to store the val loss
# of the student at each epoch
teacher_test_loss = []
teacher_train_loss = []

# Training loop 
for epoch in range(nepochs_teacher_HP):
    for (x_mb, y_mb) in train_iterator:
        # Flatten the MNIST images, which come in matrix form
        x_mb = x_mb.to(chosen_device)
        y_mb = y_mb.to(chosen_device)

        teacher_output = teacher(x_mb) 

        # Backpropagation
        loss = teacher_loss_HP(teacher_output, y_mb)
        loss.backward()
        teacher_optimizer.step()
        teacher_optimizer.zero_grad()

    #teacher_train_loss.append( round(get_acc(teacher, train_iterator).item(), 3) )
    teacher_test_loss.append( round(get_acc(teacher, test_iterator).item(), 5) )

print("Teacher results:")
print("Epochs: ", np.arange(1, nepochs_teacher_HP+1).tolist())
print("Train loss: ", teacher_train_loss)
print("Test loss: ", teacher_test_loss)


Teacher results:
Epochs:  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
Train loss:  []
Test loss:  [0.70517, 0.84256, 0.88623, 0.90287, 0.92864, 0.94642, 0.95169, 0.94947, 0.96408, 0.95141, 0.96894, 0.96202, 0.97315, 0.96537, 0.97036, 0.97069, 0.96812, 0.96699, 0.96767, 0.96589, 0.97197, 0.96946, 0.96875, 0.96889, 0.96826, 0.97941, 0.9645, 0.97137, 0.97188, 0.97057, 0.97159, 0.96885, 0.97211, 0.9691, 0.97243, 0.97081, 0.97197, 0.96658, 0.97066, 0.96931]


# Training the student model

In [None]:
class Student(nn.Module):
    ''' 
    Our student model is a composition of two MPS.
    '''
    def __init__(self):
        super(Student, self).__init__()
        # Initialize the MPS modules
        self.mps1 = MPS(
            input_dim = 28 ** 2,
            feature_dim = 2,
            output_dim = teacher_hidden_size_HP,
            bond_dim = bond_dim_HP,
            init_std = 0.01
          )

        self.mps2 = MPS(
          input_dim = teacher_hidden_size_HP, 
          feature_dim = 2,
          output_dim = nb_classes_HP,
          bond_dim = bond_dim_HP,
          init_std = 0.01
        )
        
    def forward(self, x):
        x = x.reshape(-1, 784)
        y = self.mps1(x)
        y = self.mps2(y)
        return y

student = Student().to(chosen_device)


# Create the optimizers for the training
mps1_optimizer = torch.optim.Adam(
    student.mps1.parameters(), lr = student_lr_HP, weight_decay = student_reg_HP
)

mps2_optimizer = torch.optim.Adam(
    student.mps2.parameters(), lr = student_lr_HP, weight_decay = student_reg_HP
)

student_optimizer = torch.optim.Adam(
    student.parameters(), lr = student_lr_HP, weight_decay = student_reg_HP
)

# Used on the inputs before the loss function
LogSoftmax = nn.LogSoftmax(dim=1)

# Create an array to store the val loss
# of the student at each epoch
stud_test_loss = []
stud_train_loss = []



mse_loss = nn.MSELoss()

# Mps2 training loop 
for epoch in range(nepochs_student_HP * 4):
    for (x_mb, y_mb) in train_iterator:
        # Flatten the MNIST images, which come in matrix form
        x_mb = x_mb.to(chosen_device)
        y_mb = y_mb.to(chosen_device)

        # Train mps1
        if (epoch < nepochs_student_HP*2):
          teacher_middle_output = teacher.middleforward(x_mb) 
          mps1_logits = student.mps1(x_mb.reshape(-1, 784)) 
          # Backpropagation
          loss = mse_loss(mps1_logits, teacher_middle_output)
          loss.backward()
          mps1_optimizer.step()
          mps1_optimizer.zero_grad()

        # Train mps2
        else:
          student_output = LogSoftmax( student(x_mb.reshape(-1, 784)) )
          teacher_output = LogSoftmax( teacher(x_mb.reshape(-1, 784)) )
          # Backpropagation
          loss = student_loss_HP(student_output, teacher_output)
          loss.backward()
          mps2_optimizer.step()
          mps2_optimizer.zero_grad()


    #stud_train_loss.append( round(get_acc(student, train_iterator).item(), 3) )
    stud_test_loss.append( round(get_acc(student, test_iterator).item(), 5) )
    print(stud_test_loss)
    print(stud_train_loss)

print("Student results:")
print("Epochs: ", np.arange(1, nepochs_student_HP+1).tolist())
print("Train loss: ", stud_train_loss)
print("Test loss: ", stud_test_loss)


[0.102]
[]
[0.102, 0.1]
[]
[0.102, 0.1, 0.10753]
[]
[0.102, 0.1, 0.10753, 0.1051]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659, 0.10845]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659, 0.10845, 0.10831]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659, 0.10845, 0.10831, 0.10766]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659, 0.10845, 0.10831, 0.10766, 0.10948]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659, 0.10845, 0.10831, 0.10766, 0.10948, 0.1075]
[]
[0.102, 0.1, 0.10753, 0.1051, 0.1128, 0.1115, 0.11146, 0.10684, 0.10659, 0.10845, 0.10831, 0.10766, 0.10948, 0