In [20]:
%pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jemisjoky/TorchMPS.git
  Cloning https://github.com/jemisjoky/TorchMPS.git to /tmp/pip-req-build-dvn2ub3y
  Running command git clone --filter=blob:none --quiet https://github.com/jemisjoky/TorchMPS.git /tmp/pip-req-build-dvn2ub3y
  Resolved https://github.com/jemisjoky/TorchMPS.git to commit 6c0bc1a8e2c15acba8570ca9ffe2b4a0c7135165
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [21]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train = 400
nb_test = 50
chosen_bs = 100
input_size = 28*28
nb_classes = 10


# Student hyperparameters
# MPS parameters
hidden_size = 70
bond_dim = 20
adaptive_mode = False
periodic_bc = False
feature_map = lambda x : torch.tensor([1, x]).to(chosen_device)
feature_dim = 2
# Training parameters
gn_var = 0.3 #added gaussian noise variance
gn_mean = 0 #added gaussian noise mean
n_epochs_lmps = 20
mps_learn_rate = 0.0001
mps_reg = 0.0
mps_chosen_loss = nn.CrossEntropyLoss().to(chosen_device)

In [22]:
# Import the mnist dataset
train_set = torchvision.datasets.MNIST(root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train))

train_iterator = torch.utils.data.DataLoader(dataset = train_set, 
    sampler = train_subset, batch_size=chosen_bs)


test_set = torchvision.datasets.MNIST(root = './datasets',
 train = False, transform = transforms.ToTensor(),  download = True)

test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test))

test_iterator = torch.utils.data.DataLoader(dataset = test_set, 
    sampler = test_subset, batch_size = chosen_bs)

In [None]:
class Student(nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        # Initialize the MPS modules
        self.student_mps1 = MPS(
            input_dim = input_size,
            feature_dim = feature_dim,
            output_dim = hidden_size,
            bond_dim = bond_dim,
            adaptive_mode = adaptive_mode,
            periodic_bc = periodic_bc,
        )
        self.student_mps1.register_feature_map(feature_map)

        self.student_mps2 = MPS(
            input_dim = hidden_size, #crucial setting
            feature_dim = feature_dim,
            output_dim = nb_classes,
            bond_dim = bond_dim,
            adaptive_mode = adaptive_mode,
            periodic_bc = periodic_bc,
        )
        self.student_mps2.register_feature_map(feature_map)
        # Make sure that Pytorch finds the parameters of the model
        #self.stud_parameters = nn.ParameterList([self.student_mps1.parameters(), 
                                            #self.student_mps2.parameters()])

    def forward(self, x):
        y = self.student_mps1(x)
        #y = self.student_mps2(y)
        return y


student = Student().to(chosen_device)

# Instantiate the optimizer and softmax
student_optimizer = torch.optim.Adam(student.parameters(), lr = mps_learn_rate,
                                  weight_decay = mps_reg)


softmax = nn.Softmax(dim=1)

# Training loop for the student
for epoch in range(n_epochs_lmps):
    for (x_mb, y_mb) in train_iterator:
        # Reshape, add gaussian noise, put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)


        # Foward propagation
        y_hat_mb = student(x_mb)
        loss = mps_chosen_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        student_optimizer.step()
        student_optimizer.zero_grad()

    print(loss.item())



# Get the validation set classification accuracy
student_acc_score = 0
student_acc_metric = MulticlassAccuracy(num_classes=nb_classes).to(chosen_device)
for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.reshape(-1, 784).to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    # add the number of datapoints we classified right
    student_acc_score += x_mb.size()[0] * student_acc_metric( student(x_mb), y_mb )
print("Student_acc_score:")
print(student_acc_score / nb_test) #divide by total size


2.986597776412964
2.715536117553711
2.3862602710723877
2.2160656452178955
2.3184027671813965
2.235750198364258
2.3320345878601074
1.916836142539978
1.6303497552871704
1.6536632776260376
1.0061269998550415
