In [None]:
!pip install "git+https://github.com/jemisjoky/TorchMPS.git"
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
!pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

We shall choose the Hyperparameters in our setup

In [53]:
# Hardware hyperparameters
chosen_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Data hyperparameters
data_dim = [1,28,28] # size of the individual input images (channel size, dims)
nb_train = 2000
nb_test = 500
chosen_bs = 100
nb_classes = 10

# Kernel parameters
kernel_size = 10
padding = 0
stride = 4

# Student training hyperparameters
# general settings
n_epochs_lmps = 5
mps_learn_rate = 0.0001
mps_reg = 0.0
student_loss = nn.CrossEntropyLoss().to(chosen_device)
# mps parameters
input_dim = kernel_size**2
output_dim = 1
feature_dim = 2
bond_dim = 20
adaptive_mode = False
periodic_bc = False
# gaussian noise parameters
gn_var = 0.3
gn_mean = 0


Here, we import the mnist dataset

In [54]:
# Import the mnist dataset
train_set = torchvision.datasets.MNIST(root = './datasets', train = True,    transform = transforms.ToTensor(),  download = True )

train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train))

train_iterator = torch.utils.data.DataLoader(dataset = train_set, sampler = train_subset, batch_size=chosen_bs)


test_set = torchvision.datasets.MNIST(root = './datasets', train = False, transform = transforms.ToTensor(),  download = True)

test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test))

test_iterator = torch.utils.data.DataLoader(dataset = test_set, sampler = test_subset, batch_size = chosen_bs)

In [None]:

class LMPS_patcher(nn.Module):
    def __init__(self):
        super(LMPS_patcher, self).__init__()

        # Create the MPS instance for the student
        self.mps_student = MPS(
            input_dim = input_dim,
            feature_dim = feature_dim,
            output_dim = output_dim,
            bond_dim = bond_dim,
            adaptive_mode = adaptive_mode,
            periodic_bc = periodic_bc,
        ).to(chosen_device)
        # we define our original Φ, which creates the multilinear feature space
        self.feature_map = lambda x : torch.tensor([1, x]).to(chosen_device)
        # to make the mps use our feature map for the inputs
        self.mps_student.register_feature_map(self.feature_map)


        # Create the unfolder, which will create our patches
        # the output size is [batch_size, nb_patches, patch_size]
        self.folding_params = dict(kernel_size=kernel_size, 
            dilation=1, padding=padding, stride=stride)
        self.unfolder = torch.nn.Unfold(**self.folding_params)

        # Get the number of patches with a little trick
        fake_image = torch.zeros([1] + data_dim)
        unfold_output = self.unfolder(fake_image)
        nb_patches = unfold_output.size()[2]
        self.nb_patches = nb_patches


        # Fully connected layer
        self.lin = nn.Linear(nb_patches, nb_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        self.patches = self.unfolder(x)
        self.mps_outputs = torch.zeros([x.size()[0], self.nb_patches]).to(chosen_device)
        for i in range(x.size()[0]):
            for j in range(self.nb_patches):
                # get patch tensor
                self.flattened_patch = self.patches[i,:,j].reshape(1, kernel_size**2)
                self.mps_outputs[i,j] = self.mps_student(self.flattened_patch)

        y = self.lin(self.mps_outputs)
        y = self.relu(y)
        return y


student = LMPS_patcher().to(chosen_device)

# Instantiate the optimizer and softmax
student_optimizer = torch.optim.Adam(student.parameters(), lr = mps_learn_rate,
                                  weight_decay = mps_reg)
softmax = nn.Softmax(dim=1)

# Training loop 
for epoch in range(n_epochs_lmps):
    for (x_mb, y_mb) in train_iterator:
        # Reshape the train_tuple and put on the chosen device
        x_mb = x_mb.to(chosen_device)
        y_mb = y_mb.to(chosen_device)
        # Foward propagation
        y_hat_mb = student(x_mb)
        loss = student_loss(y_hat_mb, y_mb)
        # Backpropagation
        loss.backward()
        student_optimizer.step()
        student_optimizer.zero_grad()

    print(loss.item())


# get accuracy

for (x_mb, y_mb) in test_iterator:
    x_mb = x_mb.to(chosen_device)
    y_mb = y_mb.to(chosen_device)
    metric = MulticlassAccuracy(num_classes=10).to(chosen_device)
    print(  metric( student(x_mb), y_mb )  )
