In [61]:
%pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/jemisjoky/TorchMPS.git
  Cloning https://github.com/jemisjoky/TorchMPS.git to c:\users\piche\appdata\local\temp\pip-req-build-rwogwnt7
  Resolved https://github.com/jemisjoky/TorchMPS.git to commit 6c0bc1a8e2c15acba8570ca9ffe2b4a0c7135165
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Note: you may need to restart the kernel to use updated packages.


  Running command git clone --filter=blob:none --quiet https://github.com/jemisjoky/TorchMPS.git 'C:\Users\piche\AppData\Local\Temp\pip-req-build-rwogwnt7'


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


# Hyperparameters

In [62]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train_HP = 2
nb_test_HP = 5
batch_sz_HP = 150
batch_sz_HP = min(batch_sz_HP, nb_train_HP)
nb_classes_HP = 10

# Teacher hyperparameters
nteacher_epochs_HP = 15
hidden_size_HP = 70
teacher_loss_HP = nn.CrossEntropyLoss()
#Optimizer parameters
teacher_lr_HP = 0.0001
teacher_lr_HP = 0.9

# Student hyperparameters
# MPS parameters
bond_dim_HP = 20
adaptive_mode_HP = False
periodic_bc_HP = False
feature_map_HP = lambda x : torch.tensor([1, x]).to(chosen_device)
# Training parameters
nepochs_student_HP = 15
student_lr_HP = 0.0001
student_reg_HP = 0.0
student_loss_HP = nn.KLDivLoss(reduction="batchmean", log_target=True)
# Gaussian parameters
ngauss_epochs_HP = 5 # number of epochs with added gaussian noise
gn_var_HP = 0.3 #added gaussian noise variance
gn_mean_HP = 0 #added gaussian noise mean
# well suited loss for comparison of divergences

# Premilinaries: Importing the data and utils subroutines

In [63]:
# Import the mnist train dataset
train_set = torchvision.datasets.MNIST(
    root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

# Create a training batch iterator
train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train_HP))
train_iterator = torch.utils.data.DataLoader(
    dataset = train_set, 
    sampler = train_subset, batch_size=batch_sz_HP
    )

# Import the mnist test set
test_set = torchvision.datasets.MNIST(
    root = './datasets',
    train = False, transform = transforms.ToTensor(),  download = True
    )
# Create a testing batch iterator
test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test_HP))
test_iterator = torch.utils.data.DataLoader(
    dataset = test_set, 
    sampler = test_subset, batch_size = batch_sz_HP
    )

In [64]:
# Returns the validation set classification accuracy
# of the given input model (this is a higher order function)
def get_val_acc(model):
    # Get the validation set classification accuracy
    total_good_classifications = 0
    acc_metric = MulticlassAccuracy(num_classes=nb_classes_HP).to(chosen_device)
    for (x_mb, y_mb) in test_iterator:
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Add the number of datapoints we classified right to the total
        batch_good_classifications = x_mb.size()[0] * acc_metric( model(x_mb), y_mb)
        total_good_classifications += batch_good_classifications
    return total_good_classifications / nb_test_HP # divide by total size

# Training the teacher model

In [65]:

# Create the fcnn class
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.relu = nn.ReLU()
        self.lin1 = nn.Linear(784, hidden_size_HP)
        self.lin2 = nn.Linear(hidden_size_HP, hidden_size_HP)
        self.lin3 = nn.Linear(hidden_size_HP, hidden_size_HP)
        self.lin4 = nn.Linear(hidden_size_HP, hidden_size_HP)
        self.lin5 = nn.Linear(hidden_size_HP, hidden_size_HP)
        self.lin6 = nn.Linear(hidden_size_HP, 10)

    def forward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        y = self.lin4(y)
        y = self.relu(y)
        y = self.lin5(y)
        y = self.relu(y)
        y = self.lin6(y)
        y = self.relu(y)
        return y

#Instantiate and put the model on the chosen device
teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
teacher_optimizer = torch.optim.Adam(teacher.parameters())

# Create an array to store the val loss
# of the teacher at each epoch
teacher_val_loss = []

#Training loop
for epoch in range(nteacher_epochs_HP):
    for (x_mb, y_mb) in train_iterator:
        # Reshape the train_tuple and put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Foward propagation
        y_hat_mb = teacher(x_mb)
        # Backpropagation
        teacher_loss_HP(y_hat_mb, y_mb).backward()
        teacher_optimizer.step()
        teacher_optimizer.zero_grad()
        
    val_score = get_val_acc(teacher).item()
    val_score = round(val_score, 3)
    teacher_val_loss.append( val_score )

print("Teacher accuracy through epochs")
print(np.arange(1, nteacher_epochs_HP+1).tolist())
print(teacher_val_loss)


Teacher accuracy through epochs
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


# Training the student model

In [66]:

# Initialize the MPS module
student = MPS(
    input_dim = 28 ** 2,
    output_dim = 10,
    bond_dim = bond_dim_HP,
    adaptive_mode = adaptive_mode_HP,
    periodic_bc = periodic_bc_HP,
).to(chosen_device)
student.register_feature_map(feature_map_HP)

# Instantiate the optimizer and softmax
student_optimizer = torch.optim.Adam(
    student.parameters(), lr = student_lr_HP, weight_decay = student_reg_HP
    )

# Used on the inputs before the loss function
LogSoftmax = nn.LogSoftmax(dim=1)

# Create an array to store the val loss
# of the student at each epoch
stud_val_loss = []

# Training loop 
for epoch in range(nepochs_student_HP):

    for (x_mb, _) in train_iterator:
        
        # Flatten the MNIST images, which come in matrix form
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)

        # Apply Gaussian noise for the gaussian noised epochs
        if epoch > (nepochs_student_HP - ngauss_epochs_HP):
          x_mb = (x_mb + torch.randn(size=x_mb.size())).to(chosen_device)

        # Get log of softmax of the teacher and student logits
        # required before passing to KL divergence loss in Pytorch for some reason
        teacher_output = LogSoftmax(teacher(x_mb)).to(chosen_device)
        student_output = LogSoftmax(student(x_mb)).to(chosen_device)

        # Backpropagation
        student_loss_HP(student_output, teacher_output).backward()
        student_optimizer.step()
        student_optimizer.zero_grad()

    val_score = get_val_acc(student).item()
    val_score = round(val_score, 3)
    stud_val_loss.append( val_score )

print("Student accuracy through epochs")
print(np.arange(1, nepochs_student_HP+1).tolist())
print(stud_val_loss)


Student accuracy through epochs
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04, 0.04, 0.04]
