In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.dataloader import MiniImagenet
from proto.protonet import ConvNet, distance, accuracy
from tqdm import tqdm

In [2]:
torch.manual_seed(777)
torch.cuda.manual_seed_all(777)
np.random.seed(777)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

def proto_train(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc


In [4]:
from easydict import EasyDict

In [5]:
args = EasyDict({
    'max_epoch':250,
    'train_ways':20,
    'shot' : 1,
    'train_query':15,
    'test_ways': 5,
    'test_shot' : 1,
    'test_query':15,
    'batch_size':1
})

In [6]:
root_path = './datasets/miniimagenet/pkl_file/' 
train_dataset = MiniImagenet(path=root_path, N=args.train_ways, K=args.shot, Q=args.train_query, \
                             mode='train', total_iter=25000,\
                             transform=True)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,\
                          shuffle=True, num_workers=1)
val_dataset = MiniImagenet(path=root_path, N=args.test_ways, K=args.test_shot, Q=args.test_query,\
                           mode='validation', total_iter=200,\
                           transform=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size,\
                        shuffle=True, num_workers=1)
test_dataset = MiniImagenet(path=root_path, N=args.test_ways, K=args.test_shot, Q=args.test_query,\
                            mode='test', total_iter=2000,\
                            transform=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size,\
                         shuffle=True, num_workers=1)

100%|██████████| 25000/25000 [00:35<00:00, 711.70it/s]
100%|██████████| 200/200 [00:00<00:00, 2001.36it/s]
100%|██████████| 2000/2000 [00:00<00:00, 2691.19it/s]


In [7]:
device = 'cuda'
model = ConvNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=20, gamma=0.5)

In [8]:
# sx, sy, qx, qy = task_batch

In [9]:
for epoch in range(1, args.max_epoch + 1):
    torch.cuda.empty_cache()
    model.train()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0

    for i in range(100):
        sx, sy, qx, qy  = next(iter(train_loader))
        data = torch.cat((sx, qx), dim=1)
        labels = torch.cat((sy, qy), dim=1)
        batch = (data, labels)
        torch.cuda.empty_cache()
        loss, acc = proto_train(model,
                                batch,
                                args.train_ways,
                                args.shot,
                                args.train_query,
                                metric=pairwise_distances_logits,
                                device=device)

        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()

    print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))

    model.eval()
    loss_ctr = 0
    n_loss = 0
    n_acc = 0
    for i, batch in enumerate(val_loader):
        sx, sy, qx, qy  = batch
        data = torch.cat((sx, qx), dim=1)
        labels = torch.cat((sy, qy), dim=1)
        batch = (data, labels)
        loss, acc = proto_train(model,
                                batch,
                                args.test_ways,
                                args.test_shot,
                                args.test_query,
                                metric=pairwise_distances_logits,
                                device=device)

        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

    print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))

loss_ctr = 0
n_acc = 0

for i, batch in enumerate(test_loader, 1):
    sx, sy, qx, qy  = batch
    data = torch.cat((sx, qx), dim=1)
    labels = torch.cat((sy, qy), dim=1)
    batch = (data, labels)
    loss, acc = proto_train(model,
                            batch,
                            args.test_ways,
                            args.test_shot,
                            args.test_query,
                            metric=pairwise_distances_logits,
                            device=device)
    loss_ctr += 1
    n_acc += acc
    print('batch {}: {:.2f}({:.2f})'.format(
        i, n_acc/loss_ctr * 100, acc * 100))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch 1, train, loss=21.4716 acc=0.1090
epoch 1, val, loss=2.2749 acc=0.2695
epoch 2, train, loss=3.2944 acc=0.1064
epoch 2, val, loss=1.7410 acc=0.2777
epoch 3, train, loss=3.0188 acc=0.1142
epoch 3, val, loss=1.6169 acc=0.3013
epoch 4, train, loss=2.9225 acc=0.1297
epoch 4, val, loss=1.5767 acc=0.3173
epoch 5, train, loss=2.8752 acc=0.1376
epoch 5, val, loss=1.5513 acc=0.3249
epoch 6, train, loss=2.8512 acc=0.1419
epoch 6, val, loss=1.5514 acc=0.3328
epoch 7, train, loss=2.8202 acc=0.1502
epoch 7, val, loss=1.5160 acc=0.3433
epoch 8, train, loss=2.8233 acc=0.1440
epoch 8, val, loss=1.5025 acc=0.3522
epoch 9, train, loss=2.7951 acc=0.1533
epoch 9, val, loss=1.5004 acc=0.3583
epoch 10, train, loss=2.7790 acc=0.1551
epoch 10, val, loss=1.4783 acc=0.3631
epoch 11, train, loss=2.7474 acc=0.1742
epoch 11, val, loss=1.4786 acc=0.3661
epoch 12, train, loss=2.7601 acc=0.1666
epoch 12, val, loss=1.4873 acc=0.3693
epoch 13, train, loss=2.7089 acc=0.1810
epoch 13, val, loss=1.4572 acc=0.3813
epo