In [63]:
class PrototypeSampler(torch.utils.data.Sampler):

    def __init__(self, target, n_support, n_query, n_way, iterations) -> None:
        self.target = target
        self.classes = torch.unique(target)
        self.num_samples = n_support + n_query
        self.num_cls_per_it = n_way
        self.iterations = iterations

    def __iter__(self):

        for it in range(self.iterations):
            batch = []
            K = np.random.choice(self.classes, self.num_cls_per_it, replace=False)
            for i, cls in enumerate(K):
                idxs = self.target.eq(cls).nonzero().squeeze()
                sample_idxs = torch.randperm(len(idxs))[:self.num_samples]
                batch.append(idxs[sample_idxs])
            batch = torch.cat(batch)
            batch = batch[torch.randperm(len(batch))]
            yield batch

    def __len__(self):
        return self.iterations  

In [83]:
from model import *

In [84]:
proto = Proto((1,105,105), 30)

In [99]:
from torchvision.transforms import ToTensor
data = Omniglot(root="C:\\Users\\abdul\\Projects\\Self-Supervised-Learning\\Datasets\\", background=False, transform=ToTensor())
data

Dataset Omniglot
    Number of datapoints: 13180
    Root location: C:\Users\abdul\Projects\Self-Supervised-Learning\Datasets\omniglot-py
    StandardTransform
Transform: ToTensor()

In [100]:
targets = torch.tensor(list(map(lambda a: a[1], data._flat_character_images)))

In [101]:
target.shape

torch.Size([131])

In [103]:
sam = PrototypeSampler(targets, 5,5,3, 10)

In [104]:

loader = torch.utils.data.DataLoader(data, batch_sampler=sam)

In [105]:
l = iter(loader)
x,y = next(l)

In [106]:
x.shape, y.shape

(torch.Size([30, 1, 105, 105]), torch.Size([30]))

In [107]:
x_proto = proto.prototyper(x)

In [108]:
x_proto.shape

torch.Size([30, 32])

In [109]:
loss, accuracy = prototypical_loss(x_proto, y, 5, 5)

In [110]:
loss

tensor(-0.3900, grad_fn=<NegBackward0>)

In [111]:
accuracy

tensor(0.8667)

In [118]:
n_q, selected_classes = [], []
for cx in classes:
    _nq_x = y.eq(cx.item()).sum().item() - 5
    if _nq_x >= 0:
        n_q.append(_nq_x)
        selected_classes.append(cx.item())
    # n_query = min(n_q + [n_query])
n_query = min(n_q)

In [119]:
n_q

[5, 5, 5]

In [120]:
n_query

5

In [121]:
support_idxs = list(map(lambda c: y.eq(c).nonzero()[:5].squeeze(1), classes))

In [123]:
classes = torch.unique(y)
classes

tensor([136, 505, 541])

In [124]:
n_classes = len(classes)
n_classes

3

In [125]:
support_idxs = list(map(lambda c: y.eq(c).nonzero()[:5].squeeze(1), classes))

In [126]:
support_idxs

[tensor([ 2,  4,  8, 11, 17]),
 tensor([ 0,  1,  3,  9, 12]),
 tensor([ 5,  6,  7, 10, 13])]

In [127]:
prototypes = torch.stack([x_proto[idx_list].mean(0) for idx_list in support_idxs])

In [128]:
prototypes

