In [1]:
# %load_ext tensorboard
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import copy
import random
import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
# from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils, datasets
# from torchsummary import summary


# Check assigned GPU
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

# set manual seed for reproducibility
seed = 42

# general reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# gpu training specific
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Sat Mar  4 12:48:20 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    23W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [94]:
# create transforms
# We will just convert to tensor and normalize since no special transforms are mentioned in the paper
transforms_mnist = transforms.Compose([
                                        transforms.Resize(32),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.1307,), (0.3081,))
                                       ])

transforms_cifar10 = transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
                                       ])

mnist_data_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=transforms_mnist)
mnist_data_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=transforms_mnist)

cifar10_data_train = datasets.CIFAR10('./data/cifar10/', train=True, download=True, transform=transforms_cifar10)
cifar10_data_test = datasets.CIFAR10('./data/cifar10/', train=False, download=True, transform=transforms_cifar10)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10/
Files already downloaded and verified


## Models

In [95]:
class LeNet(nn.Module):

    def __init__(self, data):
        super().__init__()
        if data == 'mnist':
            self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0)
        elif data == 'cifar10':
            self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(400, 120)
        self.linear2 = nn.Linear(120, 84)
        self.linear3 = nn.Linear(84, 10)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        out = self.sigmoid(x)
        return out

In [97]:
def train(model, epochs, train_dl, optimizer, loss_fn, lr):

    train_losses = []
    start = time.time()

    for epoch in tqdm(range(epochs)):
        train_loss = 0.0
        model.train()
        for i, batch in enumerate(train_dl):
            image, label = batch
            if torch.cuda.is_available():
                image, label = image.cuda(), label.cuda()
            
            optimizer.zero_grad()
            y_pred = model(image)
            loss = loss_fn(y_pred, label)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= (i + 1)
        train_losses.append(train_loss)

    end = time.time()
    print('Training Done!')
    print(f'Time taken to train: {end - start}')

    return train_losses

In [103]:
model_mnist = LeNet('mnist').cuda()
epochs = 100
lr = 0.01
opt_mnist = torch.optim.SGD(model_mnist.parameters(), lr=lr)
loss_fn = nn. CrossEntropyLoss()

train_dl = DataLoader(mnist_data_train, batch_size=64, shuffle=True)
test_dl = DataLoader(mnist_data_test, batch_size=128, shuffle=True)

In [104]:
train_loss_mnist = train(model_mnist, epochs, train_dl, opt_mnist, loss_fn, lr)

  9%|▉         | 9/100 [03:29<35:21, 23.31s/it]


KeyboardInterrupt: ignored

In [99]:
model_cifar10 = LeNet('cifar10').cuda()
epochs = 100
lr = 0.01
opt_cifar10 = torch.optim.SGD(model_cifar10.parameters(), lr=lr)
loss_fn = nn. CrossEntropyLoss()

train_dl = DataLoader(cifar10_data_train, batch_size=10, shuffle=True)
test_dl = DataLoader(cifar10_data_test, batch_size=10, shuffle=True)

In [100]:
train_loss_cifar10 = train(model_cifar10, epochs, train_dl, opt_cifar10, loss_fn, lr)

  3%|▎         | 3/100 [01:18<42:21, 26.20s/it]


KeyboardInterrupt: ignored

In [None]:
def test(model, test_dl):
    
    model.eval()
    
    with torch.no_grad():
        accuracy = 0
        for i, batch in enumerate(test_dl):
            image, label = batch

            if torch.cuda.is_available():
                image, label = image.cuda(), label.cuda()
                
            y_pred = model(image)
            label = F.one_hot(label, 10)
            a, preds = torch.max(y_pred, dim=1)
            accuracy += torch.tensor(torch.sum(preds == label).item()).item()
        total_acc = accuracy / (i + 1)
            # break
    return total_acc

In [None]:
test(model, test_dl)

In [None]:
rounds = 100
C = 0.1
K = 100
batch_size = 10
lr=0.01