In [1]:
import numpy as np
import os
import torch 
import torch.nn as nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.nn.functional import one_hot
from torchvision.utils import save_image
from preprocessing import batch_elastic_transform
from model import PrototypeModel
from train import save_images, run_epoch

In [2]:
# Global parameters for device and reproducibility
torch.manual_seed(7)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Globals
learning_rate = 0.0001
training_epoch = 1500
batch_size = 250

sigma = 4
alpha = 20
n_prototypes = 15
latent_size = 40
n_classes = 10

lambda_class = 20
lambda_ae = 1
lambda_1 = 1              # 1 and 2 here corresponds to the notation we used in the paper
lambda_2 = 1

model_path = 'models/'
prototype_path = 'images/prototypes/'
decoding_path = 'images/decoding/'

In [4]:
# Load data
train_data = MNIST('./data', train=True, download=True, transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
test_data = MNIST('./data', train=False, download=True,transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))


### Initialize the model and the optimizer.
proto = PrototypeModel(15, 40, 10).to(device)
optim = torch.optim.Adam(proto.parameters(), lr=learning_rate)
dataloader = DataLoader(train_data, batch_size=batch_size)

# Run for a number of epochs
for epoch in range(training_epoch):
    epoch_loss = 0.0
    epoch_acc = 0.0
    it = 0

    it, epoch_loss, epoch_acc, dec = run_epoch(proto, dataloader, optim, it, epoch_loss, epoch_acc)

    # Get prototypes and decode them to display
    prototypes = proto.prototype.get_prototypes()
    prototypes = prototypes.view(-1, 10, 2, 2)
    imgs = proto.decoder(prototypes)

    # Save images
    save_images(prototype_path, decoding_path, imgs, dec, epoch)

    # Save model
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    torch.save(proto, model_path+"proto.pth")

    # Print statement to check on progress
    print("Epoch: ", epoch, "Loss: ", epoch_loss / it, "Acc: ", epoch_acc/it)

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch:  0 Loss:  79.56553010940551 Acc:  0.09919999999999989
Epoch:  1 Loss:  63.805376068751016 Acc:  0.09929999999999987
Epoch:  2 Loss:  58.93859675725301 Acc:  0.09876666666666652
Epoch:  3 Loss:  57.97762449582418 Acc:  0.11093333333333333
Epoch:  4 Loss:  57.65849153200785 Acc:  0.1122
Epoch:  5 Loss:  57.417175404230754 Acc:  0.11213333333333332
Epoch:  6 Loss:  57.180316893259686 Acc:  0.1135333333333333
Epoch:  7 Loss:  56.852639547983806 Acc:  0.11956666666666659
Epoch:  8 Loss:  55.75808108647664 Acc:  0.21913333333333332
Epoch:  9 Loss:  51.828138494491576 Acc:  0.3889333333333337
Epoch:  10 Loss:  46.827057441075645 Acc:  0.47896666666666715
Epoch:  11 Loss:  42.97631827990214 Acc:  0.5275666666666665
Epoch:  12 Loss:  40.16992999712626 Acc:  0.5681499999999994
Epoch:  13 Loss:  37.93062192598979 Acc:  0.6067999999999996
Epoch:  14 Loss:  36.005694301923114 Acc:  0.642333333333333
Epoch:  15 Loss:  34.2927796681722 Acc:  0.6699333333333329
Epoch:  16 Loss:  32.635850707689

KeyboardInterrupt: 