In [1]:
%pip install "git+https://github.com/ScierKnave/TorchMPS.git"
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS
import math

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/ScierKnave/TorchMPS.git
  Cloning https://github.com/ScierKnave/TorchMPS.git to /tmp/pip-req-build-03wp9zf7
  Running command git clone --filter=blob:none --quiet https://github.com/ScierKnave/TorchMPS.git /tmp/pip-req-build-03wp9zf7
  Resolved https://github.com/ScierKnave/TorchMPS.git to commit f716a08e15d0af50dbfdfc435ab9604e82562ea3
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torchmps
  Building wheel for torchmps (setup.py) ... [?25l[?25hdone
  Created wheel for torchmps: filename=torchmps-0.1.0-py3-none-any.whl size=59680 sha256=9bdef27a4da62bead27461d2a96bb29c02c24f51af34664a3e74d6a1e56083d6
  Stored in directory: /tmp/pip-ephem-wheel-cache-6iicxqgc/wheels/7c/97/ed/e7cfb0c73fec613f09c365db5a9b684875fdb85e8944240efc
Successfully built torchmps
Installing collected packages: torchmps
Successfully ins

# Hyperparameters

In [2]:
# FC to MPS

# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train_HP = 2000
nb_test_HP = 500
batch_sz_HP = 150
batch_sz_HP = min(batch_sz_HP, nb_train_HP)
nb_classes_HP = 10

# Teacher hyperparameters
nepochs_teacher_HP = 25
teacher_loss_HP = nn.CrossEntropyLoss()
teacher_lr_HP = 1e-2
teacher_reg_HP = 0.01
teacher_hidden_size_HP = 70
# Student hyperparameters
# MPS parameters
bond_dim_HP = 40
adaptive_mode_HP = False
periodic_bc_HP = False

# USING CUSTOM FEATURE MAP WITH FORK?: YES

# Training parameters
nepochs_student_HP = 25 
student_lr_HP = 1e-4
student_reg_HP = 0.01
student_loss_HP = nn.KLDivLoss(reduction = "batchmean", log_target = True)

# Gaussian parameters
gauss_epochs_HP = 0 # number of epochs with added gaussian noise
gn_var_HP = 0.3 #added gaussian noise variance
gn_mean_HP = 0 #added gaussian noise mean
nepochs_student_HP = 25 + gauss_epochs_HP


# Premilinaries: Importing the data and utils subroutines

In [3]:
# Import the mnist train dataset
train_set = torchvision.datasets.MNIST(
    root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

# Create a training batch iterator
train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train_HP))
train_iterator = torch.utils.data.DataLoader(
    dataset = train_set, 
    sampler = train_subset, batch_size=batch_sz_HP
    )

# Import the mnist test set
test_set = torchvision.datasets.MNIST(
    root = './datasets',
    train = False, transform = transforms.ToTensor(),  download = True
    )
# Create a testing batch iterator
test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test_HP))
test_iterator = torch.utils.data.DataLoader(
    dataset = test_set, 
    sampler = test_subset, batch_size = batch_sz_HP
    )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 196332256.23it/s]

Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 42728639.80it/s]


Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 66214995.28it/s]

Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6755506.66it/s]


Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw



In [4]:
# Returns the validation set classification accuracy
# of the given input model (this is a higher order function)
def get_acc(model, iterator, dataset_size):
    # Get the validation set classification accuracy
    total_good_classifications = 0
    acc_metric = MulticlassAccuracy(num_classes=nb_classes_HP).to(chosen_device)
    for (x_mb, y_mb) in iterator:
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Add the number of datapoints we classified right to the total
        batch_size = x_mb.size()[0]
        y_hat = model(x_mb)
        batch_good_classifications = batch_size * acc_metric(y_hat, y_mb)
        total_good_classifications += batch_good_classifications
    return total_good_classifications / dataset_size # divide by total size

