In [1187]:
import os
import sys
import torch
import random
import hashlib
import dataclasses
from abc import ABC
from tqdm import tqdm
from functools import partial
from typing import Optional, Any, Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import seaborn as sns
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader, Dataset

In [952]:
parent = os.path.dirname(os.getcwd())
if "notebooks" in parent:
    parent = os.path.dirname(parent)
sys.path.append(parent)

In [953]:
from common import example

# test run an import to ensure we can import
example.memoed_fib(25)

75025

# Baseline MNIST classifier (to be attacked)

In [958]:
KERAS_DIRECTORY = os.path.expanduser("~/.keras/datasets")
MNIST_FILEPATH = os.path.join(KERAS_DIRECTORY, "mnist.npz")

In [1378]:
class MNISTSubset(Dataset):
    """Original data is already split so that's the view of dataset we're represent"""

    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx: int):
        return self.x[idx, :, :], self.y[idx]

    def loader(self, batch_size: int = 32, shuffle: bool = True) -> DataLoader:
        return DataLoader(self, shuffle=shuffle, batch_size=batch_size)

    # if time polish this
    def display(self, images=2): 
        labels = len(sorted(np.unique(self.y)))
        fig, axs = plt.subplots(labels, images, tight_layout=True)
        
        for label_i in range(labels):
            for image_j in range(images):
                # not pretty
                image, label = self[random.randint(0, len(self.x))]
                if label != label_i:
                    image, label = self[random.randint(0, len(self.x))]
                axs[label, image_j].imshow(image.reshape(28, 28))
                axs[label, image_j].set_title(int(label))
        plt.show()
        

class InMemoryMNIST:

    data = {}
    test: MNISTSubset
    train: MNISTSubset

    def __init__(self, filepath: str):
        for key, data in dict(np.load(filepath)).items():
            # labels need to be encoded as longs, rest floats
            data = torch.as_tensor(data, dtype=torch.float32 if "x" in key else torch.long)
            if "x" in key:
                # torch conv layers require images of shape (idx, channels, width, height);
                data = data / np.linalg.norm(data)
                data = data.reshape(data.shape[0], 1, 28, 28)
            
            self.data[key] = data
        
        self.train = MNISTSubset(self.data["x_train"], self.data["y_train"])
        self.test = MNISTSubset(self.data["x_train"], self.data["y_test"])

    
    def loaders(self) -> (DataLoader, DataLoader):
        return self.train.loader(), self.test.loader()

In [1379]:
mnist = InMemoryMNIST(MNIST_FILEPATH)

In [1380]:
def accuracy(ground_truth: DataLoader, model: nn.Module, *args, **kwargs) -> float:

    examples = 0
    correctly_labeled = 0
    
    with torch.no_grad():
        correctly_labeled = 0
        for batch in ground_truth:
            images, true_labels = batch
            pred_labels = model(images).argmax(dim=1)

            examples += len(pred_labels)
            correctly_labeled += torch.sum(true_labels == pred_labels)
    
    return int(correctly_labeled) / examples


@dataclasses.dataclass
class Experiment:
    description: str = ""
    epochs: int = 100
    loss: Callable = torch.nn.CrossEntropyLoss
    optimizer: Callable = partial(torch.optim.Adam, lr=0.01, weight_decay=5e-4)
    
    metrics: Dict[str, Callable] = dataclasses.field(default_factory=dict)

In [1381]:
class BaseCNNClassifier(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=(2, 2), stride=2)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=1)
        self.conv2 = nn.Conv2d(10, 1, kernel_size=(2, 2), stride=2)
        
        self.fc1 = nn.Linear(25, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool1(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), 1)

        return x

In [1386]:
def train(
    model: nn.Module,
    test_set: DataLoader,
    training_set: DataLoader, 
    params: Experiment
):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_func = torch.nn.CrossEntropyLoss()
    
    with tqdm(range(1, params.epochs + 1), desc="epoch training loss", leave=True) as epochs:
        for epoch in epochs:
            model.train()
            for batch in training_set:
                optimizer.zero_grad()
                images, true_labels = batch
                prob = model(images)
                loss = loss_func(prob, true_labels)
                loss.backward()
                optimizer.step()

            # evaluate on test set
            model.eval()
            metrics = { metric: f(test_set, model) for metric, f in params.metrics.items() }
            epochs.set_description(f"{epoch=}, {loss=:.4f}, {metrics=}")
    
    return model, metrics

In [None]:
mnist_train, mnist_test = mnist.loaders()

model = BaseCNNClassifier()
train(model, mnist_test, mnist_train, Experiment(epochs=200, metrics={"accuracy": accuracy}))

epoch=14, loss=2.2857, metrics={'accuracy': 0.1135}:   7%| | 14/200 [02:02<25:21

## TODO: repurpose classifier and attack adverially by noise injection