In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.dataloader import MiniImagenet
from proto.protonet import ConvNet, distance, accuracy
from tqdm import tqdm

In [2]:
N = 5
K = 1
Q = 15
batch_size = 8
meta_lr=0.003
fast_lr=0.5
adaptation_steps=1

In [3]:
torch.manual_seed(777)
torch.cuda.manual_seed_all(777)
np.random.seed(777)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
root_path = './datasets/miniimagenet/pkl_file/' 
train_dataset = MiniImagenet(path=root_path, N=N, K=K, Q=Q, mode='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size,\
                          shuffle=True, num_workers=1)
val_dataset = MiniImagenet(path=root_path, N=5, K=1, Q=Q,\
                           mode='validation', total_iter=1000)
val_loader = DataLoader(val_dataset, batch_size=batch_size,\
                        shuffle=True, num_workers=1)
test_dataset = MiniImagenet(path=root_path, N=5, K=1, Q=Q,\
                            mode='test', total_iter=5000)
test_loader = DataLoader(test_dataset, batch_size=batch_size,\
                         shuffle=True, num_workers=1)

100%|██████████| 60000/60000 [00:22<00:00, 2620.83it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2827.42it/s]
100%|██████████| 5000/5000 [00:01<00:00, 2871.85it/s]


In [5]:
task_batch = next(iter(train_loader))

In [6]:
batch = list(zip(*task_batch))

In [7]:
device = 'cpu'
model = ConvNet().to(device)

In [8]:
sx, sy, qx, qy = batch[0]
sx.size(0)
sx.size()

torch.Size([5, 3, 84, 84])

In [9]:
import torch.nn as nn

In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
sx, sy, qx, qy = batch[0]
sx, sy, qx, qy = sx.to(device), sy.to(device), qx.to(device), qy.to(device)
NK = sx.size(0)
support_indices = torch.sort(sy)
query_indices = torch.sort(qy)
sy = sy[support_indices.indices]
qx = qx[query_indices.indices]
qy = qy[query_indices.indices]
data = torch.cat((sx,qx),dim=0)
labels = qy.long()
embeddings = model(data)
support = embeddings[:NK]
proto = support.reshape(5, 1, -1).mean(dim=1)
query = embeddings[NK:]
logits = distance(query, proto)
loss = criterion(logits, labels)
acc = accuracy(logits, labels)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [12]:
embeddings.size()

torch.Size([80, 1600])

In [13]:
targets = qy.long() 
predictions = logits.argmax(dim=1).view(targets.shape)

In [14]:
predictions

tensor([2, 0, 4, 0, 0, 4, 4, 1, 4, 2, 1, 2, 2, 2, 1, 4, 1, 3, 2, 1, 1, 3, 3, 1,
        0, 3, 2, 1, 2, 3, 2, 0, 4, 1, 1, 1, 1, 4, 2, 2, 1, 0, 1, 4, 4, 3, 3, 1,
        1, 3, 3, 1, 1, 4, 1, 1, 2, 1, 1, 1, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 1, 2,
        2, 0, 2])

In [15]:
(predictions == targets).sum().float() / targets.size(0)

tensor(0.2000)

In [16]:
qy.long()

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4])

In [17]:
qy.long() == predictions

tensor([False,  True, False,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False,  True,
         True, False, False,  True, False, False, False,  True, False, False,
         True, False, False, False, False, False, False, False,  True,  True,
        False, False, False, False, False,  True,  True, False, False,  True,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False])

In [75]:
qy.long().shape

torch.Size([75])

In [61]:
proto.shape

torch.Size([5, 1600])

In [34]:
def protnet_train(batch, model, metric, N, K, device):
    sx, sy, qx, qy = batch
    sx, sy, qx, qy = sx.to(device), sy.to(device), qx.to(device), qy.to(device)
    NK = N * K
    support_indices = torch.sort(sy)
    query_indices = torch.sort(qy)
    sx = sx[support_indices.indices]
    sy = sy[support_indices.indices]
    qx = qx[query_indices.indices]
    qy = qy[query_indices.indices]
    data = torch.cat((sx,qx),dim=0)
    labels = qy.long()
    embeddings = model(data)
    support = embeddings[:NK]
    query = embeddings[NK:]
    proto = support.reshape(N, K, -1).mean(dim=1)
    logits = metric(query, proto)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc
    

In [37]:
loss, acc = protnet_train(batch, model, metric=distance, device='cpu')

In [38]:
acc.item()

0.2133333384990692

In [9]:
model = ConvNet().to(device)

In [48]:
embeddings = model(data)
support = embeddings[:N*K]
query = embeddings[N*K:]
proto = support.reshape(N, K, -1).mean(dim=1)

In [49]:
proto

tensor([[0.5249, 1.0267, 0.5028,  ..., 0.8321, 1.7107, 2.2130],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.8166, 0.1553],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.1699, 1.1649, 0.0000,  ..., 0.4147, 1.3676, 0.9837],
        [0.0810, 0.0000, 0.0000,  ..., 0.0673, 0.7102, 0.3721]],
       grad_fn=<MeanBackward1>)

In [16]:
logits = distance(query, support)
labels = qy.long()

In [18]:
loss = F.cross_entropy(logits, labels)
acc = accuracy(logits, labels)

In [None]:
if metric is None:
    metric = pairwise_distances_logits
if device is None:
    device = model.device()
data, labels = batch
data = data.to(device)
labels = labels.to(device)
n_items = shot * ways

# Sort data samples by labels
# TODO: Can this be replaced by ConsecutiveLabels ?
sort = torch.sort(labels)
data = data.squeeze(0)[sort.indices].squeeze(0)
labels = labels.squeeze(0)[sort.indices].squeeze(0)

# Compute support and query embeddings
embeddings = model(data)
support_indices = np.zeros(data.size(0), dtype=bool)
selection = np.arange(ways) * (shot + query_num)
for offset in range(shot):
    support_indices[selection + offset] = True
query_indices = torch.from_numpy(~support_indices)
support_indices = torch.from_numpy(support_indices)
support = embeddings[support_indices]
support = support.reshape(ways, shot, -1).mean(dim=1)
query = embeddings[query_indices]
labels = labels[query_indices].long()

logits = pairwise_distances_logits(query, support)
loss = F.cross_entropy(logits, labels)
acc = accuracy(logits, labels)