In [1]:
import numpy as np
import os
import torch 
import torch.nn as nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.nn.functional import one_hot
from torchvision.utils import save_image
from preprocessing import batch_elastic_transform
from model import PrototypeModel
from train import save_images, run_epoch

ModuleNotFoundError: No module named 'torch'

In [None]:
# Global parameters for device and reproducibility
torch.manual_seed(7)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Globals
learning_rate = 0.0001
training_epoch = 1500
batch_size = 250

sigma = 4
alpha = 20
n_prototypes = 15
latent_size = 40
n_classes = 10

lambda_class = 20
lambda_ae = 1
lambda_1 = 1              # 1 and 2 here corresponds to the notation we used in the paper
lambda_2 = 1

model_path = 'models/'
prototype_path = 'images/prototypes/'
decoding_path = 'images/decoding/'

In [None]:
# Load data
train_data = MNIST('./data', train=True, download=True, transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))
test_data = MNIST('./data', train=False, download=True,transform=transforms.Compose([
                                            transforms.ToTensor(),
                                        ]))


### Initialize the model and the optimizer.
proto = PrototypeModel(15, 40, 10).to(device)
optim = torch.optim.Adam(proto.parameters(), lr=learning_rate)
dataloader = DataLoader(train_data, batch_size=batch_size)

# Run for a number of epochs
for epoch in range(training_epochs):
    epoch_loss = 0.0
    epoch_acc = 0.0
    it = 0

    it, epoch_loss, epoch_acc, dec = run_epoch(proto, dataloader, optim, it, epoch_loss, epoch_acc)

    # Get prototypes and decode them to display
    prototypes = proto.prototype.get_prototypes()
    prototypes = prototypes.view(-1, 10, 2, 2)
    imgs = proto.decoder(prototypes)

    # Save images
    save_images(prototype_path, decoding_path, imgs, dec, epoch)

    # Save model
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    torch.save(proto, model_path+"proto.pth")

    # Print statement to check on progress
    print("Epoch: ", epoch, "Loss: ", epoch_loss / it, "Acc: ", epoch_acc/it)