In [11]:
!pip install "git+https://github.com/jemisjoky/TorchMPS.git"
!pip install "git+https://github.com/rballester/tntorch.git"
import tntorch as tn
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
!pip install torchmetrics
from torchmetrics.classification import MulticlassAccuracy
from torchmps import MPS

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jemisjoky/TorchMPS.git
  Cloning https://github.com/jemisjoky/TorchMPS.git to /tmp/pip-req-build-xtsvoqzn
  Running command git clone --filter=blob:none --quiet https://github.com/jemisjoky/TorchMPS.git /tmp/pip-req-build-xtsvoqzn
  Resolved https://github.com/jemisjoky/TorchMPS.git to commit 6c0bc1a8e2c15acba8570ca9ffe2b4a0c7135165
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/rballester/tntorch.git
  Cloning https://github.com/rballester/tntorch.git to /tmp/pip-req-build-dpkuuba0
  Running command git clone --filter=blob:none --quiet https://github.com/rballester/tntorch.git /tmp/pip-req-build-dpkuuba0
  Resolved https://github.com/rballester/tntorch.git to commit 241bf7ad2b806f6677a5e23534247f35f3a70f10
  Preparing

In [12]:
# Get the data

# Data hyperparameters
nb_train = 2000
nb_test = 500
chosen_bs = 150

# Import the mnist dataset
train_set = torchvision.datasets.MNIST(root = './datasets', train = True,    transform = transforms.ToTensor(),  download = True )

train_subset = torch.utils.data.SubsetRandomSampler(range(nb_train))

train_iterator = torch.utils.data.DataLoader(dataset = train_set, sampler = train_subset, batch_size=chosen_bs)


test_set = torchvision.datasets.MNIST(root = './datasets', train = False, transform = transforms.ToTensor(),  download = True)

test_subset = torch.utils.data.SubsetRandomSampler(range(nb_test))

test_iterator = torch.utils.data.DataLoader(dataset = test_set, sampler = test_subset, batch_size = chosen_bs)

In [13]:
# Create and train the FCNN

# Pick the hyperparameters for the FCNN
chosen_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_epochs_fcnn = 15
hidden_size = 70
chosen_loss = nn.CrossEntropyLoss()
#Optimizer parameters
chosen_lr = 0.001
chosen_momentum = 0.9


# Instantiate the fcnn
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.lin1 = nn.Linear(784, hidden_size)
        self.relu1 = nn.ReLU()

        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()

        self.lin3 = nn.Linear(hidden_size, hidden_size)
        self.relu3 = nn.ReLU()

        self.lin4 = nn.Linear(hidden_size, hidden_size)
        self.relu4 = nn.ReLU()

        self.lin5 = nn.Linear(hidden_size, hidden_size)
        self.relu5 = nn.ReLU()

        self.lin6 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        y = self.lin1(x)
        y = self.relu1(y)

        y = self.lin2(y)
        y = self.relu2(y)

        y = self.lin3(y)
        y = self.relu3(y)

        y = self.lin4(y)
        y = self.relu4(y)

        y = self.lin5(y)
        y = self.relu5(y)

        y = self.lin6(y)
        return y

#Instantiate and put the model on the chosen device
fcnn_teacher = FCNN().to(chosen_device)

#Instantiate the optimizer
optimizer = torch.optim.Adam(fcnn_teacher.parameters())

#Training loop
for epoch in range(n_epochs_fcnn):
    for (x_mb, y_mb) in train_iterator:

        # Reshape the train_tuple and put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        y_mb = y_mb.to(chosen_device)

        # Foward propagation
        y_hat_mb = fcnn_teacher(x_mb)
        loss = chosen_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(loss.item())

# get accuracy
for (x_mb, y_mb) in test_iterator:
    print("testing teacher")
    print(MulticlassAccuracy(num_classes=10)(fcnn_teacher(x_mb.reshape(-1, 784)), y_mb))



2.273707151412964
2.1048662662506104
1.2480018138885498
0.5959521532058716
0.7049328684806824
0.47792014479637146
0.336759477853775
0.5006894469261169
0.3463068902492523
0.21562650799751282
0.34996670484542847
0.2992750406265259
0.4503014087677002
0.3733038008213043
0.27063241600990295
testing teacher
tensor(0.8276)
testing teacher
tensor(0.7757)
testing teacher
tensor(0.7625)
testing teacher
tensor(0.8464)


In [15]:
# Create and train the MPS
# MPS parameters
bond_dim = 20
adaptive_mode = False
periodic_bc = False

# Training parameters
gn_var = 0.3 #added gaussian noise variance
gn_mean = 0 #added gaussian noise mean
n_epochs_lmps = 18
mps_learn_rate = 0.0001
mps_reg = 0.0
mps_chosen_loss = nn.CrossEntropyLoss().to(chosen_device)

# Initialize the MPS module
mps_student = MPS(
    input_dim=28 ** 2,
    output_dim=10,
    bond_dim=bond_dim,
    adaptive_mode=adaptive_mode,
    periodic_bc=periodic_bc,
).to(chosen_device)

# Instantiate the optimizer and softmax
lmps_optimizer = torch.optim.Adam(mps_student.parameters(), lr = mps_learn_rate,
                                  weight_decay = mps_reg)
softmax = nn.Softmax(dim=1)

# Training loop 
for epoch in range(n_epochs_lmps):
    for (x_mb, _) in train_iterator:
        # Reshape, add gaussian noise, put on the chosen device
        x_mb = x_mb.reshape(-1, 784).to(chosen_device)
        x_mb = (x_mb + torch.randn(size=x_mb.size())).to(chosen_device)

        # Get softmax of output of teacher
        y_mb = (softmax(fcnn_teacher(x_mb))).to(chosen_device)

        # Foward propagation
        y_hat_mb = mps_student(x_mb)
        loss = mps_chosen_loss(y_hat_mb, y_mb)

        # Backpropagation
        loss.backward()
        lmps_optimizer.step()
        lmps_optimizer.zero_grad()

    print(loss.item())




# get accuracy
for (x_mb, y_mb) in test_iterator:
    print("testing teacher")
    print(MulticlassAccuracy(num_classes=10)(mps_student(x_mb.reshape(-1, 784)), y_mb))


2.293426036834717
2.2562286853790283
2.1127593517303467
2.03153133392334
1.8916183710098267
1.6308128833770752
1.4745482206344604
1.3745629787445068
1.2541496753692627
1.0769885778427124
1.0406547784805298
0.9398740530014038
1.0772169828414917
0.8232111930847168
1.0660088062286377
0.8927194476127625
0.7936847805976868
0.7322091460227966
testing teacher
tensor(0.7400)
testing teacher
tensor(0.7869)
testing teacher
tensor(0.7254)
testing teacher
tensor(0.6192)
