In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="1"
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

# Custom packages
import sys; sys.path.append(os.path.dirname(os.getcwd()))
from models import CNN_MNIST
from utils import test
from imagenet_c import corrupt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using {} device".format(device))

In [None]:
class MNIST_noise(Dataset):
    def __init__(self, root, train, transform=None, target_transform=None):
        self.corruption = 'gaussian_noise'
        mnist = datasets.MNIST(
            root='../data',
            train=train,
            download=False,
        )
        self.images = mnist.data
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load MNIST image
        image = self.images[idx, :, :]

        # Corrupt image and define label accordingly
        severity = np.random.randint(6)
        image = np.uint8(image)
        if severity > 0:
            image = corrupt(image, severity=severity, corruption_name=self.corruption)
        label = severity

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
train_data = MNIST_noise(
    root='../data',
    train=True,
    transform=ToTensor()
)
test_data = MNIST_noise(
    root='../data',
    train=False,
    transform=ToTensor()
)


In [None]:
fig = plt.figure(figsize=(8, 8))
rows, cols = 3, 3
for i in range(1, rows*cols+1):
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    img, label = train_data[sample_idx]
    fig.add_subplot(rows, cols, i)
    plt.title(label)
    plt.imshow(img.squeeze(), cmap='gray')
    plt.axis('off')

In [None]:
BATCH_SIZE = 64
NB_EPOCHS = 3
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

In [None]:
def train(dataloader, model, loss_function, optimizer):
    
    model.train()
    
    size = len(dataloader.dataset)

    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)
        
        # Compute prediction and loss
        y_pred = model(X)
        loss = loss_function(y_pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            print('[{}/{}] loss: {}'.format(batch*len(X), size, loss))


In [None]:
model = CNN_MNIST(output_dim=6).to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
test(test_dataloader, model, loss_function, device)
for e in range(1, NB_EPOCHS+1):
    print(f'Epoch {e}/{NB_EPOCHS}\n-------------------')
    train(train_dataloader, model, loss_function, optimizer)
    test(test_dataloader, model, loss_function, device)

In [None]:
model.eval()

fig = plt.figure(figsize=(8, 8))
rows, cols = 3, 3
for i in range(1, rows*cols+1):
    sample_idx = torch.randint(len(test_data), size=(1,)).item()
    img, label = test_data[sample_idx]
    with torch.no_grad():
        prediction = model(img[None,:,:,:].to(device)).detach().cpu()
        predicted_label = np.argmax(prediction)
    fig.add_subplot(rows, cols, i)
    plt.title(f'predicted: {predicted_label}; real: {label}')
    plt.imshow(img.squeeze(), cmap='gray')
    plt.axis('off')

In [None]:
torch.save(model.state_dict(), f'CNN_noise_MNIST_weights_{datetime.now().strftime("%Y%m%d_%H%M")}.pth')