In [18]:
%pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
%pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS
import math

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jemisjoky/TorchMPS.git
  Cloning https://github.com/jemisjoky/TorchMPS.git to /tmp/pip-req-build-wwl_14hw
  Running command git clone --filter=blob:none --quiet https://github.com/jemisjoky/TorchMPS.git /tmp/pip-req-build-wwl_14hw
  Resolved https://github.com/jemisjoky/TorchMPS.git to commit 6c0bc1a8e2c15acba8570ca9ffe2b4a0c7135165
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Hyperparameters

In [19]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' 
if torch.cuda.is_available() else 'cpu')

# Data hyperparameters
nb_train_HP = 2000
nb_test_HP = 500
batch_sz_HP = 150
batch_sz_HP = min(batch_sz_HP, nb_train_HP)
nb_classes_HP = 10

# Student hyperparameters
# MPS parameters
bond_dim_HP = 20
#feature_map_HP = lambda x : torch.tensor([math.cos(1.57079*x), 
                        #math.cos(1.57079*x)]).to(chosen_device)
#feature_map_HP = lambda x : torch.tensor([1, x]).to(chosen_device)

# Training parameters
nepochs_student_HP = 25 
student_lr_HP = 1e-3
student_reg_HP = 0
student_loss_HP = nn.CrossEntropyLoss()


# Premilinaries: Importing the data and utils subroutines

In [20]:
# Import the mnist train dataset
train_set = torchvision.datasets.MNIST(
    root = './datasets', train = True,   
    transform = transforms.ToTensor(),  download = True )

# Create a training batch iterator
train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train_HP))
train_iterator = torch.utils.data.DataLoader(
    dataset = train_set, 
    sampler = train_subset, batch_size=batch_sz_HP
    )

# Import the mnist test set
test_set = torchvision.datasets.MNIST(
    root = './datasets',
    train = False, transform = transforms.ToTensor(),  download = True
    )
# Create a testing batch iterator
test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test_HP))
test_iterator = torch.utils.data.DataLoader(
    dataset = test_set, 
    sampler = test_subset, batch_size = batch_sz_HP
    )

In [21]:
# Returns the validation set classification accuracy
# of the given input model (this is a higher order function)
def get_acc(model, iterator):
    # Get the validation set classification accuracy
    total_good_classifications = 0
    acc_metric = MulticlassAccuracy(num_classes=nb_classes_HP).to(chosen_device)
    for (x_mb, y_mb) in iterator:
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Add the number of datapoints we classified right to the total
        batch_size = x_mb.size()[0]
        y_hat = model(x_mb)
        batch_good_classifications = batch_size * acc_metric(y_hat, y_mb)
        total_good_classifications += batch_good_classifications
    return total_good_classifications / nb_test_HP # divide by total size

# Training the student model

In [23]:

# Initialize the MPS module
student = MPS(
    input_dim = 28 ** 2,
    output_dim = 10,
    bond_dim = bond_dim_HP
).to(chosen_device)
student.register_feature_map(feature_map_HP)

# Instantiate the optimizer and softmax
student_optimizer = torch.optim.Adam(
    student.parameters(), lr = student_lr_HP, weight_decay = student_reg_HP
)

# Used on the inputs before the loss function
LogSoftmax = nn.LogSoftmax(dim=1)

# Create an array to store the val loss
# of the student at each epoch
stud_test_loss = []
stud_train_loss = []

# Training loop 
for epoch in range(nepochs_student_HP):
    for (x_mb, y_mb) in train_iterator:
        # Flatten the MNIST images, which come in matrix form
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        student_output = nn.Softmax(dim=1)( student(x_mb) )

        # Backpropagation
        loss = student_loss_HP(student_output, y_mb)
        loss.backward()
        student_optimizer.step()
        student_optimizer.zero_grad()

    #stud_train_loss.append( round(get_acc(student, train_iterator).item(), 3) )
    stud_test_loss.append( round(get_acc(student, test_iterator).item(), 5) )
    print(stud_test_loss)
    print(stud_train_loss)

print("Epochs: ", np.arange(1, nepochs_student_HP+1).tolist())
print("Train loss: ", stud_train_loss)
print("Test loss: ", stud_test_loss)


[0.10408]
[]
[0.10408, 0.10176]
[]


KeyboardInterrupt: ignored