In [9]:
import os
from glob import glob

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.datasets as dset
from torch.utils.data import DataLoader
from torchvision import transforms
from configurate import get_config
from protonets import ProtoNet
from utils import AverageMeter
from tqdm.notebook import tqdm

from dataloader import get_dataloader, PrototypicalBatchSampler
from prototypical_loss import PrototypicalLoss, prototypical_evaluator, euclidean_dist
from torch.nn import functional as F
from easydict import EasyDict

In [10]:
args = EasyDict({
    "dataset_dir" : 'data',
    "exp_name" : 'TEST',
    "epochs" : 100,
    "lr" : 0.001,
    "lr_scheduler_step" : 20,
    "lr_scheduler_gamma" : 0.5,
    "manual_seed" : 7,
    'log_dir' : 'runs',
    'resume' : False,
    'iterations' : 100,
    'classes_per_it_tr' : 50,
    'num_support_tr' : 3,
    'num_query_tr' : 3,
    'classes_per_it_val' : 50,
    'num_support_val' : 5,
    'num_query_val' : 5,
})
args.log_dir = os.path.join(args.log_dir,args.exp_name)
device = 'cpu'

In [11]:
train_loader, test_loader = get_dataloader(args)

Loading data...Files already downloaded and verified
done


In [12]:
torch.cuda.cudnn_enabled = False
np.random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed(args.manual_seed)

model = ProtoNet().to(device)

criterion = PrototypicalLoss().to(device)

optimizer = torch.optim.Adam(model.parameters(), args.lr)

writer = SummaryWriter(args.log_dir)
cudnn.benchmark = True

In [13]:
train_data = next(iter(train_loader))
test_data = next(iter(test_loader))

In [14]:
output = model(train_data[0])
query_sample = model(test_data[0])

loss, _, prototypes = criterion(output, train_data[1], 2)

In [33]:
def prototypical_evaluator2(prototype, input, target):
    dists = euclidean_dist(input, prototype)
    log_p_y = F.log_softmax(-dists,dim=1)
    y_hat = log_p_y.argmax(1)

    loss = -log_p_y.squeeze().view(-1).mean()
    loss2 = torch.nn.NLLLoss(log_p_y, target)
    acc = y_hat.eq(target.squeeze()).float().mean()

    return loss, loss2, acc, y_hat

loss, loss2, acc, yhat = prototypical_evaluator2(prototypes,query_sample,test_data[1])

print(loss,loss2)

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [81]:
def prototypical_loss2(input, target, n_support, device):
    '''
    Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py

    Compute the barycentres by averaging the features of n_support
    samples for each class in target, computes then the distances from each
    samples' features to each one of the barycentres, computes the
    log_probability for each n_query samples for each one of the current
    classes, of appartaining to a class c, loss and accuracy are then computed
    and returned
    Args:
    - input: the model output for a batch of samples
    - target: ground truth for the above batch of samples
    - n_support: number of samples to keep in account when computing
      barycentres, for each one of the current classes
    '''

    classes = torch.unique(target)
    n_classes = len(classes)

    # Make prototypes
    support_idxs = torch.stack(list(map(lambda c: target.eq(c).nonzero()[:n_support].squeeze(1), classes)))
    prototypes = torch.stack([input[idx_list].mean(0) for idx_list in support_idxs])

    # Make query samples
    n_query = target.eq(classes[0].item()).sum().item() - n_support
    query_idxs = torch.stack(list(map(lambda c: target.eq(c).nonzero()[n_support:], classes))).view(-1)
    query_samples = input[query_idxs]

    dists = euclidean_dist(query_samples, prototypes)

    log_p_y = F.log_softmax(-dists, dim=1)
    y_hat = log_p_y.argmax(1)
    target_idxs = torch.LongTensor([x for x in range(n_classes) for _ in range(n_query)]).to(device)

    loss_val = torch.nn.NLLLoss()(log_p_y, target_idxs)
    acc_val = y_hat.eq(target_idxs).float().mean()

    return loss_val, acc_val

if __name__ == '__main__':
    loss,acc = prototypical_loss2(output,train_data[1] , 3, 'cpu')
    print(loss.item(),acc)

tensor([[  2],
        [  3],
        [ 29],
        [ 31],
        [ 37],
        [ 70],
        [ 72],
        [ 79],
        [ 82],
        [ 91],
        [103],
        [104],
        [111],
        [119],
        [146]])
150
42.1596565246582 tensor(0.1000)


In [None]:
test_label = torch.unique(test_data[1])
yhat_label = torch.unique(yhat)
yhat_label

In [21]:
train_input, train_label = train_data[0], train_data[1]

train_class = torch.unique(train_label)
tc = train_class[0]

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])

6