In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from sklearn.mixture import GaussianMixture
from torch.utils.data import DataLoader
from utils import MNIST_partial, accuracy, plot_training_metrics


In [2]:

# Hyperparameters
batch_size = 64
learning_rate = 1e-4
num_epochs = 50
shared_rows = 10  # Number of clusters for GMM

# dataset from csv file, to use for the challenge
train_dataset = MNIST_partial(split = 'train')
val_dataset = MNIST_partial(split='val')

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)


class SharedWeightFC(nn.Module):
    def __init__(self, in_features, out_features, shared_rows):
        """
        A fully connected layer with static weight sharing.
        
        Parameters:
        - in_features: Size of the input.
        - out_features: Size of the output.
        - shared_rows: Number of unique rows to use for sharing weights.
        """
        super(SharedWeightFC, self).__init__()
        
        # Initialize a small set of trainable rows
        self.shared_weights = nn.Parameter(torch.randn(shared_rows, in_features))
        
        # Bias term for the fully connected layer
        self.bias = nn.Parameter(torch.randn(out_features))

        # Store how many rows are shared
        self.out_features = out_features
        self.shared_rows = shared_rows

    def forward(self, x):
        # Repeat the shared rows to build the full weight matrix
        weight_matrix = self.shared_weights.repeat(self.out_features // self.shared_rows, 1)
        
        # Matrix multiplication followed by bias addition
        return torch.matmul(x, weight_matrix.t()) + self.bias

# Define a CNN model with FC layers generated by GMM
class CNNModel(nn.Module):
    def __init__(self, shared_rows):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(8, 12, kernel_size=5)

        # Replace standard FC layers with GMM-generated FC layers
        self.fc1 = SharedWeightFC(in_features=12 * 4 * 4, out_features=20, shared_rows=shared_rows)
        self.fc2 = nn.Linear(20, 10)#SharedWeightFC(in_features=20, out_features=10, shared_rows=shared_rows) #nn.Linear(20, 10)

    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = x.view(x.size(0), -1) # [N, 32 * 8 * 8]
        x = self.fc1(x)
        x = self.fc2(x)
        return x


In [3]:

# Instantiate the model
model = CNNModel(shared_rows=shared_rows)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)


# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
def train(model, train_loader, optimizer, criterion):
    model.train()
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)  # Move data to GPU

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if (i+1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{(i+1)}/{len(train_loader)}], Loss: {loss.item():.4f}')


sum(p.numel() for p in model.parameters() if p.requires_grad)

4770

In [4]:

# Train the model
train(model, train_loader, optimizer, criterion)


Epoch [1/50], Step [10/47], Loss: 3.1089
Epoch [1/50], Step [20/47], Loss: 2.8428
Epoch [1/50], Step [30/47], Loss: 2.5481
Epoch [1/50], Step [40/47], Loss: 2.4872
Epoch [2/50], Step [10/47], Loss: 2.4965
Epoch [2/50], Step [20/47], Loss: 2.4975
Epoch [2/50], Step [30/47], Loss: 2.3268
Epoch [2/50], Step [40/47], Loss: 2.3264
Epoch [3/50], Step [10/47], Loss: 2.1255
Epoch [3/50], Step [20/47], Loss: 2.1246
Epoch [3/50], Step [30/47], Loss: 2.0580
Epoch [3/50], Step [40/47], Loss: 1.9900
Epoch [4/50], Step [10/47], Loss: 2.0556
Epoch [4/50], Step [20/47], Loss: 1.8985
Epoch [4/50], Step [30/47], Loss: 1.8953
Epoch [4/50], Step [40/47], Loss: 1.8339
Epoch [5/50], Step [10/47], Loss: 1.5766
Epoch [5/50], Step [20/47], Loss: 1.7068
Epoch [5/50], Step [30/47], Loss: 1.6474
Epoch [5/50], Step [40/47], Loss: 1.7643
Epoch [6/50], Step [10/47], Loss: 1.5845
Epoch [6/50], Step [20/47], Loss: 1.5667
Epoch [6/50], Step [30/47], Loss: 1.5361
Epoch [6/50], Step [40/47], Loss: 1.5619
Epoch [7/50], St

In [5]:

# Test the model
def test(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # Move data to GPU

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

# Test the model after training
test(model, val_loader)

Test Accuracy: 90.33%


In [6]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

4770

In [7]:
# for i in model.fc1.parameters():
#     print(i)

In [8]:
# 1, 2, 4, 5, 10 
# weight_sharing_parameter_num_list = [3042, 3234, 3618, 3810, 4770] 

# weight_sharing_test_acc_list = [
#                                 [56.25, 89.43, 95.58, 96.95, 97.61],
#                                 [57.55, 86.67, 95.21, 96.27, 97.86],
#                                 [59.63, 88.32, 95.71, 96.51, 97.92]
#                                 ]

weight_sharing_parameter_num_list = [3042, 3234, 3618, 3810, 4770] 

weight_sharing_test_acc_list = [
                                [37.00, ],
                                [38.33, ],
                                [38.83, ]
                                ]