tensor([[-0.0996,  0.4030,  0.3146, -1.3302, -0.0789,  0.4372,  0.0467,  0.9667,
         -0.4025, -0.3075,  0.1530,  0.3286,  0.1445, -0.6578, -0.0880,  0.8212,
         -0.2657, -0.7124, -0.1041, -0.6374, -0.0478,  0.0821, -0.4149,  0.2195,
         -0.8004,  0.9381,  0.4414,  0.2266, -0.6047, -0.0306, -1.2042,  0.4144],
        [-0.0846,  0.4042,  0.3182, -1.6517, -0.1027,  0.4310, -0.0366,  1.1268,
         -0.4266, -0.2413,  0.3230,  0.4847,  0.2782, -0.7807, -0.1066,  0.8597,
         -0.2857, -0.7635, -0.2026, -0.6301, -0.1435,  0.0330, -0.3067,  0.4153,
         -0.9599,  1.1893,  0.5344,  0.1853, -0.7977, -0.1586, -1.3060,  0.4588],
        [-0.0709,  0.3467,  0.2652, -1.6565, -0.0706,  0.4056, -0.0216,  1.1193,
         -0.4536, -0.2727,  0.2180,  0.4795,  0.2470, -0.8360, -0.1332,  0.9431,
         -0.3035, -0.7747, -0.2230, -0.6434, -0.1229,  0.0860, -0.2990,  0.3889,
         -0.9903,  1.2481,  0.5349,  0.2242, -0.8561, -0.1442, -1.2412,  0.4092]],
       grad_fn=<StackBac

In [129]:
prototypes.shape

torch.Size([3, 32])

In [131]:
query_idxs = torch.stack(list(map(lambda c: y.eq(c).nonzero()[5:], classes))).view(-1)
query_idxs

tensor([18, 26, 27, 28, 29, 14, 19, 22, 23, 25, 15, 16, 20, 21, 24])

In [132]:
query_idxs.shape

torch.Size([15])

In [134]:
query_samples = x_proto[query_idxs]
query_samples.shape

torch.Size([15, 32])

In [135]:
dists = euclidean_dist(query_samples, prototypes)

In [136]:
dists.shape

torch.Size([15, 3])

In [137]:
dists

tensor([[0.1905, 0.9696, 0.9681],
        [0.1411, 0.1980, 0.2551],
        [0.0820, 0.7769, 0.8305],
        [0.1365, 0.8987, 0.9146],
        [0.0662, 0.6543, 0.6948],
        [0.3661, 0.1109, 0.1767],
        [0.4857, 0.1620, 0.1723],
        [0.2910, 0.1346, 0.1793],
        [0.2240, 0.1319, 0.1757],
        [0.4685, 0.1464, 0.2019],
        [0.1203, 0.3598, 0.3304],
        [0.3266, 0.1344, 0.1274],
        [0.5145, 0.1121, 0.0846],
        [0.1650, 0.7946, 0.7341],
        [0.2051, 0.1829, 0.1821]], grad_fn=<SumBackward1>)

In [138]:
dists[0,:]

tensor([0.1905, 0.9696, 0.9681], grad_fn=<SliceBackward0>)

In [139]:
query_idxs[0]

tensor(18)

In [140]:
classes[0]

tensor(136)

In [141]:
classes

tensor([136, 505, 541])

In [142]:
y[18]

tensor(136)

In [144]:
F.softmax(-dists[0,:])

  F.softmax(-dists[0,:])


tensor([0.5213, 0.2392, 0.2395], grad_fn=<SoftmaxBackward0>)

In [145]:
F.softmax(-dists, dim=1)

tensor([[0.5213, 0.2392, 0.2395],
        [0.3525, 0.3330, 0.3145],
        [0.5070, 0.2531, 0.2399],
        [0.5192, 0.2423, 0.2385],
        [0.4788, 0.2659, 0.2553],
        [0.2858, 0.3689, 0.3454],
        [0.2667, 0.3686, 0.3648],
        [0.3042, 0.3557, 0.3401],
        [0.3179, 0.3485, 0.3336],
        [0.2713, 0.3744, 0.3542],
        [0.3850, 0.3030, 0.3120],
        [0.2913, 0.3531, 0.3556],
        [0.2480, 0.3708, 0.3812],
        [0.4765, 0.2538, 0.2697],
        [0.3283, 0.3357, 0.3360]], grad_fn=<SoftmaxBackward0>)

In [146]:
F.softmax(-dists, dim=1).shape

torch.Size([15, 3])

In [147]:
F.softmax(-dists, dim=1).view(3, n_query, -1)

tensor([[[0.5213, 0.2392, 0.2395],
         [0.3525, 0.3330, 0.3145],
         [0.5070, 0.2531, 0.2399],
         [0.5192, 0.2423, 0.2385],
         [0.4788, 0.2659, 0.2553]],

        [[0.2858, 0.3689, 0.3454],
         [0.2667, 0.3686, 0.3648],
         [0.3042, 0.3557, 0.3401],
         [0.3179, 0.3485, 0.3336],
         [0.2713, 0.3744, 0.3542]],

        [[0.3850, 0.3030, 0.3120],
         [0.2913, 0.3531, 0.3556],
         [0.2480, 0.3708, 0.3812],
         [0.4765, 0.2538, 0.2697],
         [0.3283, 0.3357, 0.3360]]], grad_fn=<ViewBackward0>)

In [148]:
log_py = F.softmax(-dists, dim=1).view(3, n_query, -1)

In [149]:
target_inds = torch.arange(0, n_classes).view(n_classes, 1, 1).expand(n_classes, n_query, 1).long()

In [150]:
target_inds

tensor([[[0],
         [0],
         [0],
         [0],
         [0]],

        [[1],
         [1],
         [1],
         [1],
         [1]],

        [[2],
         [2],
         [2],
         [2],
         [2]]])

In [152]:
log_py.gather(2, target_inds)

tensor([[[0.5213],
         [0.3525],
         [0.5070],
         [0.5192],
         [0.4788]],

        [[0.3689],
         [0.3686],
         [0.3557],
         [0.3485],
         [0.3744]],

        [[0.3120],
         [0.3556],
         [0.3812],
         [0.2697],
         [0.3360]]], grad_fn=<GatherBackward0>)

In [153]:
log_py

tensor([[[0.5213, 0.2392, 0.2395],
         [0.3525, 0.3330, 0.3145],
         [0.5070, 0.2531, 0.2399],
         [0.5192, 0.2423, 0.2385],
         [0.4788, 0.2659, 0.2553]],

        [[0.2858, 0.3689, 0.3454],
         [0.2667, 0.3686, 0.3648],
         [0.3042, 0.3557, 0.3401],
         [0.3179, 0.3485, 0.3336],
         [0.2713, 0.3744, 0.3542]],

        [[0.3850, 0.3030, 0.3120],
         [0.2913, 0.3531, 0.3556],
         [0.2480, 0.3708, 0.3812],
         [0.4765, 0.2538, 0.2697],
         [0.3283, 0.3357, 0.3360]]], grad_fn=<ViewBackward0>)

In [154]:
log_py.gather(2, target_inds).squeeze()

tensor([[0.5213, 0.3525, 0.5070, 0.5192, 0.4788],
        [0.3689, 0.3686, 0.3557, 0.3485, 0.3744],
        [0.3120, 0.3556, 0.3812, 0.2697, 0.3360]], grad_fn=<SqueezeBackward0>)

In [155]:
log_py.gather(2, target_inds).squeeze().view(-1)

tensor([0.5213, 0.3525, 0.5070, 0.5192, 0.4788, 0.3689, 0.3686, 0.3557, 0.3485,
        0.3744, 0.3120, 0.3556, 0.3812, 0.2697, 0.3360],
       grad_fn=<ViewBackward0>)

In [156]:
log_py.gather(2, target_inds).squeeze().view(-1).mean()

tensor(0.3900, grad_fn=<MeanBackward0>)

In [157]:
log_py.shape

torch.Size([3, 5, 3])

In [158]:
log_py.max(2)

torch.return_types.max(
values=tensor([[0.5213, 0.3525, 0.5070, 0.5192, 0.4788],
        [0.3689, 0.3686, 0.3557, 0.3485, 0.3744],
        [0.3850, 0.3556, 0.3812, 0.4765, 0.3360]], grad_fn=<MaxBackward0>),
indices=tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [0, 2, 2, 0, 2]]))

In [159]:
_, y_hat = log_py.max(2)

In [162]:
y_hat.shape, target_inds.shape

(torch.Size([3, 5]), torch.Size([3, 5, 1]))

In [163]:
y_hat.eq(target_inds.squeeze())

tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [False,  True,  True, False,  True]])

In [165]:
y_hat.eq(target_inds.squeeze()).float().mean()

tensor(0.8667)

In [167]:
acc= y_hat.eq(target_inds.squeeze()).float().mean()
acc

tensor(0.8667)