In [5]:
# Create the fcnn class
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.relu = nn.ReLU()
        self.lin1 = nn.Linear(784, teacher_hidden_size_HP)
        self.lin2 = nn.Linear(teacher_hidden_size_HP, teacher_hidden_size_HP)
        self.lin3 = nn.Linear(teacher_hidden_size_HP, teacher_hidden_size_HP)
        self.lin4 = nn.Linear(teacher_hidden_size_HP, teacher_hidden_size_HP)
        self.lin5 = nn.Linear(teacher_hidden_size_HP, teacher_hidden_size_HP)
        self.lin6 = nn.Linear(teacher_hidden_size_HP, 10)

    def middleforward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        return y

    def forward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        y = self.lin4(y)
        y = self.relu(y)
        y = self.lin5(y)
        y = self.relu(y)
        y = self.lin6(y)
        y = self.relu(y)
        return y

#Instantiate and put the model on the chosen device
teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
teacher_optimizer = torch.optim.Adam(teacher.parameters())

# Create an array to store the val loss
# of the student at each epoch
teacher_test_loss = []
teacher_train_loss = []

# Training loop 
for epoch in range(nepochs_teacher_HP):
    for (x_mb, y_mb) in train_iterator:
        # Flatten the MNIST images, which come in matrix form
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)

        teacher_output = teacher(x_mb) 

        # Backpropagation
        loss = teacher_loss_HP(teacher_output, y_mb)
        loss.backward()
        teacher_optimizer.step()
        teacher_optimizer.zero_grad()

    #teacher_train_loss.append( round(get_acc(teacher, train_iterator).item(), 3) )
    teacher_test_loss.append( round(get_acc(teacher, test_iterator, nb_test_HP).item(), 5) )

print("Teacher results:")
print("Epochs: ", np.arange(1, nepochs_teacher_HP+1).tolist())
print("Train loss: ", teacher_train_loss)
print("Test loss: ", teacher_test_loss)


Teacher results:
Epochs:  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
Train loss:  []
Test loss:  [0.205, 0.26762, 0.34435, 0.45653, 0.50299, 0.48948, 0.56028, 0.54766, 0.59326, 0.6071, 0.59979, 0.62026, 0.60682, 0.63139, 0.62401, 0.62037, 0.62014, 0.60877, 0.62892, 0.62471, 0.62431, 0.63479, 0.62603, 0.62664, 0.62532]


# Training the student model

In [6]:
def train_a_student():
  '''
  Trains a student model.
  '''
  # Initialize the MPS module
  student = MPS(
      input_dim = 28 ** 2,
      output_dim = 10,
      bond_dim = bond_dim_HP
  ).to(chosen_device)
  #student.register_feature_map(feature_map_HP)

  # Instantiate the optimizer and softmax
  student_optimizer = torch.optim.Adam(
      student.parameters(), lr = student_lr_HP, weight_decay = student_reg_HP
  )

  # Used on the inputs before the loss function
  LogSoftmax = nn.LogSoftmax(dim=1)

  # Create an array to store the val loss
  # of the student at each epoch
  stud_test_loss = np.array([])
  stud_train_loss = np.array([])

  # Training loop 
  for epoch in range(nepochs_student_HP):
      for (x_mb, y_mb) in train_iterator:
          # Flatten the MNIST images, which come in matrix form
          x_mb = x_mb.reshape(-1, 784).to(chosen_device)
          y_mb = y_mb.to(chosen_device)

          # Add Gaussian noise for the gaussian epochs
          if (epoch >= nepochs_student_HP - gauss_epochs_HP):
            x_mb = x_mb + torch.randn(size=x_mb.size()).to(chosen_device)

          student_output = LogSoftmax( student(x_mb) )
          teacher_output = LogSoftmax( teacher(x_mb) )

          # Backpropagation
          loss = student_loss_HP(student_output, teacher_output)
          loss.backward()
          student_optimizer.step()
          student_optimizer.zero_grad()

      # Get accuracy over all test and training data for current epoch
      train_current_accuracy = round( get_acc(student, train_iterator, nb_train_HP).item(), 4)
      test_current_accuracy = round( get_acc(student, test_iterator, nb_test_HP).item(), 4)
      stud_train_loss = np.append(stud_train_loss, train_current_accuracy)
      stud_test_loss = np.append(stud_test_loss, test_current_accuracy)
  return(stud_train_loss, stud_test_loss)

