In [25]:
import os
from glob import glob

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import numpy as np

from configurate import get_config
from dataloader import get_dataloader
from protonets import ProtoNet
from one_cycle_policy import OneCyclePolicy
# from prototypical_loss import PrototypicalLoss
from prototypical_loss2 import PrototypicalLoss
from utils import AverageMeter
from tqdm.notebook import tqdm

from easydict import EasyDict

In [4]:
args = EasyDict({
    "dataset_dir" : 'data',
    "exp_name" : 'TEST',
    "epochs" : 100,
    "lr" : 0.001,
    "lr_scheduler_step" : 20,
    "lr_scheduler_gamma" : 0.5,
    "manual_seed" : 7,
    'log_dir' : 'runs',
    'resume' : False,
    'iterations' : 100,
    'classes_per_it_tr' : 60,
    'num_support_tr' : 5,
    'num_query_tr' : 5,
    'classes_per_it_val' : 5,
    'num_support_val' : 5,
    'num_query_val' : 15,
})

args.log_dir = os.path.join(args.log_dir,args.exp_name)
device = 'cpu'

In [5]:
val_loader = get_dataloader(args, 'train')

model = ProtoNet().to(device)

data = next(iter(val_loader))

Loading data...done


In [6]:
x, y = data[0], data[1]

In [20]:
print(len(x))
print(len(y))
print(x.shape)
print(y.shape)
print(x[0].shape)

600
600
torch.Size([600, 1, 28, 28])
torch.Size([600])
torch.Size([1, 28, 28])
60


In [21]:
output = model(x)


In [73]:
import torch
from torch.nn import functional as F
from torch.nn.modules import Module
def euclidean_dist(x, y):
    '''
    Compute euclidean distance between two tensors
    '''
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(-1)
    if d != y.size(-1):
        raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    print(x.shape)
    print(y.shape)

    return torch.pow(x - y, 2).sum(2)

class PrototypicalLoss2(Module):
    '''
    Loss class deriving from Module for the prototypical loss
    '''

    def __init__(self, n_support, n_query, n_class):
        super(PrototypicalLoss2, self).__init__()
        self.n_support = n_support
        self.n_query = n_query
        self.n_class = n_class
        self.n_one_dataset = n_class * n_support
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def forward(self, input, target):
        classes = torch.unique(target)
        n_classes = len(classes)

        # Make prototypes
        support_idxs = torch.stack(list(map(lambda c: target.eq(c).nonzero()[:self.n_support].squeeze(1), classes)))
        prototypes = torch.stack([input[idx_list].mean(0) for idx_list in support_idxs])

        # Make query samples
        n_query = target.eq(classes[0].item()).sum().item() - self.n_support
        query_idxs = torch.stack(list(map(lambda c: target.eq(c).nonzero()[self.n_support:], classes))).view(-1)
        query_samples = input[query_idxs]

        print(query_samples[:1])
        print(prototypes[:1])
        print(query_samples.shape)
        print(prototypes.shape)

        dists = euclidean_dist(query_samples, prototypes)

        log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)
        y_hat = log_p_y.argmax(2)

        target_inds = torch.arange(0, n_classes).to(self.device).view(n_classes, 1, 1).expand(n_classes, n_query, 1).long()

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        acc_val = y_hat.eq(target_inds.squeeze()).float().mean()

        return loss_val, acc_val

criterion = PrototypicalLoss2(args.num_support_tr,args.num_query_tr,args.classes_per_it_tr)
loss, acc = criterion(output,y)

# print(loss,acc)

tensor([[0.7022, 0.0000, 1.0635, 0.9263, 1.2506, 1.4464, 0.2688, 0.7035, 0.9000,
         1.7206, 0.5364, 1.1941, 0.1021, 1.4347, 0.4625, 0.0650, 0.1614, 0.2412,
         0.6972, 2.0165, 1.8409, 0.0552, 0.8492, 0.1855, 0.0000, 0.0000, 1.2591,
         1.0737, 0.2296, 1.0826, 2.1406, 0.2337, 1.3483, 0.9681, 0.8804, 0.7693,
         0.0250, 0.4393, 0.3813, 0.5563, 0.0000, 1.5188, 0.6951, 1.2831, 1.9443,
         1.4142, 1.0503, 1.3200, 0.6807, 0.3002, 1.3096, 0.1353, 1.3005, 1.1397,
         1.3880, 2.3217, 0.8153, 1.0755, 1.4852, 0.0000, 0.3125, 0.4396, 1.4121,
         1.6893]], grad_fn=<SliceBackward>)
tensor([[0.7368, 0.4754, 0.5043, 0.7442, 0.8143, 1.2315, 0.0651, 1.6028, 0.8586,
         1.3559, 0.7544, 1.2305, 0.1563, 1.2188, 1.2568, 0.3338, 0.5958, 0.4568,
         0.0000, 1.1136, 1.8422, 0.2678, 0.8315, 0.3272, 0.0143, 0.0000, 0.9105,
         1.6440, 0.2926, 1.1777, 0.9825, 0.0237, 1.5522, 0.5521, 0.5109, 0.8916,
         0.2939, 0.3058, 0.1418, 1.5322, 0.1213, 0.1949, 0.3409, 