In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import cv2
import multiprocessing as mp
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import import_ipynb
import distance
import Protonet
import data
from tqdm import tqdm

importing Jupyter notebook from distance.ipynb
importing Jupyter notebook from Protonet.ipynb
importing Jupyter notebook from data.ipynb


In [3]:
def train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch,epoch_size,PATH = "model/protonet.pt"):
    scheduler = optim.lr_scheduler.StepLR(optimizer,1,gamma = 0.5, last_epoch = -1)
    epoch = 0
    while(epoch < max_epoch):
        running_loss = 0.0
        running_acc = 0.0
        
        for episode in tqdm(range(epoch_size)):
            sample = data.extract_sample(n_way, n_support, n_query, train_x, train_y)
            optimizer.zero_grad()
            loss, output = model.set_forward_loss(sample)
            running_loss += output['loss']
            running_acc += output['acc']
            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / epoch_size
        epoch_acc = running_acc / epoch_size
        print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))
        epoch += 1
        scheduler.step()
    #save the model 
    
    torch.save(model, PATH)
    

In [4]:
%%time 
trainx, trainy = data.read_images('omniglot/images_background')
testx, testy = data.read_images('omniglot/images_evaluation')

CPU times: user 1min 8s, sys: 8.94 s, total: 1min 17s
Wall time: 1min 21s


In [6]:
%%time
model = Protonet.load_protonet_conv(x_dim=(3,28,28),hid_dim=64,z_dim=64)
optimizer = optim.Adam(model.parameters(), lr = 0.001)
n_way = 60
n_support = 5
n_query = 5

train_x = trainx
train_y = trainy

max_epoch = 5
epoch_size = 500

train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch, epoch_size)

  0%|          | 0/500 [00:00<?, ?it/s]          

Epoch 1 -- Loss: 0.1509 Acc: 0.9540


  0%|          | 0/500 [00:00<?, ?it/s]          

Epoch 2 -- Loss: 0.0528 Acc: 0.9828


  0%|          | 0/500 [00:00<?, ?it/s]          

Epoch 3 -- Loss: 0.0430 Acc: 0.9858


  0%|          | 0/500 [00:00<?, ?it/s]          

Epoch 4 -- Loss: 0.0333 Acc: 0.9882


                                                 

Epoch 5 -- Loss: 0.0324 Acc: 0.9884
CPU times: user 1h 10min 52s, sys: 7min 46s, total: 1h 18min 38s
Wall time: 1h 11min 48s




In [7]:
def test(model, test_x, test_y, n_way, n_support, n_query, test_episode):
    """
    Tests the protonet
    Args:
        model: trained model
        test_x (np.array): images of testing set
        test_y (np.array): labels of testing set
        n_way (int): number of classes in a classification task
        n_support (int): number of labeled examples per class in the support set
        n_query (int): number of labeled examples per class in the query set
        test_episode (int): number of episodes to test on
      """
    running_loss = 0.0
    running_acc = 0.0
    for episode in tqdm(range(test_episode)):
        sample = data.extract_sample(n_way, n_support, n_query, test_x, test_y)
        loss, output = model.set_forward_loss(sample)
        running_loss += output['loss']
        running_acc += output['acc']
    avg_loss = running_loss / test_episode
    avg_acc = running_acc / test_episode
    print('Test results -- Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, avg_acc))

In [8]:
n_way = 5
n_support = 5
n_query = 5

test_x = testx
test_y = testy

test_episode = 1000

test(model, test_x, test_y, n_way, n_support, n_query, test_episode)

                                                   

Test results -- Loss: 0.0147 Acc: 0.9968


