In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import cv2
import multiprocessing as mp
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import import_ipynb
import distance
import Protonet
import data
from tqdm import tqdm

In [None]:
def train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch,epoch_size,PATH = "model/protonet.pt"):
    scheduler = optim.lr_scheduler.StepLR(optimizer,1,gamma = 0.5, last_epoch = -1)
    epoch = 0
    while(epoch < max_epoch):
        running_loss = 0.0
        running_acc = 0.0
        
        for episode in tqdm(range(epoch_size)):
            sample = data.extract_sample(n_way, n_support, n_query, train_x, train_y)
            optimizer.zero_grad()
            loss, output = model.set_forward_loss(sample)
            running_loss += output['loss']
            running_acc += output['acc']
            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / epoch_size
        epoch_acc = running_acc / epoch_size
        print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))
        epoch += 1
        scheduler.step()
    #save the model 
    
    torch.save(model, PATH)
    

In [None]:
def test(model, test_x, test_y, n_way, n_support, n_query, test_episode):
    """
    Tests the protonet
    Args:
        model: trained model
        test_x (np.array): images of testing set
        test_y (np.array): labels of testing set
        n_way (int): number of classes in a classification task
        n_support (int): number of labeled examples per class in the support set
        n_query (int): number of labeled examples per class in the query set
        test_episode (int): number of episodes to test on
      """
    running_loss = 0.0
    running_acc = 0.0
    for episode in tqdm(range(test_episode)):
        sample = data.extract_sample(n_way, n_support, n_query, test_x, test_y)
        loss, output = model.set_forward_loss(sample)
        running_loss += output['loss']
        running_acc += output['acc']
    avg_loss = running_loss / test_episode
    avg_acc = running_acc / test_episode
    print('Test results -- Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, avg_acc))