In [1]:
import os
from glob import glob

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.datasets as dset
from torch.utils.data import DataLoader
from torchvision import transforms
from configurate import get_config
from protonets import ProtoNet
from utils import AverageMeter
from tqdm.notebook import tqdm

from dataloader import get_dataloader, PrototypicalBatchSampler
from prototypical_loss import PrototypicalLoss, prototypical_evaluator, euclidean_dist
from torch.nn import functional as F
from easydict import EasyDict

In [2]:
args = EasyDict({
    "dataset_dir" : 'data',
    "exp_name" : 'TEST',
    "epochs" : 100,
    "lr" : 0.001,
    "lr_scheduler_step" : 20,
    "lr_scheduler_gamma" : 0.5,
    "manual_seed" : 7,
    'log_dir' : 'runs',
    'resume' : False,
    'iterations' : 100,
    'classes_per_it_tr' : 100,
    'num_support_tr' : 3,
    'num_query_tr' : 3,
    'classes_per_it_val' : 100,
    'num_support_val' : 5,
    'num_query_val' : 5,
})
args.log_dir = os.path.join(args.log_dir,args.exp_name)
device = 'cpu'

In [3]:
train_loader, test_loader = get_dataloader(args)

Loading data...Files already downloaded and verified
done


In [4]:
torch.cuda.cudnn_enabled = False
np.random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed(args.manual_seed)

model = ProtoNet().to(device)

criterion = PrototypicalLoss().to(device)

optimizer = torch.optim.Adam(model.parameters(), args.lr)

writer = SummaryWriter(args.log_dir)
cudnn.benchmark = True

In [5]:
train_data = next(iter(train_loader))
test_data = next(iter(test_loader))

In [6]:
output = model(train_data[0])
query_sample = model(test_data[0])

loss, _, prototypes = criterion(output, train_data[1], 2)

In [82]:
def prototypical_evaluator2(prototype, input, target):
    dists = euclidean_dist(input, prototype)
    log_p_y = F.log_softmax(-dists,dim=1)
    y_hat = log_p_y.argmax(1)

    loss = -log_p_y.squeeze().view(-1).mean()
    acc = y_hat.eq(target.squeeze()).float().mean()

    return loss, acc, y_hat

loss, acc, yhat = prototypical_evaluator2(prototypes,query_sample,test_data[1])

print(loss,acc,yhat)

tensor(75.6080, grad_fn=<NegBackward>) tensor(0.1000) tensor([76, 72, 81, 51, 19, 24, 31, 52, 60, 66])


In [71]:
a = torch.arange(4,9)
b = torch.arange(10,20)
print(a)
print(b)
gat = b.gather(0,a)
print(gat)

tensor([4, 5, 6, 7, 8])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([14, 15, 16, 17, 18])