# Repeat the training process in order to get the variance
global_stud_test_loss = np.array([])
global_stud_train_loss = np.array([])
for i in range(20):
  print("\n Training and testing of model ", i+1, " in progress...")
  (stud_train_loss, stud_test_loss) = train_a_student()
  print("Train loss: ", stud_train_loss)
  print("Test loss: ", stud_test_loss)
  if (i == 0):
    global_stud_train_loss = stud_train_loss
    global_stud_test_loss = stud_test_loss
  else:
    global_stud_train_loss = np.vstack((global_stud_train_loss, stud_train_loss))
    global_stud_test_loss = np.vstack((global_stud_test_loss, stud_test_loss))



 Training and testing of model  1  in progress...
Train loss:  [0.1    0.1661 0.4116 0.4846 0.5132 0.5831 0.5956 0.5897 0.6158 0.6241
 0.6378 0.6462 0.6536 0.6482 0.6559 0.6636 0.663  0.6629 0.6602 0.6693
 0.6628 0.6686 0.6674 0.6742 0.6728]
Test loss:  [0.1    0.1655 0.4001 0.429  0.4702 0.5413 0.574  0.5739 0.5837 0.5808
 0.6085 0.6093 0.6184 0.6305 0.6207 0.6401 0.6455 0.6443 0.6186 0.6367
 0.6461 0.645  0.6499 0.6544 0.6391]

 Training and testing of model  2  in progress...
Train loss:  [0.1015 0.119  0.3797 0.4389 0.5075 0.5534 0.5873 0.5957 0.6089 0.628
 0.6298 0.6476 0.6561 0.6587 0.655  0.6627 0.6634 0.6624 0.6648 0.6666
 0.6652 0.6689 0.6718 0.673  0.6745]
Test loss:  [0.1062 0.1071 0.3541 0.3488 0.4645 0.4736 0.5515 0.5282 0.5779 0.5849
 0.5936 0.6208 0.6339 0.6379 0.6384 0.6383 0.6412 0.6505 0.6376 0.6445
 0.6342 0.6474 0.6454 0.6511 0.6529]

 Training and testing of model  3  in progress...
Train loss:  [0.1    0.1    0.1    0.1    0.3046 0.4579 0.5559 0.5874 0.6209 0.631

In [7]:
# Print the final results and save them for the report.
print("Final results")
print("Train loss: ", global_stud_train_loss)
print("Test loss: ", global_stud_test_loss)

np.save('20x_fc_to_mps_trainloss_40bd', stud_train_loss)
np.save('20x_fc_to_mps_testloss_40bd', stud_test_loss)


Final results
Train loss:  [[0.1    0.1661 0.4116 0.4846 0.5132 0.5831 0.5956 0.5897 0.6158 0.6241
  0.6378 0.6462 0.6536 0.6482 0.6559 0.6636 0.663  0.6629 0.6602 0.6693
  0.6628 0.6686 0.6674 0.6742 0.6728]
 [0.1015 0.119  0.3797 0.4389 0.5075 0.5534 0.5873 0.5957 0.6089 0.628
  0.6298 0.6476 0.6561 0.6587 0.655  0.6627 0.6634 0.6624 0.6648 0.6666
  0.6652 0.6689 0.6718 0.673  0.6745]
 [0.1    0.1    0.1    0.1    0.3046 0.4579 0.5559 0.5874 0.6209 0.631
  0.6378 0.6427 0.6422 0.6587 0.6419 0.6505 0.6566 0.6534 0.6635 0.6707
  0.6617 0.6712 0.6717 0.6667 0.6721]
 [0.1195 0.1    0.2273 0.4746 0.5111 0.5704 0.6041 0.5993 0.6085 0.6221
  0.6218 0.6375 0.641  0.6485 0.6427 0.6498 0.6549 0.6628 0.6645 0.6692
  0.6631 0.6715 0.6678 0.6707 0.6717]
 [0.1    0.1    0.1423 0.2196 0.327  0.4411 0.5246 0.5699 0.5893 0.6066
  0.6238 0.6212 0.6427 0.6482 0.6521 0.6557 0.658  0.6571 0.6619 0.6637
  0.669  0.6657 0.667  0.6719 0.6673]
 [0.1    0.1    0.1    0.1    0.1    0.1595 0.3257 0.4644 0.5438 