In [6]:
%pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jemisjoky/TorchMPS.git
  Cloning https://github.com/jemisjoky/TorchMPS.git to /tmp/pip-req-build-s5lxgzb8
  Running command git clone --filter=blob:none --quiet https://github.com/jemisjoky/TorchMPS.git /tmp/pip-req-build-s5lxgzb8
  Resolved https://github.com/jemisjoky/TorchMPS.git to commit 6c0bc1a8e2c15acba8570ca9ffe2b4a0c7135165
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [7]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train = 2000
nb_test = 500
chosen_bs = 150
nb_classes = 10

# Teacher hyperparameters
chosen_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_epochs_fcnn = 15
hidden_size = 70
chosen_loss = nn.CrossEntropyLoss()
#Optimizer parameters
chosen_lr = 0.0001
chosen_momentum = 0.9

# Student hyperparameters
# MPS parameters
bond_dim = 20
adaptive_mode = False
periodic_bc = False
feature_map = lambda x : torch.tensor([1, x]).to(chosen_device)
# Training parameters
gauss_epochs = 5 # number of epochs with added gaussian noise
gn_var = 0.3 #added gaussian noise variance
gn_mean = 0 #added gaussian noise mean
n_epochs_lmps = 15
mps_learn_rate = 0.0001
mps_reg = 0.0
mps_chosen_loss = nn.CrossEntropyLoss().to(chosen_device)

In [8]:
# Import the mnist dataset
train_set = torchvision.datasets.MNIST(root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train))

train_iterator = torch.utils.data.DataLoader(dataset = train_set, 
    sampler = train_subset, batch_size=chosen_bs)


test_set = torchvision.datasets.MNIST(root = './datasets',
 train = False, transform = transforms.ToTensor(),  download = True)

test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test))

test_iterator = torch.utils.data.DataLoader(dataset = test_set, 
    sampler = test_subset, batch_size = chosen_bs)

In [9]:

# Create the fcnn class
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.relu = nn.ReLU()
        self.lin1 = nn.Linear(784, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, hidden_size)
        self.lin4 = nn.Linear(hidden_size, hidden_size)
        self.lin5 = nn.Linear(hidden_size, hidden_size)
        self.lin6 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        y = self.lin1(x)
        y = self.relu(y)
        y = self.lin2(y)
        y = self.relu(y)
        y = self.lin3(y)
        y = self.relu(y)
        y = self.lin4(y)
        y = self.relu(y)
        y = self.lin5(y)
        y = self.relu(y)
        y = self.lin6(y)
        y = self.relu(y)
        return y

#Instantiate and put the model on the chosen device
fcnn_teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
optimizer = torch.optim.Adam(fcnn_teacher.parameters())

#Training loop
for epoch in range(n_epochs_fcnn):
    for (x_mb, y_mb) in train_iterator:
        # Reshape the train_tuple and put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Foward propagation
        y_hat_mb = fcnn_teacher(x_mb)
        loss = chosen_loss(y_hat_mb, y_mb)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(loss.item())

# Get the validation set classification accuracy
teacher_acc_score = 0
teacher_acc_metric = MulticlassAccuracy(num_classes=nb_classes).to(chosen_device)
for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.reshape(-1, 784).to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    # add the number of datapoints we classified right
    teacher_acc_score += x_mb.size()[0] * teacher_acc_metric( fcnn_teacher(x_mb), y_mb )
print("The teacher's accuracy score is:")
print(teacher_acc_score / nb_test) #divide by total size



2.290205955505371
2.2128565311431885
1.8400448560714722
1.2573946714401245
The teacher's accuracy score is:
tensor(0.3920)


In [10]:

# Initialize the MPS module
mps_student = MPS(
    input_dim=28 ** 2,
    output_dim=10,
    bond_dim=bond_dim,
    adaptive_mode=adaptive_mode,
    periodic_bc=periodic_bc,
).to(chosen_device)
mps_student.register_feature_map(feature_map)

# Instantiate the optimizer and softmax
lmps_optimizer = torch.optim.Adam(mps_student.parameters(), lr = mps_learn_rate,
                                  weight_decay = mps_reg)
softmax = nn.Softmax(dim=1)

# Training loop 
for epoch in range(n_epochs_lmps):
    for (x_mb, _) in train_iterator:
        # Reshape, add gaussian noise, put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        if epoch > (n_epochs_lmps - gauss_epochs):
          x_mb = (x_mb + torch.randn(size=x_mb.size())).to(chosen_device)

        # Get softmax of output of teacher
        y_mb = (softmax(fcnn_teacher(x_mb))).to(chosen_device)

        # Foward propagation
        y_hat_mb = mps_student(x_mb)
        loss = mps_chosen_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        lmps_optimizer.step()
        lmps_optimizer.zero_grad()

    print(loss.item())


student = mps_student

# Get the validation set classification accuracy
student_acc_score = 0
student_acc_metric = MulticlassAccuracy(num_classes=nb_classes).to(chosen_device)
for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.reshape(-1, 784).to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    # add the number of datapoints we classified right
    student_acc_score += x_mb.size()[0] * student_acc_metric( student(x_mb), y_mb )
print("Student_acc_score:")
print(student_acc_score / nb_test) #divide by total size


2.297011375427246
2.301151752471924
2.2917494773864746
2.204847574234009


KeyboardInterrupt: ignored