# Deep Learning - Project 1
Wojciech Kutak

---

### Cross Attention Network for few-shot learning problem
#### 1. Code

In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
import torchvision
from torchinfo import summary
from tqdm import tqdm


DATA_PATH = os.path.join("..", "data")
FEW_SHOT_PATH = os.path.join("..", "data_few_shot")

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsampler = None):
        super().__init__()
        self.downsampler = downsampler

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)
        self.out_channels = out_channels


    def forward(self, x):
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity
        out = self.relu(out)
        return out


class ResNet_32x32(nn.Module):
    def __init__(self, block: nn.Module = ResidualBlock):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1)
        self.bnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer1 = self._residual_layer(block, 64, 128, 2)

        self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer2 = self._residual_layer(block, 128, 256, 2)

        self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer3 = self._residual_layer(block, 256, 512, 2)

        self.output_shape = (512, 3, 3)


    def _residual_layer(self, block, in_channels, out_channels, blocks_num, stride=1):
        """Creates a residual layer consisting out of residual blocks"""
        downsampler = None
        if stride != 1 or in_channels != out_channels:
            downsampler = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        layers = []
        layers.append(block(in_channels, out_channels, stride, downsampler))
        for _ in range(blocks_num - 1):
            layers.append(block(out_channels, out_channels, stride))

        return nn.Sequential(*layers)


    def forward(self, x):
        x = self.relu(self.bnorm(self.conv(x)))

        x = self.res_layer1(self.max_pool1(x))
        x = self.res_layer2(self.max_pool2(x))
        x = self.res_layer3(self.max_pool3(x))

        return x

In [150]:
summary(ResNet_32x32(), (30, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet_32x32                             [30, 512, 3, 3]           --
├─Conv2d: 1-1                            [30, 64, 32, 32]          1,792
├─BatchNorm2d: 1-2                       [30, 64, 32, 32]          128
├─ReLU: 1-3                              [30, 64, 32, 32]          --
├─MaxPool2d: 1-4                         [30, 64, 15, 15]          --
├─Sequential: 1-5                        [30, 128, 15, 15]         --
│    └─ResidualBlock: 2-1                [30, 128, 15, 15]         --
│    │    └─Sequential: 3-1              [30, 128, 15, 15]         8,448
│    │    └─Conv2d: 3-2                  [30, 128, 15, 15]         73,856
│    │    └─BatchNorm2d: 3-3             [30, 128, 15, 15]         256
│    │    └─ReLU: 3-4                    [30, 128, 15, 15]         --
│    │    └─Conv2d: 3-5                  [30, 128, 15, 15]         147,584
│    │    └─BatchNorm2d: 3-6             [30, 128, 15, 15]         2

In [240]:
import math

class FusionLayer(nn.Module):
    def __init__(self, m: int, bottleneck_size: int):
        super(FusionLayer, self).__init__()

        self.temperature = 1.0
        self.m = m
        self.bottleneck_size = bottleneck_size
        self.conv1 = nn.Conv2d(self.m, self.bottleneck_size, kernel_size=1)
        self.bn = nn.BatchNorm2d(self.bottleneck_size)
        self.conv2 = nn.Conv2d(self.bottleneck_size, self.m, kernel_size=1)
        self.relu = nn.ReLU()

        # TODO:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))


    def forward(self, R: torch.Tensor):
        # print("\nFusionLayer forward")
        # print("R:", R.shape)

        # If R = R_p
        # Dim=3 corresponds to all queries and dim=4 to all support examples
        # We are averaging over all queries to extract info about the mean relevance of
        # certain pixel from support to all other queries
        # Otherwise it's the other way around
        w = torch.mean(R, dim=3)
        # print("w", w.shape)
        w = w.transpose(1, 3)
        # print("w", w.shape)

        w = self.conv1(w)
        # print("w", w.shape)
        w = self.bn(w)
        w = self.relu(w)
        w = self.conv2(w)
        # print("w", w.shape)
        w = w.transpose(1, 3).unsqueeze(3)
        # print("w", w.shape)

        # I guess... Unfortunately not well described in paper
        # I suppose it's a weighted average over query features
        A = torch.mean(w * R, dim=-1)
        A = F.softmax(A / self.temperature, dim=-1) + 1

        # print("A:", A.shape)
        return A

        # print("R shape:", R.shape)

        # b, M, m1, m2 = R.shape
        # assert m1 == m2

        # # Global average pooling
        # w: torch.Tensor = R.mean(dim=-2)
        # print("spatial w:", w.shape)
        # w = w.transpose(1, 3)
        # print("spatial w after:", w.shape)

        # # w = w.unsqueeze(-2)
        # # Meta learner
        # w = self.conv1(w)
        # w = self.bn(w)
        # print("conv1 w:", w.shape)
        # w = self.relu(w)
        # print("relu w:",w.shape)
        # w = self.conv2(w)
        # print("conv2 w:",w.shape)

        # w = w.transpose(1, 3)
        # print("w tra", w.shape)
        # w = w.squeeze((-2, -1))
        # print("w squeeze", w.shape)

        # # Convolution operation
        # # w is now a vector of average class
        # A = self.attention(w, R)
        # print("Attention:", A.shape)


        # return A


    # def attention(self, weights: torch.Tensor, R: torch.Tensor) -> torch.Tensor:

    #     weights_t = weights.transpose(-3, -1).squeeze(1)

    #     b, M, m, h, w = R.shape
    #     R = R.view(b * M, m, h*w)
    #     # print("weights:", weights.shape)
    #     # print("weights_t:", weights_t.shape)
    #     # print("R:", R.shape)

    #     R_mean = torch.bmm(weights_t, R) / self.temperature
    #     # print("R_mean:", R_mean.shape)
    #     R_mean = R_mean.view(b, M, 1, m)
    #     A = F.softmax(R_mean, dim=-1)
    #     return A

In [133]:
width, height = 5, 5
m = width*height
M = 5
batch_size = 3
summary(FusionLayer(m=m, bottleneck_size=15), (batch_size, M, m, height, width))

FusionLayer forward
R shape: torch.Size([3, 5, 25, 5, 5])
spatial w: torch.Size([3, 5, 25, 1, 1])
spatial w after: torch.Size([15, 25, 1, 1])
conv1 w: torch.Size([15, 15, 1, 1])
relu w: torch.Size([15, 15, 1, 1])
conv2 w: torch.Size([15, 25, 1, 1])
weights: torch.Size([15, 25, 1, 1])
weights_t: torch.Size([15, 1, 25])
R: torch.Size([15, 25, 25])
R_mean: torch.Size([15, 1, 25])
Attention: torch.Size([3, 5, 1, 25])


Layer (type:depth-idx)                   Output Shape              Param #
FusionLayer                              [3, 5, 1, 25]             --
├─AdaptiveAvgPool2d: 1-1                 [3, 5, 25, 1, 1]          --
├─Conv2d: 1-2                            [15, 15, 1, 1]            390
├─ReLU: 1-3                              [15, 15, 1, 1]            --
├─Conv2d: 1-4                            [15, 25, 1, 1]            400
Total params: 790
Trainable params: 790
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.01
Input size (MB): 0.04
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.05

In [59]:
a = torch.tensor([1.0, 2.0]).view(1, 2)
b = torch.tensor([[2.0, 3.0, 4.0], [10.0, 11.0, 112.0]]).view(1, 2, 3, 1)
a*b

tensor([[[[  2.,   4.],
          [  3.,   6.],
          [  4.,   8.]],

         [[ 10.,  20.],
          [ 11.,  22.],
          [112., 224.]]]])

In [239]:
class CrossAttentionModule(nn.Module):
    def __init__(self, input_shape: tuple[int,int,int]):
        super(CrossAttentionModule, self).__init__()

        # print("input shape", *input_shape)
        c, h, w = input_shape
        self.fusion_layer = FusionLayer(m=h*w, bottleneck_size=int(h*w*2/3))


    def forward(self, P: torch.Tensor, Q: torch.Tensor):
        # print("\nCross Attention forward")
        b, n_p, c, h, w = P.shape
        n_q = Q.shape[1]
        # print("P", P.shape)
        # print("Q", Q.shape)

        P = P.view(b, n_p, c, h*w)
        Q = Q.view(b, n_q, c, h*w)
        # print("P", P.shape)
        # print("Q", Q.shape)

        P_norm = F.normalize(P, p=2, dim=2)
        Q_norm = F.normalize(Q, p=2, dim=2)
        # print("P_norm", P_norm.shape)
        # print("Q_norm", Q_norm.shape)

        P_norm = P.transpose(2, 3).view(b, n_p, 1, h*w, c)
        Q_norm = Q.view(b, 1, n_q, c, h*w)
        # print("P_norm", P_norm.shape)
        # print("Q_norm", Q_norm.shape)

        # Now it will be broadcasted so all pairs of per batch will be miltiplied
        R_p = torch.matmul(P_norm, Q_norm)
        R_q = R_p.transpose(3, 4)
        # print("R_p", R_p.shape)
        # print("R_q", R_q.shape)

        A_p: torch.Tensor = self.fusion_layer(R_p)
        A_q: torch.Tensor = self.fusion_layer(R_q)

        # print("P", P.shape, "A_p", A_p.shape)
        P = P.unsqueeze(2)
        A_p = A_p.unsqueeze(3)
        # print("P", P.shape, "A_p", A_p.shape)
        P = P * A_p
        # print("P", P.shape)
        P = P.view(b, n_p, n_q, c, h, w)
        # print("P", P.shape)

        # print("Q", Q.shape, "A_q", A_q.shape)
        Q = Q.unsqueeze(1)
        A_q = A_q.unsqueeze(3)
        Q = Q * A_q
        Q = Q.view(b, n_p, n_q, c, h, w)
        # print("Q", Q.shape)

        return P.transpose(1, 2), Q.transpose(1, 2)

    # def forward(self, S: torch.Tensor):
        # P, Q = S[0], S[1, :, 0, :, :, :]
        # print("CrossAttentionModule forward")
        # print(P.shape, Q.shape)
        # assert P.shape == Q.shape
        # b, M, c, h, w = P.shape
        # assert (b, c, h, w) == Q.shape
        # m = h*w
        # # Change representation from tensor c*h*w to c*m (2 dims)
        # P = P.view(b, M, c, m)
        # Q = Q.view(b, 1, c, m)

        # P_norm = F.normalize(P, p=2, dim=2)
        # Q_norm = F.normalize(Q, p=2, dim=2)
        # print("P norm:", P_norm.shape)
        # P_norm_t = P_norm.transpose(-2, -1).unsqueeze(2)
        # Q_norm = Q_norm.unsqueeze(1)
        # print("P norm t:", P_norm_t.shape)
        # print("Q norm:", Q_norm.shape)
        # R_q = torch.matmul(P_norm_t, Q_norm)
        # print("R_q:", R_q.shape)
        # R_p = R_q.transpose(-2, -1)
        # R_q = R_q
        # print("R_q", R_q.shape)
        # print("R_p", R_p.shape)



        # A_p: torch.Tensor = self.fusion_layer(R_p)
        # print()
        # A_q: torch.Tensor = self.fusion_layer(R_q)
        # print()

        # print("P:", P.shape, "A_p:", A_p.shape)
        # P_feat = torch.mul(P, A_p) + P
        # print("P_feat:", P_feat.shape)

        # print("Q:", Q.shape, "A_q:", A_q.shape)
        # Q = Q.expand(b, M, c, m)
        # print("Q exp:", Q.shape, "A_q:", A_q.shape)
        # Q_feat = torch.mul(Q, A_q) + Q
        # print("Q_feat:", Q_feat.shape)
        # return P_feat, Q_feat

In [71]:
b = 1
M = 6
K = 5
c = 512
height, width = 5, 5
can = CrossAttentionModule((c, height, width))

support = torch.ones(M, K, c, height, width)
query = torch.ones(M, c, height, width)
can(support, query)

P norm: torch.Size([6, 5, 512, 25])
P norm t: torch.Size([6, 5, 1, 25, 512])
Q norm: torch.Size([6, 1, 1, 512, 25])
R_q: torch.Size([6, 5, 1, 25, 25])
R_q torch.Size([6, 5, 1, 25, 25])
R_p torch.Size([6, 5, 1, 25, 25])


ValueError: too many values to unpack (expected 4)

In [None]:
b = 3
m = 5



summary(CrossAttentionModule((512, 5, 5)), (2, b, m, 512, 5, 5))

CrossAttentionModule forward
torch.Size([3, 5, 512, 5, 5]) torch.Size([3, 512, 5, 5])
P norm: torch.Size([3, 5, 512, 25])
P norm t: torch.Size([3, 5, 25, 512])
Q norm: torch.Size([3, 1, 512, 25])
R_q torch.Size([3, 5, 25, 5, 5])
R_p torch.Size([3, 5, 25, 5, 5])
FusionLayer forward
R shape: torch.Size([3, 5, 25, 5, 5])
spatial w: torch.Size([3, 5, 25, 1, 1])
spatial w after: torch.Size([15, 25, 1, 1])
conv1 w: torch.Size([15, 15, 1, 1])
relu w: torch.Size([15, 15, 1, 1])
conv2 w: torch.Size([15, 25, 1, 1])
weights: torch.Size([15, 25, 1, 1])
weights_t: torch.Size([15, 1, 25])
R: torch.Size([15, 25, 25])
R_mean: torch.Size([15, 1, 25])
Attention: torch.Size([3, 5, 1, 25])

FusionLayer forward
R shape: torch.Size([3, 5, 25, 5, 5])
spatial w: torch.Size([3, 5, 25, 1, 1])
spatial w after: torch.Size([15, 25, 1, 1])
conv1 w: torch.Size([15, 15, 1, 1])
relu w: torch.Size([15, 15, 1, 1])
conv2 w: torch.Size([15, 25, 1, 1])
weights: torch.Size([15, 25, 1, 1])
weights_t: torch.Size([15, 1, 25])


Layer (type:depth-idx)                   Output Shape              Param #
CrossAttentionModule                     [3, 5, 512, 25]           --
├─FusionLayer: 1-1                       [3, 5, 1, 25]             --
│    └─AdaptiveAvgPool2d: 2-1            [3, 5, 25, 1, 1]          --
│    └─Conv2d: 2-2                       [15, 15, 1, 1]            390
│    └─ReLU: 2-3                         [15, 15, 1, 1]            --
│    └─Conv2d: 2-4                       [15, 25, 1, 1]            400
├─FusionLayer: 1-2                       [3, 5, 1, 25]             (recursive)
│    └─AdaptiveAvgPool2d: 2-5            [3, 5, 25, 1, 1]          --
│    └─Conv2d: 2-6                       [15, 15, 1, 1]            (recursive)
│    └─ReLU: 2-7                         [15, 15, 1, 1]            --
│    └─Conv2d: 2-8                       [15, 25, 1, 1]            (recursive)
Total params: 790
Trainable params: 790
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.02
Input size (MB): 1.54


In [36]:
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
torch.mul(A, B), torch.matmul(A, B)

(tensor([[ 1.,  4.],
         [ 9., 16.]]),
 tensor([[ 7., 10.],
         [15., 22.]]))

In [108]:
A = torch.tensor([[1.0, 2.0], [3., 4.]])
B = torch.tensor([[7., 11.]])
A*B

tensor([[ 7., 22.],
        [21., 44.]])

In [256]:
class CrossAttentionNetwork(nn.Module):
    def __init__(self, cam: CrossAttentionModule = None, n_classes: int = 6):
        super(CrossAttentionNetwork, self).__init__()

        self.embedding = ResNet_32x32()
        self.cam = cam if cam is not None else CrossAttentionModule(self.embedding.output_shape)
        self.classifer = nn.Conv2d(self.embedding.output_shape[0], n_classes, kernel_size=1)
        self.n_classes = n_classes


    def forward(self,
                support: torch.Tensor, query: torch.Tensor,
                y_support: torch.Tensor, y_query: torch.Tensor
    ):
        # n_support = number of classes * number of examples per class
        b, n_support, c, h, w = support.shape

        n_y_support_classes = y_support.shape[-1]

        # n_query = number of queries to classify
        n_query = query.size(1)

        assert support.shape[0] == query.shape[0] and \
            support.shape[2:] == query.shape[2:]

        # Embed all support and query images at once
        support = support.view(-1, c, h, w)
        query = query.view(-1, c, h, w)
        X = torch.cat((support, query), dim=0)
        # E is of shape (b*(n_support + n_query), c, h, w)
        E: torch.Tensor = self.embedding(X)
        # print("E", E.shape)

        E_support = E[:b*n_support]
        # Flattening images from (c,h,w) to (c*h*w)
        E_support = E_support.view(b, n_support, -1)
        # print("E_support", E_support.shape)
        y_support = y_support.transpose(1, 2)
        # print("y_supp", y_support.shape)

        # Smart way of averaging embedding vectors, generalized for unbalanced number of
        # support images wrt to class
        E_support_mean = torch.bmm(y_support, E_support)
        # print("y sum", torch.sum(y_support, dim=2, keepdim=True).shape)
        y_support_class_count = torch.sum(y_support, dim=2, keepdim=True).expand_as(E_support_mean)
        # print("y_support_class_count", y_support_class_count.shape)
        E_support_mean = torch.div(E_support_mean, y_support_class_count)
        # print("E_support mean", E_support_mean.shape)
        # print(*E.shape[1:])
        E_support_mean = E_support_mean.view(b, n_y_support_classes, *E.shape[1:])

        E_query = E[b*n_support:]
        E_query = E_query.view(b, n_query, *E.shape[1:])

        # Embedding with saturated features which are common between
        # each support class and query sample
        # E_support_mean/E_query: (b, n_query, n_support, c, h, w)
        E_support_mean, E_query = self.cam(E_support_mean, E_query)
        # print("\nE_support_mean:", E_support_mean.shape)
        # print("E_query:", E_query.shape)

        # Computing mean activations across height and width of image (means of active maps)
        E_support_mean = torch.mean(E_support_mean, dim=(-2, -1))
        # E_query = torch.mean(E_query, dim=(-2, -1))
        # print("E_support_mean:", E_support_mean.shape)
        # print("E_query:", E_query.shape)

        if self.training:
            return self._classification(E_support_mean, E_query, y_query)
        else:
            return self._test(E_support_mean, E_query)


    def _classification(self, P: torch.Tensor, Q: torch.Tensor, y_query: torch.Tensor):
        # print("\n_classification")
        # Nearest neighbor classification
        # Compute cosine distance
        P_norm = F.normalize(P, p=2, dim=3)
        Q_norm  = F.normalize(Q, p=2, dim=3)
        # print("P_norm", P_norm.shape)
        # print("Q", Q.shape)

        P_norm = P_norm.unsqueeze(-1)
        P_norm = P_norm.unsqueeze(-1)
        # print("P_norm", P_norm.shape)
        # print("Q_norm", Q_norm.shape)


        similarity_score = torch.sum(P_norm * Q_norm, dim=3)
        # similarity_score = similarity_score.view(-1, *similarity_score.shape[2:])
        # print("similarity_score", similarity_score.shape)

        # Global classification
        b, n_query, n_support, c, h, w = Q.shape
        # print("\nQ", Q.size())
        Q = Q.reshape(b, n_query, self.n_classes, -1)
        Q = Q.transpose(2, 3)
        # print("Q", Q.size())
        y_query = y_query.unsqueeze(-1)
        # print("y_query:", y_query.shape)
        y_pred = torch.matmul(Q, y_query)
        # print("y_pred:", y_pred.shape)
        y_pred = y_pred.view(b*n_query, c, h, w)
        # print("y_pred:", y_pred.shape)
        y_pred = self.classifer(y_pred)
        # print("y_pred:", y_pred.shape)
        y_pred = y_pred.view(b, n_query, self.n_classes, h, w)

        return y_pred, similarity_score


    def _test(self, P: torch.Tensor, Q: torch.Tensor):
        # Global average pooling over spatial dimensions
        Q = Q.mean(-1)
        Q = Q.mean(-1)

        P_norm = F.normalize(P, p=2, dim=-1)
        Q_norm = F.normalize(Q, p=2, dim=-1)
        similarity_scores = torch.sum(P_norm*Q_norm, dim=-1)
        return similarity_scores


In [257]:
b = 1
n_way, n_shot = 5, 5
c, height, width = 3, 32, 32

n_query = 15

support = torch.ones((b, n_way*n_shot, c, height, width))
y_support = torch.randint(0, n_way, (b, n_way*n_shot))
y_support_ohe = F.one_hot(y_support, n_way).type(torch.float)

query = torch.randn((b, n_query, c, height, width))
y_query = torch.randint(0, n_way, (b, n_query))
y_query_ohe = F.one_hot(y_query, n_way).type(torch.float)


can2test = CrossAttentionNetwork(n_classes=n_way)
y_pred2, sim_score2 = can2test(support, query, y_support_ohe, y_query_ohe)
y_pred2.shape, sim_score2.shape

(torch.Size([1, 15, 5, 3, 3]), torch.Size([1, 15, 5, 3, 3]))

In [223]:
b = 1
n_way, n_shot = 5, 5
c, height, width = 3, 32, 32

n_query = 15

support = torch.ones((b, n_way*n_shot, c, height, width))
y_support = torch.randint(0, n_way, (b, n_way*n_shot))
y_support_ohe = F.one_hot(y_support, n_way).type(torch.float)

query = torch.randn((b, n_query, c, height, width))
y_query = torch.randint(0, n_way, (b, n_query))
y_query_ohe = F.one_hot(y_query, n_way).type(torch.float)
# y_support = torch.randint(0, n_way, (b, n_query))
# y_support_ohe = F.one_hot(y_support, n_way)


can2test = CrossAttentionNetwork(n_classes=n_way)
can2test(support, query, y_support_ohe, y_query_ohe)

E torch.Size([40, 512, 3, 3])
E_support torch.Size([1, 25, 4608])
y_supp torch.Size([1, 5, 25])
y sum torch.Size([1, 5, 1])
y_support_class_count torch.Size([1, 5, 4608])
E_support mean torch.Size([1, 5, 4608])
512 3 3

Cross Attention forward
P torch.Size([1, 5, 512, 3, 3])
Q torch.Size([1, 15, 512, 3, 3])
P torch.Size([1, 5, 512, 9])
Q torch.Size([1, 15, 512, 9])
P_norm torch.Size([1, 5, 512, 9])
Q_norm torch.Size([1, 15, 512, 9])
P_norm torch.Size([1, 5, 1, 9, 512])
Q_norm torch.Size([1, 1, 15, 512, 9])
R_p torch.Size([1, 5, 15, 9, 9])
R_q torch.Size([1, 5, 15, 9, 9])

FusionLayer forward
R: torch.Size([1, 5, 15, 9, 9])
w torch.Size([1, 5, 15, 9])
w torch.Size([1, 9, 15, 5])
w torch.Size([1, 6, 15, 5])
w torch.Size([1, 9, 15, 5])
w torch.Size([1, 5, 15, 1, 9])
A: torch.Size([1, 5, 15, 9])

FusionLayer forward
R: torch.Size([1, 5, 15, 9, 9])
w torch.Size([1, 5, 15, 9])
w torch.Size([1, 9, 15, 5])
w torch.Size([1, 6, 15, 5])
w torch.Size([1, 9, 15, 5])
w torch.Size([1, 5, 15, 1, 9])
A

(tensor([[[[-4.5964e-01, -6.7759e-01, -1.9571e+00],
           [-1.5642e+00, -1.6503e+00, -2.4744e+00],
           [-4.6587e+00, -3.3465e+00, -1.6853e+00]],
 
          [[-1.8017e+00, -1.5339e+00, -4.1703e-02],
           [-1.7650e+00, -2.0061e+00, -1.2024e+00],
           [-3.2324e+00, -2.1887e+00, -9.3838e-01]],
 
          [[ 3.9101e-01,  5.7364e-01, -2.7385e-02],
           [-9.0061e-01, -1.0135e+00,  7.5178e-01],
           [-5.4327e-01, -1.7640e-01, -6.1134e-01]],
 
          [[ 9.8394e-01,  7.5506e-01,  9.8741e-01],
           [ 6.5876e-01,  4.0617e-01, -1.1633e+00],
           [-1.6457e+00,  5.9972e-02,  1.9820e-01]],
 
          [[-4.3297e-01,  1.4053e+00,  5.0144e-01],
           [-3.4192e-01, -3.5208e-02,  1.0980e+00],
           [-7.9379e-01, -1.0404e+00,  8.9025e-01]]],
 
 
         [[[-7.4148e-01, -1.0890e+00, -2.5846e-01],
           [-4.8078e-01, -8.7759e-01, -2.4776e+00],
           [-7.5128e-01, -6.8779e-01, -1.2025e+00]],
 
          [[-2.8457e+00, -1.1334e+00, -8.65

In [84]:
a = torch.ones(2, 1, 2, 2)
b = torch.ones(1, 2, 2, 2)
a[0] = a[0] * 2
a[1] = a[1] * 3
b[:,0] = b[:,0] * 5
b[:,1] = b[:,1] * 7
print(a,b)
print(torch.matmul(a,b).shape)
torch.matmul(a,b)

tensor([[[[2., 2.],
          [2., 2.]]],


        [[[3., 3.],
          [3., 3.]]]]) tensor([[[[5., 5.],
          [5., 5.]],

         [[7., 7.],
          [7., 7.]]]])
torch.Size([2, 2, 2, 2])


tensor([[[[20., 20.],
          [20., 20.]],

         [[28., 28.],
          [28., 28.]]],


        [[[30., 30.],
          [30., 30.]],

         [[42., 42.],
          [42., 42.]]]])

In [163]:
summary(CrossAttentionNetwork(), (2, 32, 5, 5, 3, 32, 32))

input shape 512 3 3
Support:  torch.Size([32, 5, 5, 3, 32, 32])
query:  torch.Size([32, 3, 32, 32])
CrossAttentionModule forward
torch.Size([32, 5, 512, 3, 3]) torch.Size([32, 512, 3, 3])
P norm: torch.Size([32, 5, 512, 9])
P norm t: torch.Size([32, 5, 9, 512])
Q norm: torch.Size([32, 1, 512, 9])
R_q torch.Size([32, 5, 9, 3, 3])
R_p torch.Size([32, 5, 9, 3, 3])
FusionLayer forward
R shape: torch.Size([32, 5, 9, 3, 3])
spatial w: torch.Size([32, 5, 9, 1, 1])
spatial w after: torch.Size([160, 9, 1, 1])
conv1 w: torch.Size([160, 6, 1, 1])
relu w: torch.Size([160, 6, 1, 1])
conv2 w: torch.Size([160, 9, 1, 1])
weights: torch.Size([160, 9, 1, 1])
weights_t: torch.Size([160, 1, 9])
R: torch.Size([160, 9, 9])
R_mean: torch.Size([160, 1, 9])
Attention: torch.Size([32, 5, 1, 9])

FusionLayer forward
R shape: torch.Size([32, 5, 9, 3, 3])
spatial w: torch.Size([32, 5, 9, 1, 1])
spatial w after: torch.Size([160, 9, 1, 1])
conv1 w: torch.Size([160, 6, 1, 1])
relu w: torch.Size([160, 6, 1, 1])
conv2 

Layer (type:depth-idx)                        Output Shape              Param #
CrossAttentionNetwork                         [32, 5, 512, 9]           --
├─ResNet_32x32: 1-1                           [800, 512, 3, 3]          --
│    └─Conv2d: 2-1                            [800, 64, 32, 32]         1,792
│    └─BatchNorm2d: 2-2                       [800, 64, 32, 32]         128
│    └─ReLU: 2-3                              [800, 64, 32, 32]         --
│    └─MaxPool2d: 2-4                         [800, 64, 15, 15]         --
│    └─Sequential: 2-5                        [800, 128, 15, 15]        --
│    │    └─ResidualBlock: 3-1                [800, 128, 15, 15]        230,400
│    │    └─ResidualBlock: 3-2                [800, 128, 15, 15]        295,680
│    └─MaxPool2d: 2-6                         [800, 128, 7, 7]          --
│    └─Sequential: 2-7                        [800, 256, 7, 7]          --
│    │    └─ResidualBlock: 3-3                [800, 256, 7, 7]          919,552
│

In [31]:
class CANLoss(nn.Module):
    def __init__(self, can: CrossAttentionNetwork, lamb: float = 0.5, n_classes: int = 5):
        super(CANLoss, self).__init__()

        self.can = can
        self.lamb = lamb
        self.classifier = nn.LazyConv1d(out_channels=n_classes, kernel_size=1)


    def cosine_dist(self, P_gap: torch.Tensor, Q_features: torch.Tensor) -> torch.Tensor:
        # print("\ncosine_dist")
        # print("Q_features",  Q_features.shape)
        # Q_features: (b, c, m)
        b, c, m = Q_features.shape
        # print("P_gap",  P_gap.shape)
        # P_gap: (b, c)
        P_gap_exp = P_gap.expand(m, b, c).transpose(0, 1).transpose(1, 2)
        # print("P_gap_exp", P_gap_exp.shape)
        cos_dist = F.cosine_similarity(P_gap_exp, Q_features, dim=-2)
        # print("cos_dist", cos_dist.shape)
        return cos_dist


    def L1_loss(self, P_sig: torch.Tensor, Q_sig: torch.Tensor, y_true: torch.Tensor):
        # P_features: (b, c, m)
        # Q_features: (b, c, m)
        # print("l1 loss")
        # print("P_features:", P_sig.shape)
        # print("Q_features:", Q_sig.shape)


        P_gap = torch.mean(P_sig, dim=-1)
        # P_gap: (b, c)
        distances = self.cosine_dist(P_gap, Q_sig)
        # distances: (b, m)
        # print("Distances:", distances.shape)

        # print("distances:", distances)
        likelihoods = -torch.log(F.softmax(-distances, dim=-1))
        # print("likelihoods", likelihoods)
        # print("likelihoods", likelihoods.shape)
        l1 = torch.sum(likelihoods)
        # print("l1", l1)
        return l1


    def L2_loss(self, Q_sig: torch.Tensor, y_true: torch.Tensor):
        # print("l2 loss")
        # print("Q_sig:", Q_sig.shape)
        # print()

        Z: torch.Tensor = self.classifier(Q_sig)
        # print("Z", Z.shape)
        Z_significant = []
        for y, z in zip(y_true, Z):
            Z_significant.append(z[y])
        Z_significant = torch.concat(Z_significant)
        # print("Z_significant:", Z_significant.shape)

        y_pred = -torch.log(F.softmax(Z_significant, dim=-1))
        # print("y_pred", y_pred.shape)
        l2 = torch.sum(y_pred)
        # print("l2:", l2.shape)
        return l2


    def forward(self, support: torch.Tensor, query: torch.Tensor, y_true: torch.Tensor):
        # print("support:", support.shape, "query:", query.shape, "y_true:", y_true.shape)
        P_features, Q_features = self.can(support, query)
        # print("P_features:", P_features.shape, "Q_features:", Q_features.shape)
        Q_significant = []
        P_significant = []
        for y, q, p in zip(y_true, Q_features, P_features):
            Q_significant.append(q[y])
            P_significant.append(p[y])
        Q_significant = torch.concatenate(Q_significant, dim=0)
        P_significant = torch.concatenate(P_significant, dim=0)
        # print("P_significant:", P_significant.shape, "Q_significant:", Q_significant.shape)
        l1 = self.L1_loss(P_significant, Q_significant, y_true)
        # print("L1 loss", l1.item())
        l2 = self.L2_loss(Q_significant, y_true)
        # print("L2 loss", l2.item())
        loss = self.lamb*l1 + l2
        return loss

In [17]:
b = 1
M = 6
K = 5
c = 3
height, width = 32, 32
can = CrossAttentionNetwork()
criterion = CANLoss(can, n_classes=M)

support = torch.ones(M, M, K, c, height, width)
query = torch.ones(M, c, height, width)
y_true = torch.randint(0, 5, (M, 1))
loss = criterion(support, query, y_true)
loss.backward()

support: torch.Size([6, 6, 5, 3, 32, 32]) query: torch.Size([6, 3, 32, 32]) y_true: torch.Size([6, 1])
P_features: torch.Size([6, 6, 512, 9]) Q_features: torch.Size([6, 6, 512, 9])
P_significant: torch.Size([6, 512, 9]) Q_significant: torch.Size([6, 512, 9])


In [214]:
a = torch.arange(30).view(3, 5, 2)
idx = torch.tensor([0, 1])
a, a[idx].shape, idx.shape

(tensor([[[ 0,  1],
          [ 2,  3],
          [ 4,  5],
          [ 6,  7],
          [ 8,  9]],
 
         [[10, 11],
          [12, 13],
          [14, 15],
          [16, 17],
          [18, 19]],
 
         [[20, 21],
          [22, 23],
          [24, 25],
          [26, 27],
          [28, 29]]]),
 torch.Size([2, 5, 2]),
 torch.Size([2]))

In [257]:
b = 32
M = 5
K = 5
c = 3
height, width = 32, 32
can = CrossAttentionNetwork().cuda()
criterion = CANLoss(can, n_classes=M).cuda()

support = torch.ones(b, M, K, c, height, width).cuda()
query = torch.ones(b, c, height, width).cuda()
y_true = torch.randint(0, 5, (b, 1)).cuda()
loss = criterion(support, query, y_true)
loss.backward()

Utils

In [224]:
import numpy as np
import torch
import random


def set_seed(seed: int):
    random.seed(seed)
    np.random.RandomState(seed)
    torch.manual_seed(seed)

def get_device():
    return "cuda:0" if torch.cuda.device_count() > 0 else 'cpu'


Data prep

In [225]:
from torch.utils.data import Sampler
import random

class FewShotSampler(Sampler):
    def __init__(self, data_source, n_shots: int, shuffle: bool = True):
        self.data_source = data_source
        self.n_shots = n_shots
        self.classes = self._get_classes()
        self.shuffle = shuffle


    def _get_classes(self):
        classes = dict()
        for i, (_, label) in enumerate(self.data_source):
            if label not in classes:
                classes[label] = []
            classes[label].append(i)

        return classes


    def __iter__(self):
        indices = []
        n = len(self.data_source)
        n_classes = len(self.classes)
        episodes = n // ((self.n_shots + 1) * n_classes)

        for i in range(episodes):
            query = []
            for c in self.classes:
                class_ind = self.classes[c]
                lower = i*(self.n_shots + 1)
                upper = (i+1)*(self.n_shots + 1) - 1
                support_one_class = class_ind[lower : upper]
                query.append(class_ind[upper])
                if self.shuffle:
                    random.shuffle(support_one_class)
                indices.extend(support_one_class)
            if self.shuffle:
                random.shuffle(query)
            indices.extend(query)
        return iter(indices)

    def __len__(self):
        return len(self.data_source)

In [226]:
import os
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
import torchvision
from torchvision import transforms
import shutil


CINIC_MEAN = [0.47889522, 0.47227842, 0.43047404]
CINIC_STD = [0.24205776, 0.23828046, 0.25874835]


def prepare_data_folder(data_path: str, few_shot_path: str):
    subsets = ["train", "valid", "test"]
    classes = {"train": [ 'bird', 'cat', 'deer', 'dog', 'frog', 'horse'],
               "valid": ['airplane','automobile'],
               "test": [ 'ship','truck']
    }
    [os.makedirs(os.path.join(few_shot_path, sub), exist_ok=True) for sub in subsets]
    [[os.makedirs(os.path.join(few_shot_path, sub, c), exist_ok=True) for c in classes[sub]]
     for sub in subsets]
    for sub in subsets:
        old_sub_path = os.path.join(data_path, sub)
        new_sub_path = os.path.join(few_shot_path, sub)
        for c in classes[sub]:
            old_dir = os.path.join(old_sub_path, c)
            new_dir = os.path.join(new_sub_path, c)
            imgs = os.listdir(old_dir)
            random.shuffle(imgs)
            # 600 support imgs -> 120 batches of 5-way
            # 120 imgs for query
            # 720 in total
            new_imgs = imgs[:720]
            [shutil.copyfile(os.path.join(old_dir, img), os.path.join(new_dir, img))
             for img in new_imgs]



def get_dataset(
        path: str,
        batch_size: int,
        n_shots: int,
        shuffle: bool,
        use_augmentations: bool,
) -> DataLoader:
    augmentations = ([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomErasing()
    ] if use_augmentations else [])
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=CINIC_MEAN,std=CINIC_STD),
                                    *augmentations])

    ds = torchvision.datasets.ImageFolder(path, transform=transform)
    sampler = FewShotSampler(ds, n_shots, shuffle=shuffle)
    n_way = len(ds.classes)
    episode_size = batch_size * n_way * (n_shots + 1)
    loader = DataLoader(ds, sampler=sampler, batch_size=episode_size, num_workers=2, pin_memory=True)
    return loader


def get_cinic_few(
        data_path: str,
        batch_size: int,
        n_shots: int,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    train_path = os.path.join(data_path, "train")
    valid_path = os.path.join(data_path, "valid")
    test_path = os.path.join(data_path, "test")

    cinic_train = get_dataset(train_path, batch_size, n_shots, True, True)
    cinic_validation = get_dataset(valid_path, batch_size, n_shots, False, False)
    cinic_test = get_dataset(test_path, batch_size, n_shots, False, False)

    return cinic_train, cinic_validation, cinic_test


In [None]:
# prepare_data_folder(DATA_PATH, FEW_SHOT_PATH)

In [41]:
n_way = 2
n_shots = 5
ds = torchvision.datasets.ImageFolder(os.path.join(FEW_SHOT_PATH, "test"), transform=transforms.ToTensor())
print(len(ds))
sampler = FewShotSampler(ds, n_shots)
loader = DataLoader(ds, sampler=sampler, batch_size=(n_way * n_shots + n_way))

for samples, labels in loader:
    support = samples[:n_way * n_shots]
    support_labels = labels[:n_way * n_shots]
    support_labels = support_labels.view(1, n_way, n_shots)
    print(support_labels)
    break
    # print(samples.shape, label.shape)


1440
0 : 720
1 : 720
1440
1440 120
0 0 (l, u): 0 5
support: 5
0 1 (l, u): 0 5
support: 5
query; 2
1 0 (l, u): 6 11
support: 5
1 1 (l, u): 6 11
support: 5
query; 2
2 0 (l, u): 12 17
support: 5
2 1 (l, u): 12 17
support: 5
query; 2
3 0 (l, u): 18 23
support: 5
3 1 (l, u): 18 23
support: 5
query; 2
4 0 (l, u): 24 29
support: 5
4 1 (l, u): 24 29
support: 5
query; 2
5 0 (l, u): 30 35
support: 5
5 1 (l, u): 30 35
support: 5
query; 2
6 0 (l, u): 36 41
support: 5
6 1 (l, u): 36 41
support: 5
query; 2
7 0 (l, u): 42 47
support: 5
7 1 (l, u): 42 47
support: 5
query; 2
8 0 (l, u): 48 53
support: 5
8 1 (l, u): 48 53
support: 5
query; 2
9 0 (l, u): 54 59
support: 5
9 1 (l, u): 54 59
support: 5
query; 2
10 0 (l, u): 60 65
support: 5
10 1 (l, u): 60 65
support: 5
query; 2
11 0 (l, u): 66 71
support: 5
11 1 (l, u): 66 71
support: 5
query; 2
12 0 (l, u): 72 77
support: 5
12 1 (l, u): 72 77
support: 5
query; 2
13 0 (l, u): 78 83
support: 5
13 1 (l, u): 78 83
support: 5
query; 2
14 0 (l, u): 84 89
suppor

In [89]:
ds.classes

['ship', 'truck']

In [50]:
set_seed(1)
ds_train, ds_valid, ds_test = get_cinic_few(FEW_SHOT_PATH, 32, 5)

[tensor([[[[-0.0343, -0.0505, -0.0505,  ...,  0.0000,  0.0000,  0.0000],
           [-0.0181, -0.0505, -0.0505,  ...,  0.0000,  0.0000,  0.0000],
           [-0.0019, -0.0181, -0.0343,  ...,  0.0000,  0.0000,  0.0000],
           ...,
           [-0.7472, -0.6824, -0.0505,  ...,  0.0000,  0.0000,  0.0000],
           [-0.6986, -0.6014, -0.6662,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
          [[ 0.1246,  0.1081,  0.1081,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.1739,  0.1410,  0.1410,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.2233,  0.1904,  0.1739,  ...,  0.0000,  0.0000,  0.0000],
           ...,
           [-0.7642, -0.6983, -0.0071,  ...,  0.0000,  0.0000,  0.0000],
           [-0.5996, -0.5008, -0.5831,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
          [[-0.6785, -0.6937, -0.6937,  ...,  0.0000,  0.0000,  0.0000],
           [-

In [54]:
samples, labels = next(iter(ds_train))
print(labels.tolist())

[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 

In [95]:
ds_train.dataset.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [263]:
class SpecialCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SpecialCrossEntropyLoss, self).__init__()
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        # inputs: (b, query_size, n_classes, h, w)
        # inputs: (b, query_size)
        inputs = inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2), -1)
        targets = targets.view(-1)

        log_probs = self.logsoftmax(inputs)
        print(targets)
        targets = torch.zeros(inputs.size(0), inputs.size(1)).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
        print(targets)
        targets = targets.unsqueeze(-1)
        targets = targets.cuda()
        print("log_probs:", log_probs.shape, "targets:", targets.shape)
        loss = (-targets * log_probs).mean(0).sum()
        print(-targets*log_probs)
        return loss / inputs.size(2)

### Training time
---

In [227]:
n_way = 6
n_shots = 5
c, h, w = 3, 32, 32

set_seed(1)
ds_train, ds_valid, ds_test = get_cinic_few(FEW_SHOT_PATH, 1, n_shots)

In [236]:
def split_support_query(
        episode_samples: torch.Tensor,
        n_way: int,
        n_shots: int
) -> tuple[torch.Tensor, torch.Tensor]:
    episode_samples = episode_samples
    query = episode_samples[n_way * n_shots:].unsqueeze(0)
    support = episode_samples[:n_way * n_shots].unsqueeze(0)
    return support, query


Warm start

In [264]:
from torch.nn import CrossEntropyLoss

n_way = 6
n_shots = 5

can = CrossAttentionNetwork().cuda()
criterion = SpecialCrossEntropyLoss()
optimizer_params = {
    "lr": 0.1,
}
optimizer = SGD(can.parameters(), **optimizer_params, )
epochs = 25

for epoch in range(epochs):
    losses = []
    for episode_samples, episode_labels in tqdm(ds_train):
        optimizer.zero_grad()
        # print(episode_samples.shape)
        # print(episode_labels.shape)
        support, query = split_support_query(episode_samples, n_way, n_shots)
        support: torch.Tensor = support.cuda()
        y_support = episode_labels[:n_way * n_shots].unsqueeze(0)
        y_support_ohe: torch.Tensor = F.one_hot(y_support, n_way).float().cuda()

        y_query = episode_labels[n_way * n_shots:].unsqueeze(0)
        y_query_ohe: torch.Tensor =  F.one_hot(y_query,n_way).float().cuda()
        query: torch.Tensor = query.cuda()

        # print("SUP:", support.shape)
        # print("QUERY:", query.shape)
        # print("Y_SUP:", y_support_ohe.shape)
        # print("Y_QUERY:", y_query_ohe.shape)

        y_pred, cls_score = can(support, query, y_support_ohe, y_query_ohe)
        # print("y_pred", y_pred.shape, "cls_score", cls_score.shape)
        loss1 = criterion(y_pred, y_query.view(-1))
        loss2 = criterion(cls_score, y_query.view(-1))
        # print("l1:", loss1, "l2: ", loss2.item)
        loss = loss1 + 0.5 * loss2
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        break
    print(f"Epoch {epoch+1}, loss: {torch.tensor(losses).mean()}")
    print("Losses", losses)
    break

  0%|          | 0/120 [00:00<?, ?it/s]

tensor([0, 4, 3, 1, 5, 2])
tensor([[1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0.]])
log_probs: torch.Size([6, 6, 9]) targets: torch.Size([6, 6, 1])
tensor([[[1.4144, 1.8716, 2.0096, 1.5667, 1.6312, 1.9034, 1.2339, 1.1492,
          1.6190],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0

  0%|          | 0/120 [00:07<?, ?it/s]

Epoch 1, loss: 2.8213820457458496
Losses [2.8213820457458496]





In [252]:
from torch.nn import CrossEntropyLoss

can = CrossAttentionNetwork().cuda()
criterion = SpecialCrossEntropyLoss()
optimizer_params = {
    "lr": 0.1,
}
optimizer = SGD(can.parameters(), **optimizer_params, )
epochs = 25

for epoch in range(epochs):
    losses = []
    for episode_samples, episode_labels in tqdm(ds_train):
        optimizer.zero_grad()
        # print(episode_samples.shape)
        # print(episode_labels.shape)
        support, query = split_support_query(episode_samples, n_way, n_shots)
        support: torch.Tensor = support.cuda()
        y_support = episode_labels[:n_way * n_shots].unsqueeze(0)
        y_support_ohe: torch.Tensor = F.one_hot(y_support, n_way).float().cuda()

        y_query = episode_labels[n_way * n_shots:].unsqueeze(0)
        y_query_ohe: torch.Tensor =  F.one_hot(y_query,n_way).float().cuda()
        query: torch.Tensor = query.cuda()

        # print("SUP:", support.shape)
        # print("QUERY:", query.shape)
        # print("Y_SUP:", y_support_ohe.shape)
        # print("Y_QUERY:", y_query_ohe.shape)

        y_pred, cls_score = can(support, query, y_support_ohe, y_query_ohe)
        # print("y_pred", y_pred.shape, "cls_score", cls_score.shape)
        loss1 = criterion(y_pred, y_query.view(-1))
        loss2 = criterion(cls_score, y_query.view(-1))
        # print("l1:", loss1, "l2: ", loss2.item)
        loss = loss1 + 0.5 * loss2
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch {epoch+1}, loss: {torch.tensor(losses).mean()}")
    print("Losses", losses)
        # break

#         loss = criterion(y_pred, y_query_ohe)

# can_warmed_up = CrossAttentionNetwork()
# can_warmed_up.load_state_dict(can.state_dict())
# criterion_warmed_up = CANLoss(can_warmed_up, n_classes=n_way)
# criterion_warmed_up.load_state_dict(criterion.state_dict())

100%|██████████| 120/120 [00:09<00:00, 12.80it/s]


Epoch 1, loss: 4.283754825592041
Losses [2.8294456005096436, 9.136136054992676, 13.230528831481934, 13.052899360656738, 17.615345001220703, 8.604540824890137, 6.4407958984375, 13.702064514160156, 7.341562271118164, 28.1121826171875, 4.872560024261475, 3.0149474143981934, 13.354408264160156, 28.230745315551758, 8.545151710510254, 6.080874443054199, 11.377700805664062, 4.19000244140625, 3.9890012741088867, 4.6881513595581055, 4.505551338195801, 4.5318169593811035, 4.680295944213867, 4.42927360534668, 3.2379324436187744, 3.5702686309814453, 4.196882247924805, 3.2555174827575684, 5.165862083435059, 4.894817352294922, 3.3822593688964844, 3.249025583267212, 2.931387186050415, 3.2217676639556885, 2.6986289024353027, 3.2894160747528076, 3.72867751121521, 2.5520641803741455, 4.32594108581543, 4.045936107635498, 2.9464457035064697, 3.0410008430480957, 2.6001687049865723, 2.3025436401367188, 2.6005003452301025, 4.34052038192749, 2.542372465133667, 3.8469910621643066, 2.810100555419922, 2.60832762

100%|██████████| 120/120 [00:09<00:00, 13.11it/s]


Epoch 2, loss: 2.7312419414520264
Losses [3.060857057571411, 2.69051456451416, 3.01593279838562, 2.680072784423828, 2.6981589794158936, 2.9113516807556152, 2.809696912765503, 2.7537803649902344, 2.794569253921509, 2.5362462997436523, 2.792789936065674, 2.7387959957122803, 2.6019511222839355, 2.6583056449890137, 2.936633825302124, 2.519298553466797, 2.584510326385498, 2.68925142288208, 2.604865789413452, 2.9561057090759277, 2.5127692222595215, 2.562983989715576, 2.6948673725128174, 2.9041690826416016, 2.5802602767944336, 2.4910404682159424, 2.9335343837738037, 3.267958164215088, 2.98996901512146, 2.750950336456299, 2.7400858402252197, 2.7195911407470703, 2.7591662406921387, 2.755347967147827, 2.903384208679199, 2.6517157554626465, 3.0106372833251953, 2.669735908508301, 2.675886631011963, 2.453761100769043, 2.6791367530822754, 2.6913013458251953, 2.790091037750244, 2.547377586364746, 2.6052567958831787, 2.6911168098449707, 2.4100379943847656, 3.142364740371704, 2.624706983566284, 2.71562

100%|██████████| 120/120 [00:09<00:00, 12.75it/s]


Epoch 3, loss: 2.695554494857788
Losses [2.9881105422973633, 2.810535192489624, 2.9024295806884766, 2.6380889415740967, 2.7233080863952637, 2.7990829944610596, 2.735384464263916, 2.55119252204895, 2.586552143096924, 2.615338087081909, 2.6094746589660645, 2.7745773792266846, 2.5540523529052734, 2.6304104328155518, 2.6915531158447266, 2.721095085144043, 2.739131450653076, 2.6717708110809326, 2.4914939403533936, 2.6893017292022705, 2.4245686531066895, 2.5942366123199463, 2.6116881370544434, 2.838710308074951, 2.5809316635131836, 2.3684396743774414, 2.746819496154785, 3.128225803375244, 2.834019184112549, 2.7087676525115967, 2.6811811923980713, 2.735593557357788, 2.6355204582214355, 2.889894485473633, 2.9102349281311035, 2.7435309886932373, 2.69708514213562, 2.7676353454589844, 2.6213083267211914, 2.422738552093506, 2.722111463546753, 2.6096482276916504, 2.589284896850586, 2.5624194145202637, 2.6054177284240723, 2.578826904296875, 2.450066328048706, 2.7383148670196533, 2.4761083126068115, 

100%|██████████| 120/120 [00:09<00:00, 12.73it/s]


Epoch 4, loss: 2.660261869430542
Losses [2.995896339416504, 2.592712879180908, 2.8803415298461914, 2.500459671020508, 2.7844998836517334, 2.9148824214935303, 2.7239749431610107, 2.606924057006836, 2.6953272819519043, 2.556788921356201, 2.545518159866333, 2.847712278366089, 2.4874556064605713, 2.633477210998535, 2.6612706184387207, 2.465423107147217, 2.77932071685791, 2.732682466506958, 2.6241211891174316, 2.458482265472412, 2.3024072647094727, 2.5106375217437744, 2.61417293548584, 2.634122371673584, 2.515411853790283, 2.4734039306640625, 2.9274516105651855, 2.907456398010254, 2.8689892292022705, 2.607189178466797, 2.774837017059326, 2.6497464179992676, 2.7368292808532715, 2.7690834999084473, 2.855652093887329, 2.9001896381378174, 2.9385128021240234, 2.650587558746338, 2.583840847015381, 2.367122173309326, 2.6300883293151855, 2.6031270027160645, 2.5817291736602783, 2.3849806785583496, 2.3454012870788574, 2.694528579711914, 2.0938827991485596, 2.7906737327575684, 2.100104808807373, 2.565

100%|██████████| 120/120 [00:09<00:00, 12.82it/s]


Epoch 5, loss: 2.644871473312378
Losses [2.715864419937134, 2.446046829223633, 2.718503475189209, 2.5079174041748047, 2.6079843044281006, 2.852008581161499, 2.6082122325897217, 2.498654365539551, 2.5741381645202637, 2.382997512817383, 2.68794322013855, 2.9508676528930664, 2.3791017532348633, 2.752394437789917, 2.4401848316192627, 2.8840644359588623, 2.5968613624572754, 2.4563472270965576, 2.6126585006713867, 2.45238995552063, 2.4541845321655273, 2.606088876724243, 2.528163433074951, 2.7898061275482178, 2.7746782302856445, 2.4745044708251953, 2.535886764526367, 2.8351523876190186, 2.6301212310791016, 2.899625301361084, 2.5862808227539062, 2.862063407897949, 2.7823550701141357, 2.61970591545105, 2.905122756958008, 2.7759010791778564, 2.5495967864990234, 2.6596391201019287, 2.582819700241089, 2.5298261642456055, 2.623915195465088, 2.5845813751220703, 2.631376266479492, 2.4292123317718506, 2.5819778442382812, 2.4852333068847656, 2.2728567123413086, 2.9085185527801514, 2.284738063812256, 2.

100%|██████████| 120/120 [00:09<00:00, 12.84it/s]


Epoch 6, loss: 2.628109931945801
Losses [2.6869125366210938, 2.4434359073638916, 2.8326566219329834, 2.5476512908935547, 2.6259355545043945, 2.757293462753296, 2.5945191383361816, 2.6144115924835205, 2.5748980045318604, 2.451211452484131, 2.475947856903076, 2.6126160621643066, 2.5092597007751465, 2.774779796600342, 2.543365478515625, 2.523944139480591, 2.588050127029419, 2.2640862464904785, 3.184805154800415, 2.461103677749634, 2.569711685180664, 2.8326878547668457, 2.618192195892334, 2.687025308609009, 2.604093074798584, 2.5881290435791016, 2.5217649936676025, 2.7322733402252197, 2.6760380268096924, 2.508326292037964, 2.543675184249878, 2.986765146255493, 2.593953847885132, 2.9146833419799805, 2.854214668273926, 2.6632235050201416, 2.7216808795928955, 2.6184678077697754, 2.594362258911133, 2.4139065742492676, 2.487051248550415, 2.752257823944092, 2.609609603881836, 2.2532401084899902, 2.522672653198242, 2.4539215564727783, 2.2229154109954834, 2.6833889484405518, 2.171003580093384, 2.4

100%|██████████| 120/120 [00:09<00:00, 12.57it/s]


Epoch 7, loss: 2.614802598953247
Losses [2.764807939529419, 2.5266807079315186, 2.6219868659973145, 2.5885119438171387, 2.662684440612793, 2.784043788909912, 2.6994447708129883, 2.395488739013672, 2.546602725982666, 2.547227144241333, 2.5896289348602295, 2.8972764015197754, 2.403109073638916, 2.638826370239258, 2.4417545795440674, 2.647230625152588, 2.59023380279541, 2.391660213470459, 2.5512986183166504, 2.396986484527588, 2.3130457401275635, 2.4427173137664795, 2.689507007598877, 2.6897592544555664, 2.5542068481445312, 2.3892035484313965, 2.8203999996185303, 2.7411186695098877, 2.6117467880249023, 2.5621204376220703, 2.4849305152893066, 2.798586845397949, 2.798494577407837, 2.735610008239746, 2.8166744709014893, 2.8813085556030273, 2.9104695320129395, 2.55397629737854, 2.480375289916992, 2.227761745452881, 2.618595600128174, 2.569037437438965, 2.788572311401367, 2.2960872650146484, 2.5357556343078613, 2.5891003608703613, 2.4478061199188232, 2.6725006103515625, 2.6060352325439453, 2.5

100%|██████████| 120/120 [00:09<00:00, 12.50it/s]


Epoch 8, loss: 2.6055593490600586
Losses [2.6538336277008057, 2.4943222999572754, 2.6299986839294434, 2.4447576999664307, 2.665508270263672, 2.83394455909729, 2.630075216293335, 2.460599184036255, 2.689728260040283, 2.6663784980773926, 2.4773311614990234, 2.9674124717712402, 2.49946928024292, 2.6511309146881104, 2.3880343437194824, 2.4655442237854004, 2.5885472297668457, 2.4433085918426514, 2.3635833263397217, 2.396719217300415, 2.2879374027252197, 2.6253299713134766, 2.651721239089966, 2.7768449783325195, 2.5196704864501953, 2.161285638809204, 2.717418670654297, 2.716383218765259, 2.668393135070801, 2.8233790397644043, 2.519165515899658, 2.724757671356201, 2.7268638610839844, 2.859424114227295, 3.021409273147583, 2.947599411010742, 2.646371841430664, 2.669220447540283, 2.621185541152954, 2.2139930725097656, 2.496363401412964, 2.7018303871154785, 2.7856991291046143, 2.2898905277252197, 2.5480592250823975, 2.5529227256774902, 2.2917070388793945, 2.575984001159668, 2.687304735183716, 2.4

100%|██████████| 120/120 [00:09<00:00, 12.51it/s]


Epoch 9, loss: 2.5887644290924072
Losses [2.514286518096924, 2.342604160308838, 2.5894365310668945, 2.5348711013793945, 2.4508864879608154, 2.879136085510254, 2.552659511566162, 2.750056743621826, 2.6306777000427246, 2.4825968742370605, 2.3491287231445312, 2.8518242835998535, 2.3742432594299316, 2.6427412033081055, 2.48567271232605, 2.562464714050293, 2.479778528213501, 2.198087692260742, 2.6592211723327637, 2.2645480632781982, 2.3920531272888184, 2.696373224258423, 2.6706154346466064, 2.758934736251831, 2.5627663135528564, 2.3281593322753906, 2.5218799114227295, 2.659024715423584, 2.476799249649048, 2.687746286392212, 2.5640745162963867, 2.9356327056884766, 2.782414197921753, 3.028792381286621, 2.7728025913238525, 2.599249839782715, 2.789139747619629, 2.5535311698913574, 2.5164201259613037, 2.3095576763153076, 2.356282949447632, 2.721949338912964, 2.7296271324157715, 2.4004323482513428, 2.5634374618530273, 2.60402250289917, 2.262397050857544, 2.429792881011963, 2.515427827835083, 2.13

100%|██████████| 120/120 [00:09<00:00, 12.55it/s]


Epoch 10, loss: 2.5750062465667725
Losses [2.783693552017212, 2.5530755519866943, 2.7445218563079834, 2.5290653705596924, 2.550168752670288, 2.6789989471435547, 3.0021462440490723, 2.255223035812378, 2.460578680038452, 2.6500144004821777, 2.250804901123047, 2.8057498931884766, 2.402285575866699, 2.5225138664245605, 2.3658511638641357, 2.6098203659057617, 2.5574913024902344, 2.1092369556427, 2.3771746158599854, 2.405704975128174, 2.419577121734619, 2.695812225341797, 2.6141130924224854, 2.7985095977783203, 2.4562950134277344, 2.2715847492218018, 2.3045287132263184, 2.565887928009033, 2.453476905822754, 2.541428804397583, 2.6063640117645264, 2.826597213745117, 2.420703649520874, 2.7921109199523926, 3.052978038787842, 2.4574990272521973, 2.6371426582336426, 2.6189799308776855, 2.589548110961914, 2.466750144958496, 2.3541758060455322, 2.6582698822021484, 2.723414421081543, 2.217844009399414, 2.372398853302002, 2.569601058959961, 2.461646795272827, 2.3343167304992676, 2.295454740524292, 2.1

100%|██████████| 120/120 [00:09<00:00, 12.50it/s]


Epoch 11, loss: 2.559849977493286
Losses [2.6476874351501465, 2.407184600830078, 2.7556095123291016, 2.6301817893981934, 2.4587788581848145, 2.631277084350586, 2.5666966438293457, 2.3395142555236816, 2.4983408451080322, 2.329860210418701, 2.5991320610046387, 2.9918599128723145, 2.3582091331481934, 2.7802114486694336, 2.41625714302063, 2.5631864070892334, 2.472480058670044, 2.222628116607666, 2.521245002746582, 2.149017095565796, 2.3267788887023926, 2.426279067993164, 2.584888458251953, 2.6625821590423584, 2.634591579437256, 2.274618625640869, 2.2023091316223145, 2.3333113193511963, 2.664564609527588, 2.465121030807495, 2.5786685943603516, 2.8579869270324707, 2.4063234329223633, 2.820115327835083, 2.8235585689544678, 2.6699154376983643, 2.7528626918792725, 2.822218656539917, 2.442397117614746, 2.5383763313293457, 2.3987340927124023, 2.6333045959472656, 2.4922690391540527, 2.2573511600494385, 2.3584201335906982, 2.7626101970672607, 2.2695415019989014, 2.5905842781066895, 2.29055118560791

100%|██████████| 120/120 [00:09<00:00, 12.51it/s]


Epoch 12, loss: 2.5538277626037598
Losses [2.732764482498169, 2.3388917446136475, 2.640352964401245, 2.4802024364471436, 2.5641088485717773, 2.8163437843322754, 2.493596076965332, 2.2130393981933594, 2.5645713806152344, 2.2496657371520996, 2.523510456085205, 2.8043041229248047, 2.3314085006713867, 2.564570426940918, 2.5408992767333984, 2.491302013397217, 2.5468456745147705, 2.227689266204834, 2.512734889984131, 2.4169058799743652, 2.224144697189331, 2.3005075454711914, 2.571183919906616, 2.757951498031616, 2.364854574203491, 2.28761625289917, 2.3702573776245117, 2.5974016189575195, 2.7886242866516113, 2.614063024520874, 2.643561840057373, 2.88344144821167, 2.436933994293213, 3.139958381652832, 2.608304738998413, 2.7438395023345947, 2.742310047149658, 2.40922212600708, 2.5186710357666016, 2.159257650375366, 2.4368176460266113, 2.5708327293395996, 2.8424201011657715, 2.3697397708892822, 2.175736427307129, 2.642117738723755, 2.4988739490509033, 2.5483899116516113, 2.4878547191619873, 2.07

100%|██████████| 120/120 [00:10<00:00, 11.68it/s]


Epoch 13, loss: 2.5346083641052246
Losses [2.5883376598358154, 2.4749176502227783, 2.525010824203491, 2.410388231277466, 2.4467811584472656, 2.6670382022857666, 2.9393553733825684, 2.4752228260040283, 2.4115333557128906, 2.151427984237671, 2.439624547958374, 2.654843807220459, 2.2202208042144775, 2.323281764984131, 2.4139833450317383, 2.503342390060425, 2.586469888687134, 2.131030559539795, 2.2850148677825928, 2.3672966957092285, 2.3648312091827393, 2.4924826622009277, 2.7014477252960205, 2.8613779544830322, 2.229851245880127, 2.587125301361084, 2.4966049194335938, 2.468468427658081, 2.9641027450561523, 2.563852071762085, 2.7910518646240234, 2.4332034587860107, 2.6248316764831543, 2.7741971015930176, 2.7097537517547607, 2.687730312347412, 2.6319527626037598, 2.5143187046051025, 2.4804141521453857, 2.1278252601623535, 2.3970930576324463, 2.7222180366516113, 2.4228100776672363, 2.2345004081726074, 2.4715960025787354, 2.7223360538482666, 2.564624786376953, 2.4772531986236572, 2.1458644866

100%|██████████| 120/120 [00:09<00:00, 12.04it/s]


Epoch 14, loss: 2.5172054767608643
Losses [2.649017810821533, 2.40610933303833, 2.958345413208008, 2.4575672149658203, 2.483642101287842, 2.4349617958068848, 2.593290328979492, 2.1786680221557617, 2.5805554389953613, 1.9856452941894531, 2.6384212970733643, 2.8153953552246094, 2.2629811763763428, 2.4187827110290527, 2.285438060760498, 2.5313704013824463, 2.8448896408081055, 2.181893825531006, 2.2966392040252686, 2.395402193069458, 2.054990291595459, 2.1033847332000732, 2.3847367763519287, 3.023655891418457, 2.708109140396118, 2.1038265228271484, 2.188149929046631, 2.09616756439209, 2.59128475189209, 2.2837016582489014, 2.964162588119507, 2.606764793395996, 2.5687222480773926, 2.5310845375061035, 2.7108168601989746, 2.3946828842163086, 3.228548526763916, 2.5349626541137695, 2.5806286334991455, 2.4759886264801025, 2.612321376800537, 2.3945844173431396, 2.434431552886963, 2.2688815593719482, 2.3063206672668457, 2.66890811920166, 2.688891887664795, 2.3681259155273438, 2.5625240802764893, 2.

100%|██████████| 120/120 [00:09<00:00, 12.18it/s]


Epoch 15, loss: 2.477229595184326
Losses [2.7855052947998047, 2.371513605117798, 2.4260759353637695, 2.2942824363708496, 2.2319350242614746, 2.501021146774292, 2.4860575199127197, 2.605243682861328, 2.4216320514678955, 2.3592240810394287, 2.0556681156158447, 2.918083429336548, 2.3327841758728027, 2.70353364944458, 2.432361602783203, 2.3133840560913086, 2.545891761779785, 1.9198367595672607, 2.7289321422576904, 2.3734636306762695, 2.225738286972046, 2.379091501235962, 2.4818334579467773, 3.0642471313476562, 2.5073049068450928, 2.5481929779052734, 2.1568102836608887, 2.3709986209869385, 2.4839017391204834, 2.1277174949645996, 2.753335475921631, 2.926626443862915, 2.303392171859741, 2.638794422149658, 2.854881525039673, 2.5995097160339355, 2.5200083255767822, 2.478512763977051, 2.3202319145202637, 2.186305046081543, 2.4039876461029053, 2.4562182426452637, 2.5304605960845947, 2.257284641265869, 2.225947618484497, 2.6056036949157715, 2.2520532608032227, 2.3814666271209717, 2.461896419525146

100%|██████████| 120/120 [00:09<00:00, 12.18it/s]


Epoch 16, loss: 2.49756121635437
Losses [2.7260212898254395, 2.1985416412353516, 2.4240481853485107, 2.452897310256958, 2.5426530838012695, 2.5813193321228027, 2.648458957672119, 2.379106283187866, 2.380958080291748, 2.3079028129577637, 2.118537664413452, 2.7460427284240723, 2.2571451663970947, 2.801093578338623, 2.3734679222106934, 2.418710708618164, 2.5044467449188232, 2.2668416500091553, 2.7063536643981934, 2.2910337448120117, 2.335296630859375, 2.4868249893188477, 2.498459577560425, 2.7109336853027344, 2.466167688369751, 2.3999483585357666, 2.166904926300049, 2.1405234336853027, 2.45021653175354, 2.487978458404541, 2.9549808502197266, 3.1018104553222656, 2.53924822807312, 2.7565526962280273, 2.4993319511413574, 2.392704963684082, 2.2469139099121094, 2.334315776824951, 2.2722361087799072, 1.9808053970336914, 2.598773241043091, 2.4295711517333984, 2.471177339553833, 2.411205768585205, 2.526516914367676, 2.5151095390319824, 2.265958547592163, 2.3574953079223633, 2.606248378753662, 2.3

100%|██████████| 120/120 [00:09<00:00, 12.40it/s]


Epoch 17, loss: 2.452986240386963
Losses [2.5574886798858643, 2.5504817962646484, 2.466244697570801, 2.361182928085327, 2.2464511394500732, 2.4595136642456055, 2.3607406616210938, 2.117069721221924, 2.425652503967285, 2.152669668197632, 1.982121467590332, 2.9789390563964844, 2.124408006668091, 2.7468669414520264, 2.133892059326172, 2.3064370155334473, 2.3283960819244385, 2.185594081878662, 2.694215774536133, 2.3875226974487305, 2.128687620162964, 2.399564027786255, 2.210613250732422, 2.430924892425537, 2.689302682876587, 2.1302003860473633, 2.201016426086426, 2.2105584144592285, 2.890221118927002, 2.6019954681396484, 2.869192600250244, 2.5090396404266357, 2.5791540145874023, 2.4278621673583984, 2.6346867084503174, 2.5193593502044678, 2.735973596572876, 2.279233694076538, 2.1792306900024414, 2.013214588165283, 2.424215316772461, 2.7915797233581543, 2.326896905899048, 2.407046318054199, 2.059946298599243, 2.3637447357177734, 2.1024763584136963, 2.4533607959747314, 2.263824939727783, 2.17

100%|██████████| 120/120 [00:09<00:00, 12.30it/s]


Epoch 18, loss: 2.4376211166381836
Losses [2.62164306640625, 2.175971508026123, 2.5083138942718506, 2.3420891761779785, 2.4489827156066895, 2.3712985515594482, 2.394505023956299, 2.205413818359375, 2.5762503147125244, 2.4545233249664307, 2.193166732788086, 2.3580968379974365, 2.2020745277404785, 2.6235408782958984, 2.3721823692321777, 2.1947686672210693, 2.2691867351531982, 2.2019450664520264, 2.598097324371338, 2.332505702972412, 2.066635847091675, 2.1859421730041504, 2.5151236057281494, 2.8647594451904297, 2.568601131439209, 2.3224668502807617, 2.3083882331848145, 2.348057270050049, 2.910511016845703, 2.5910239219665527, 2.633537769317627, 2.621532917022705, 2.5728678703308105, 2.65988826751709, 2.3549606800079346, 2.4954802989959717, 2.454957962036133, 2.2209911346435547, 2.4663150310516357, 2.0379719734191895, 2.278048515319824, 2.4872541427612305, 2.209066867828369, 2.1147923469543457, 2.2295854091644287, 2.387836456298828, 2.581613540649414, 2.328462839126587, 2.593538999557495, 

100%|██████████| 120/120 [00:09<00:00, 12.27it/s]


Epoch 19, loss: 2.4120588302612305
Losses [2.686161518096924, 2.1665444374084473, 2.068446636199951, 2.447434425354004, 2.1164822578430176, 2.326284885406494, 2.8608078956604004, 2.1560678482055664, 2.314302921295166, 2.1290907859802246, 2.013777017593384, 2.8033761978149414, 2.1143527030944824, 2.456226348876953, 2.2522735595703125, 2.1851654052734375, 2.2799320220947266, 1.8955893516540527, 2.661043167114258, 2.2329823970794678, 2.3329875469207764, 2.0442495346069336, 2.496410369873047, 2.4254143238067627, 2.428884744644165, 2.3836684226989746, 1.9129948616027832, 2.1159825325012207, 2.6291937828063965, 2.160266637802124, 2.5817437171936035, 2.6806955337524414, 2.316883087158203, 2.2753231525421143, 2.211623191833496, 2.328022003173828, 2.5484113693237305, 2.24843692779541, 2.2246813774108887, 2.304356098175049, 2.1868820190429688, 2.4529688358306885, 2.457186460494995, 2.5024890899658203, 2.1648037433624268, 2.9826836585998535, 2.101393461227417, 2.1587135791778564, 2.38468694686889

100%|██████████| 120/120 [00:09<00:00, 12.43it/s]


Epoch 20, loss: 2.390683174133301
Losses [2.5678749084472656, 2.299727201461792, 2.5864059925079346, 2.269726037979126, 2.2986268997192383, 2.3534512519836426, 2.55181884765625, 2.2304091453552246, 2.736872911453247, 2.335381269454956, 2.00746750831604, 2.7696080207824707, 2.215702772140503, 2.6773955821990967, 2.2790746688842773, 2.0038022994995117, 2.55303955078125, 1.9495949745178223, 2.2477402687072754, 2.2744901180267334, 2.225564956665039, 2.122783899307251, 2.1847918033599854, 2.729092597961426, 2.655118703842163, 1.928307056427002, 2.0151329040527344, 2.2670986652374268, 2.4126367568969727, 2.272772789001465, 2.6654958724975586, 2.4440481662750244, 2.3966593742370605, 2.507870674133301, 2.345689296722412, 2.3810932636260986, 2.4581432342529297, 1.8365793228149414, 2.171679973602295, 1.7940199375152588, 2.619983434677124, 2.23736310005188, 2.0466606616973877, 2.1575021743774414, 2.6373085975646973, 2.469805955886841, 2.27943754196167, 2.1779298782348633, 2.1766364574432373, 2.09

100%|██████████| 120/120 [00:09<00:00, 12.55it/s]


Epoch 21, loss: 2.3848233222961426
Losses [3.2977700233459473, 2.297795295715332, 2.1286001205444336, 2.402513265609741, 2.160956859588623, 2.4216647148132324, 2.3201162815093994, 1.9444366693496704, 2.273001194000244, 2.1288528442382812, 2.2794766426086426, 2.570544481277466, 2.441664457321167, 2.59842586517334, 2.132251024246216, 2.176056385040283, 2.178302049636841, 1.7892982959747314, 2.733877658843994, 2.139626979827881, 2.855013847351074, 2.2213642597198486, 1.9864076375961304, 2.7252490520477295, 2.524768829345703, 1.8636043071746826, 1.881680965423584, 2.04061222076416, 2.4629266262054443, 2.2873244285583496, 2.3604276180267334, 2.4651877880096436, 2.7257487773895264, 2.2877297401428223, 2.404865264892578, 2.397928237915039, 2.4754996299743652, 2.3550283908843994, 2.2649765014648438, 1.8979331254959106, 2.2829227447509766, 2.331632375717163, 2.187875747680664, 2.2971200942993164, 2.303493022918701, 2.2241036891937256, 2.045424461364746, 2.1621551513671875, 2.3740415573120117, 2

100%|██████████| 120/120 [00:09<00:00, 12.17it/s]


Epoch 22, loss: 2.367060899734497
Losses [2.751924991607666, 2.1015238761901855, 2.264651298522949, 2.3895819187164307, 2.245710611343384, 2.4808952808380127, 2.560558319091797, 2.3132665157318115, 2.3128576278686523, 1.9330031871795654, 1.9731144905090332, 2.676266670227051, 2.1724390983581543, 2.8388571739196777, 2.2365505695343018, 2.037337303161621, 2.3868417739868164, 1.8411107063293457, 2.3994548320770264, 2.3488636016845703, 1.8018202781677246, 2.293588638305664, 1.960949182510376, 2.6688599586486816, 2.3485450744628906, 1.9855833053588867, 2.228804111480713, 2.1068596839904785, 2.469770669937134, 2.510061740875244, 2.665889024734497, 2.7531304359436035, 2.3719053268432617, 2.0927016735076904, 2.3112380504608154, 1.8918476104736328, 2.5567941665649414, 1.8628451824188232, 1.9342975616455078, 1.932403564453125, 2.302607536315918, 2.5851387977600098, 1.9867055416107178, 2.233185291290283, 2.05003023147583, 2.3336005210876465, 2.229238748550415, 2.2378950119018555, 2.07527875900268

100%|██████████| 120/120 [00:10<00:00, 11.64it/s]


Epoch 23, loss: 2.3314285278320312
Losses [2.757357358932495, 2.3198869228363037, 2.3301868438720703, 2.105756998062134, 2.1424505710601807, 2.479609966278076, 2.5505518913269043, 2.0671072006225586, 2.059544324874878, 1.826061725616455, 2.0257980823516846, 3.1022558212280273, 2.0048749446868896, 2.6278316974639893, 2.038682460784912, 2.3142154216766357, 2.4011483192443848, 2.1719818115234375, 2.5211877822875977, 2.1791791915893555, 1.9793322086334229, 2.078101873397827, 2.112499713897705, 2.270103931427002, 2.4766812324523926, 2.275151491165161, 2.0060205459594727, 2.5007858276367188, 2.4270355701446533, 2.1952567100524902, 2.640820026397705, 2.702566146850586, 2.4646894931793213, 2.0377578735351562, 2.437744379043579, 2.361577272415161, 2.3505077362060547, 2.023777961730957, 2.03017520904541, 1.878673791885376, 2.092095375061035, 3.316777229309082, 2.374136447906494, 2.255547285079956, 1.9288136959075928, 2.296382188796997, 2.574056625366211, 2.246173858642578, 2.282146692276001, 2.1

100%|██████████| 120/120 [00:10<00:00, 11.89it/s]


Epoch 24, loss: 2.3047127723693848
Losses [2.715297222137451, 2.1832921504974365, 2.230015993118286, 2.1808207035064697, 2.0945348739624023, 2.6949288845062256, 3.098728656768799, 2.1109657287597656, 2.1861839294433594, 2.042839527130127, 1.841050148010254, 2.6770193576812744, 2.0923755168914795, 2.7517402172088623, 2.2770121097564697, 2.014504909515381, 2.2955403327941895, 1.6851797103881836, 2.4439663887023926, 2.030555486679077, 2.0820066928863525, 1.837900161743164, 2.139676094055176, 2.7609212398529053, 2.3092403411865234, 1.9794446229934692, 1.9717025756835938, 1.823206901550293, 2.398892879486084, 2.143543243408203, 2.5644655227661133, 2.5306334495544434, 2.3532931804656982, 2.2206430435180664, 2.2850863933563232, 2.11296010017395, 2.26796817779541, 1.8263893127441406, 1.9930129051208496, 1.6643390655517578, 1.9138896465301514, 2.4252612590789795, 2.314427614212036, 2.2357635498046875, 1.7105040550231934, 2.0828516483306885, 1.9041626453399658, 2.211092472076416, 1.9311240911483

100%|██████████| 120/120 [00:09<00:00, 12.03it/s]

Epoch 25, loss: 2.300511598587036
Losses [2.7883827686309814, 2.1250288486480713, 2.5206480026245117, 2.463980197906494, 1.959113597869873, 2.3444247245788574, 2.328101873397827, 2.2106728553771973, 2.450039863586426, 2.0417299270629883, 1.8133907318115234, 2.6542890071868896, 1.9374346733093262, 2.774535655975342, 1.9397151470184326, 2.0586814880371094, 2.5984549522399902, 2.064669609069824, 2.660753011703491, 2.323810577392578, 1.8745503425598145, 1.8692831993103027, 2.052159309387207, 2.958265542984009, 2.2840216159820557, 1.8214857578277588, 1.7888915538787842, 1.8428361415863037, 2.049453020095825, 2.3722219467163086, 3.485842704772949, 2.8806281089782715, 2.3901381492614746, 2.6370623111724854, 2.2754034996032715, 2.0976076126098633, 1.9731249809265137, 1.8105762004852295, 2.1374480724334717, 1.7571771144866943, 2.4155192375183105, 2.736812114715576, 2.5160789489746094, 2.0285134315490723, 1.9638195037841797, 2.3086719512939453, 1.9733318090438843, 2.0892961025238037, 2.006710529




In [246]:
print(y_pred.shape, cls_score.shape)
print(cls_score[0])

torch.Size([6, 6, 3, 3]) torch.Size([6, 6, 3, 3])
tensor([[[0.6723, 0.6892, 0.6877],
         [0.6556, 0.6629, 0.6731],
         [0.6663, 0.6526, 0.6691]],

        [[0.6274, 0.6571, 0.6813],
         [0.6301, 0.6362, 0.6410],
         [0.6410, 0.6144, 0.6337]],

        [[0.6694, 0.6798, 0.7027],
         [0.6486, 0.6566, 0.6589],
         [0.6495, 0.6457, 0.6690]],

        [[0.5005, 0.5307, 0.5705],
         [0.5281, 0.5810, 0.5521],
         [0.5444, 0.5394, 0.5345]],

        [[0.6406, 0.6689, 0.6833],
         [0.6351, 0.6397, 0.6448],
         [0.6366, 0.6207, 0.6429]],

        [[0.5691, 0.6060, 0.6315],
         [0.5913, 0.6170, 0.6109],
         [0.6117, 0.5997, 0.5970]]], device='cuda:0',
       grad_fn=<SelectBackward0>)


In [45]:
can = CrossAttentionNetwork().cuda()
criterion = CANLoss(can, n_classes=n_way).cuda()
optimizer_params = {
    "lr": 0.001,
}
optimizer = SGD(criterion.parameters(), **optimizer_params, )
epochs = 10

for epoch in range(epochs):
    losses = []
    for episode_samples, episode_labels in tqdm(ds_train):
        optimizer.zero_grad()
        # print(episode_samples.shape)
        # print(episode_labels.shape)
        y_true: torch.Tensor = episode_labels[n_way * n_shots:].unsqueeze(-1).cuda()
        support, query = split_support_query(episode_samples, n_way, n_shots)
        support: torch.Tensor = support.cuda()
        query: torch.Tensor = query.cuda()

        # print(support.shape)
        # print(query.shape)
        # print(y_true.shape)

        loss = criterion(support, query, y_true)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch {epoch+1}, loss: {torch.tensor(losses).mean()}")
    print("Losses", losses)

can_warmed_up = CrossAttentionNetwork()
can_warmed_up.load_state_dict(can.state_dict())
criterion_warmed_up = CANLoss(can_warmed_up, n_classes=n_way)
criterion_warmed_up.load_state_dict(criterion.state_dict())

100%|██████████| 120/120 [00:14<00:00,  8.22it/s]


Epoch 1, loss: 198.50894165039062
Losses [192.42591857910156, 237.9314422607422, 237.76531982421875, 245.19598388671875, 221.25787353515625, 236.29574584960938, 239.2960968017578, 289.90252685546875, 234.09927368164062, 219.93170166015625, 228.55364990234375, 350.619873046875, 185.3072967529297, 324.6642761230469, 256.7618408203125, 210.46917724609375, 268.48870849609375, 251.00247192382812, 301.4913330078125, 205.4044952392578, 210.89247131347656, 223.46502685546875, 184.7008514404297, 223.15823364257812, 219.7668914794922, 250.71609497070312, 214.39227294921875, 212.18272399902344, 245.47213745117188, 215.8321533203125, 192.01109313964844, 196.920166015625, 247.02711486816406, 205.81088256835938, 184.94223022460938, 197.227783203125, 194.1817626953125, 183.76893615722656, 195.86712646484375, 233.53994750976562, 205.96414184570312, 185.79127502441406, 193.08319091796875, 187.62911987304688, 189.32125854492188, 181.9144287109375, 183.29896545410156, 191.9380340576172, 182.1277313232422

100%|██████████| 120/120 [00:15<00:00,  7.99it/s]


Epoch 2, loss: 178.96128845214844
Losses [179.9770050048828, 178.4154052734375, 178.51812744140625, 179.78636169433594, 179.15972900390625, 179.83270263671875, 178.6851806640625, 178.74208068847656, 178.71200561523438, 178.9686279296875, 180.3282012939453, 178.48582458496094, 178.55201721191406, 179.2447509765625, 179.7429962158203, 179.1844940185547, 178.46884155273438, 178.6881103515625, 178.9486083984375, 178.95611572265625, 178.5220184326172, 178.22509765625, 178.27886962890625, 179.30364990234375, 178.673095703125, 181.0144805908203, 179.02105712890625, 178.48541259765625, 178.2066650390625, 178.20907592773438, 180.08395385742188, 184.88461303710938, 178.8894500732422, 178.96788024902344, 198.44859313964844, 178.3725128173828, 181.89874267578125, 178.21633911132812, 178.33876037597656, 178.2165069580078, 181.2938995361328, 178.37440490722656, 179.03564453125, 178.23912048339844, 180.42269897460938, 178.78286743164062, 178.5854949951172, 178.0987548828125, 178.1309051513672, 178.30

100%|██████████| 120/120 [00:14<00:00,  8.07it/s]


Epoch 3, loss: 178.22218322753906
Losses [178.44183349609375, 178.35025024414062, 178.119873046875, 178.07376098632812, 178.24285888671875, 178.13284301757812, 178.43304443359375, 178.27415466308594, 178.18087768554688, 178.1523895263672, 178.25372314453125, 178.42291259765625, 178.26458740234375, 178.30197143554688, 178.06936645507812, 178.13427734375, 178.21983337402344, 180.2061767578125, 178.59146118164062, 178.23013305664062, 178.17788696289062, 178.19503784179688, 178.11444091796875, 178.81100463867188, 178.6336669921875, 178.43423461914062, 178.28765869140625, 178.2894287109375, 178.15841674804688, 178.54345703125, 178.18060302734375, 178.09329223632812, 178.24998474121094, 178.4127197265625, 178.19825744628906, 178.2690887451172, 178.50747680664062, 178.16970825195312, 178.25875854492188, 178.5123291015625, 178.17185974121094, 178.14859008789062, 178.07310485839844, 178.15000915527344, 178.26168823242188, 178.13995361328125, 178.3119354248047, 178.17269897460938, 178.1093902587

100%|██████████| 120/120 [00:15<00:00,  7.84it/s]


Epoch 4, loss: 178.12826538085938
Losses [178.28614807128906, 178.24887084960938, 178.21029663085938, 178.38241577148438, 178.19439697265625, 178.160888671875, 178.0693817138672, 178.229248046875, 178.079345703125, 178.07008361816406, 178.122314453125, 178.1374969482422, 178.17190551757812, 178.09617614746094, 178.05642700195312, 178.10360717773438, 178.08572387695312, 178.1634979248047, 178.05758666992188, 178.2420196533203, 178.1058807373047, 178.00912475585938, 178.10423278808594, 178.08236694335938, 178.07772827148438, 178.43399047851562, 178.1163330078125, 178.06285095214844, 178.07208251953125, 178.16546630859375, 178.08241271972656, 178.06143188476562, 178.1833953857422, 178.13658142089844, 178.11056518554688, 178.14163208007812, 178.0637969970703, 178.01687622070312, 178.1009521484375, 178.12664794921875, 178.23680114746094, 178.13040161132812, 178.06930541992188, 178.06741333007812, 178.17758178710938, 178.17440795898438, 178.12779235839844, 178.04026794433594, 178.15652465820

100%|██████████| 120/120 [00:15<00:00,  7.93it/s]


Epoch 5, loss: 178.2189483642578
Losses [178.1849365234375, 178.146728515625, 178.4124755859375, 178.08981323242188, 178.15037536621094, 178.5340576171875, 178.4835205078125, 178.28326416015625, 178.05093383789062, 178.14266967773438, 178.09918212890625, 178.0738525390625, 178.10702514648438, 178.02452087402344, 178.03013610839844, 178.0419921875, 178.18804931640625, 178.73690795898438, 178.0683135986328, 178.17666625976562, 178.0782928466797, 178.12489318847656, 178.0885467529297, 178.32574462890625, 182.28189086914062, 178.07423400878906, 178.2489471435547, 178.11976623535156, 180.00213623046875, 178.04736328125, 178.15045166015625, 178.14170837402344, 181.87420654296875, 178.13809204101562, 178.0547332763672, 178.260498046875, 178.73336791992188, 178.1239471435547, 178.0979461669922, 178.15142822265625, 178.42800903320312, 178.651611328125, 178.14544677734375, 178.12960815429688, 178.13845825195312, 178.10079956054688, 178.0794677734375, 178.25469970703125, 178.0694580078125, 178.10

100%|██████████| 120/120 [00:15<00:00,  7.88it/s]


Epoch 6, loss: 178.08486938476562
Losses [178.06622314453125, 178.36410522460938, 178.03855895996094, 178.02328491210938, 178.13650512695312, 178.02191162109375, 178.06134033203125, 178.57675170898438, 178.02066040039062, 178.0683135986328, 178.04769897460938, 178.09536743164062, 178.09713745117188, 178.0325164794922, 178.02362060546875, 178.0115203857422, 178.08526611328125, 178.02764892578125, 178.06126403808594, 178.20162963867188, 178.01541137695312, 178.02110290527344, 178.01771545410156, 178.11619567871094, 178.06082153320312, 178.09234619140625, 178.03701782226562, 178.07545471191406, 178.0655975341797, 178.02365112304688, 178.02456665039062, 178.0282440185547, 178.167236328125, 178.01654052734375, 178.14459228515625, 178.07798767089844, 178.28176879882812, 178.04025268554688, 178.03575134277344, 178.12579345703125, 178.064208984375, 178.5859375, 178.02877807617188, 178.0803680419922, 178.17039489746094, 178.03765869140625, 178.1770782470703, 178.36717224121094, 178.041473388671

100%|██████████| 120/120 [00:14<00:00,  8.04it/s]


Epoch 7, loss: 178.0553741455078
Losses [178.01602172851562, 178.0426483154297, 178.0181884765625, 178.0244140625, 178.17259216308594, 178.01890563964844, 178.07223510742188, 178.05618286132812, 178.02392578125, 178.04611206054688, 178.0930938720703, 178.05905151367188, 178.06546020507812, 178.00119018554688, 178.008056640625, 178.05926513671875, 178.10516357421875, 178.05599975585938, 178.05276489257812, 178.14199829101562, 178.05291748046875, 178.0447235107422, 178.00531005859375, 178.07980346679688, 178.0530242919922, 178.1043243408203, 178.09873962402344, 178.10182189941406, 178.02943420410156, 178.05014038085938, 178.08108520507812, 178.04452514648438, 178.06100463867188, 178.07272338867188, 178.0284881591797, 178.090576171875, 178.03323364257812, 178.11328125, 178.0310516357422, 178.04515075683594, 178.04299926757812, 178.0757293701172, 178.0349884033203, 178.0308380126953, 178.0506591796875, 178.02655029296875, 178.12307739257812, 178.03260803222656, 178.09786987304688, 178.0279

100%|██████████| 120/120 [00:15<00:00,  7.93it/s]


Epoch 8, loss: 178.06346130371094
Losses [178.03433227539062, 178.09771728515625, 178.00180053710938, 178.0249786376953, 178.095947265625, 178.1306915283203, 178.0175323486328, 178.045654296875, 178.03109741210938, 178.0653839111328, 178.04714965820312, 178.0985107421875, 178.00100708007812, 178.0218048095703, 177.9994659423828, 178.03109741210938, 178.01319885253906, 178.03353881835938, 178.2552490234375, 178.25344848632812, 178.04896545410156, 177.99526977539062, 178.0371551513672, 178.23907470703125, 178.02967834472656, 178.0779571533203, 178.04722595214844, 178.064208984375, 178.06130981445312, 178.0089111328125, 178.00421142578125, 178.05343627929688, 178.08084106445312, 178.03268432617188, 178.0548858642578, 178.1356201171875, 178.06460571289062, 178.05772399902344, 178.05001831054688, 178.04190063476562, 178.051513671875, 178.0622100830078, 178.02886962890625, 178.02880859375, 178.06785583496094, 178.06170654296875, 178.07408142089844, 178.0369110107422, 178.05572509765625, 178.

100%|██████████| 120/120 [00:15<00:00,  7.80it/s]


Epoch 9, loss: 178.0452117919922
Losses [178.01480102539062, 178.04432678222656, 178.02935791015625, 177.99264526367188, 178.07205200195312, 178.05343627929688, 178.0538787841797, 178.07432556152344, 178.04615783691406, 178.07736206054688, 178.05311584472656, 178.04226684570312, 178.03492736816406, 177.99124145507812, 178.01351928710938, 178.03269958496094, 178.03759765625, 178.0446319580078, 178.02395629882812, 178.00473022460938, 178.02102661132812, 178.0048828125, 178.1102294921875, 178.03880310058594, 178.0186004638672, 178.0362548828125, 178.043212890625, 178.022705078125, 178.00222778320312, 178.06289672851562, 178.01412963867188, 178.024169921875, 178.16273498535156, 178.02081298828125, 178.01226806640625, 178.0391387939453, 178.02029418945312, 178.0433807373047, 178.02053833007812, 178.0191192626953, 177.997314453125, 178.04519653320312, 178.0301971435547, 178.0805206298828, 178.04190063476562, 178.02972412109375, 178.08657836914062, 178.00050354003906, 178.0264892578125, 178.0

100%|██████████| 120/120 [00:15<00:00,  7.86it/s]

Epoch 10, loss: 178.05142211914062
Losses [178.04055786132812, 178.05270385742188, 177.99264526367188, 177.99847412109375, 178.34024047851562, 178.0435791015625, 178.0194854736328, 178.1002960205078, 177.99256896972656, 178.07032775878906, 178.13406372070312, 178.11837768554688, 178.02764892578125, 178.10238647460938, 178.00669860839844, 178.041748046875, 178.07769775390625, 178.11422729492188, 178.06370544433594, 178.05690002441406, 178.10110473632812, 178.0052490234375, 178.0106201171875, 178.05084228515625, 178.00030517578125, 178.07171630859375, 178.0263214111328, 178.04588317871094, 178.00881958007812, 178.01597595214844, 178.0190887451172, 178.01168823242188, 178.119873046875, 178.03311157226562, 178.04054260253906, 178.1046142578125, 178.0120849609375, 178.01341247558594, 178.0101776123047, 178.07696533203125, 178.03794860839844, 178.0901336669922, 178.0034637451172, 178.0672607421875, 178.11033630371094, 178.00221252441406, 178.06492614746094, 178.0002899169922, 178.06298828125




<All keys matched successfully>

In [51]:
can.load_state_dict(can_warmed_up.state_dict())
criterion.load_state_dict(criterion_warmed_up.state_dict())

<All keys matched successfully>

In [52]:
optimizer_params = {
    "lr": 0.001,
}
optimizer = SGD(criterion.parameters(), **optimizer_params, )
epochs = 10

for epoch in range(epochs):
    losses = []
    for episode_samples, episode_labels in tqdm(ds_train):
        optimizer.zero_grad()
        # print(episode_samples.shape)
        # print(episode_labels.shape)
        y_true: torch.Tensor = episode_labels[n_way * n_shots:].unsqueeze(-1).cuda()
        support, query = split_support_query(episode_samples, n_way, n_shots)
        support: torch.Tensor = support.cuda()
        query: torch.Tensor = query.cuda()

        # print(support.shape)
        # print(query.shape)
        # print(y_true.shape)

        loss = criterion(support, query, y_true)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch {epoch+1}, loss: {torch.tensor(losses).mean()}")
    print("Losses", losses)

100%|██████████| 120/120 [00:15<00:00,  7.82it/s]


Epoch 1, loss: 178.0390625
Losses [178.0507049560547, 178.0585174560547, 178.0208740234375, 178.004150390625, 178.07373046875, 178.0234832763672, 178.04244995117188, 178.07427978515625, 178.02801513671875, 178.05892944335938, 178.058837890625, 178.01821899414062, 178.01966857910156, 177.99734497070312, 177.99365234375, 178.03753662109375, 178.0151824951172, 178.03021240234375, 177.99514770507812, 178.02157592773438, 178.07846069335938, 177.9918212890625, 178.06390380859375, 178.0248260498047, 178.0069580078125, 178.07432556152344, 178.0235595703125, 178.06692504882812, 178.04476928710938, 178.00929260253906, 178.039794921875, 178.04180908203125, 178.15969848632812, 178.040771484375, 178.07904052734375, 178.13258361816406, 178.0528564453125, 178.0921630859375, 178.01751708984375, 178.02383422851562, 178.05450439453125, 178.04824829101562, 178.04150390625, 178.0391845703125, 178.11485290527344, 178.02041625976562, 178.0462646484375, 177.999755859375, 178.0139617919922, 178.04159545898438

100%|██████████| 120/120 [00:15<00:00,  7.78it/s]


Epoch 2, loss: 178.03428649902344
Losses [177.99905395507812, 178.03048706054688, 177.99755859375, 178.0045928955078, 178.0302734375, 178.01617431640625, 178.0039520263672, 178.032470703125, 178.02511596679688, 178.03909301757812, 178.0143585205078, 178.08926391601562, 178.0113525390625, 177.986572265625, 178.0303955078125, 178.05819702148438, 178.00157165527344, 178.0264434814453, 178.0155792236328, 178.00390625, 178.060791015625, 177.99758911132812, 178.01333618164062, 178.025634765625, 178.03009033203125, 178.04598999023438, 178.01791381835938, 178.00894165039062, 178.00205993652344, 178.0091552734375, 178.01889038085938, 178.0022430419922, 178.09632873535156, 178.000244140625, 178.03701782226562, 178.06353759765625, 178.0583953857422, 178.061767578125, 178.02517700195312, 177.99606323242188, 178.08358764648438, 178.02828979492188, 178.00985717773438, 178.03439331054688, 178.05552673339844, 178.01019287109375, 178.0233917236328, 177.99880981445312, 178.0674591064453, 178.06016540527

100%|██████████| 120/120 [00:15<00:00,  7.86it/s]


Epoch 3, loss: 178.032958984375
Losses [178.01939392089844, 178.0391387939453, 178.0081024169922, 177.98912048339844, 178.06321716308594, 178.0206756591797, 178.07723999023438, 178.05242919921875, 177.99896240234375, 178.0080108642578, 178.007080078125, 178.02183532714844, 177.9876708984375, 177.9996795654297, 178.01683044433594, 178.01173400878906, 178.00526428222656, 178.06887817382812, 178.0416259765625, 178.03289794921875, 178.014892578125, 177.99920654296875, 178.03236389160156, 178.0393524169922, 178.01548767089844, 178.0250244140625, 178.02011108398438, 178.02838134765625, 178.00375366210938, 177.9990234375, 178.0015411376953, 178.02505493164062, 178.05372619628906, 178.01791381835938, 178.01771545410156, 178.0804443359375, 178.02932739257812, 178.10008239746094, 178.00634765625, 178.05892944335938, 178.04116821289062, 178.00689697265625, 177.9995880126953, 178.10060119628906, 178.01513671875, 178.0260772705078, 178.0406951904297, 177.99420166015625, 178.03497314453125, 178.0109

100%|██████████| 120/120 [00:15<00:00,  7.81it/s]


Epoch 4, loss: 178.03604125976562
Losses [178.01934814453125, 178.09927368164062, 178.02386474609375, 178.01766967773438, 178.09771728515625, 178.03875732421875, 178.01162719726562, 178.01547241210938, 178.27243041992188, 178.20339965820312, 178.06414794921875, 178.05276489257812, 177.99771118164062, 178.01939392089844, 178.002685546875, 178.02053833007812, 178.0089874267578, 178.108642578125, 178.05978393554688, 178.072509765625, 178.04019165039062, 177.99012756347656, 178.0199737548828, 178.01742553710938, 178.0102081298828, 178.050537109375, 178.02297973632812, 178.02301025390625, 178.02316284179688, 178.01611328125, 178.00082397460938, 178.0130615234375, 178.13841247558594, 178.0157470703125, 178.02218627929688, 178.04873657226562, 178.04183959960938, 178.10888671875, 178.02593994140625, 178.00680541992188, 178.03671264648438, 178.00698852539062, 178.01513671875, 178.0296173095703, 178.01339721679688, 177.99191284179688, 178.08612060546875, 177.98971557617188, 178.0002899169922, 17

100%|██████████| 120/120 [00:15<00:00,  7.79it/s]


Epoch 5, loss: 178.02891540527344
Losses [178.02244567871094, 178.04339599609375, 177.9982452392578, 178.00428771972656, 178.10362243652344, 178.00045776367188, 178.0092315673828, 178.01449584960938, 178.0089111328125, 178.05386352539062, 178.1006317138672, 178.033203125, 178.00628662109375, 178.0438690185547, 178.005859375, 178.0050048828125, 178.01333618164062, 178.09552001953125, 178.0707550048828, 178.00494384765625, 178.09432983398438, 177.99240112304688, 178.02536010742188, 178.0741729736328, 178.00010681152344, 178.0247344970703, 178.01055908203125, 178.00555419921875, 178.01121520996094, 178.0437774658203, 178.00674438476562, 177.99868774414062, 178.1166229248047, 178.01077270507812, 178.00669860839844, 178.0201416015625, 178.07403564453125, 178.01133728027344, 178.0244140625, 177.99472045898438, 178.12522888183594, 178.0327911376953, 178.00076293945312, 178.0194091796875, 178.01089477539062, 178.0269317626953, 177.9929962158203, 177.9996337890625, 177.9952392578125, 177.993515

100%|██████████| 120/120 [00:16<00:00,  7.37it/s]


Epoch 6, loss: 178.02854919433594
Losses [178.0059814453125, 178.0265655517578, 177.99801635742188, 177.99319458007812, 178.05038452148438, 178.00070190429688, 177.99301147460938, 178.05169677734375, 178.00172424316406, 178.06591796875, 178.03155517578125, 178.0418701171875, 178.02365112304688, 177.98690795898438, 177.99447631835938, 177.9990234375, 178.00588989257812, 178.05677795410156, 178.01278686523438, 178.01461791992188, 178.00637817382812, 177.9945831298828, 178.05227661132812, 177.99691772460938, 177.9981231689453, 178.06834411621094, 178.02020263671875, 178.03543090820312, 178.0032501220703, 177.9992218017578, 178.03648376464844, 177.9952392578125, 178.03683471679688, 178.03213500976562, 177.99998474121094, 178.01329040527344, 178.0772705078125, 177.9962921142578, 178.0293426513672, 178.01242065429688, 178.09275817871094, 178.03585815429688, 178.01211547851562, 178.06626892089844, 178.05667114257812, 178.0072021484375, 178.03042602539062, 177.9891815185547, 178.02230834960938

100%|██████████| 120/120 [00:16<00:00,  7.24it/s]


Epoch 7, loss: 178.03826904296875
Losses [178.07382202148438, 178.0862274169922, 178.0128631591797, 177.9967041015625, 178.03253173828125, 178.0360565185547, 178.03952026367188, 178.052734375, 177.993408203125, 178.00262451171875, 178.06198120117188, 178.0276641845703, 178.0399169921875, 177.99169921875, 178.25473022460938, 178.0681915283203, 178.00653076171875, 178.0222930908203, 178.02114868164062, 178.0251922607422, 178.0535430908203, 177.98745727539062, 178.03045654296875, 178.01718139648438, 178.01939392089844, 178.04354858398438, 178.04087829589844, 178.05422973632812, 178.01467895507812, 177.99876403808594, 177.99493408203125, 177.99609375, 178.09716796875, 178.00634765625, 178.01510620117188, 178.07916259765625, 178.00645446777344, 178.05419921875, 178.01783752441406, 178.11215209960938, 178.07577514648438, 178.0399932861328, 178.03445434570312, 178.06549072265625, 178.09295654296875, 178.04498291015625, 178.05465698242188, 177.9898681640625, 178.0020294189453, 178.039581298828

100%|██████████| 120/120 [00:17<00:00,  7.03it/s]


Epoch 8, loss: 178.03121948242188
Losses [178.01185607910156, 178.04522705078125, 178.0711669921875, 177.9980010986328, 178.08372497558594, 178.0096435546875, 178.04241943359375, 178.0178985595703, 178.01260375976562, 177.99740600585938, 178.06973266601562, 178.12057495117188, 178.0122528076172, 177.9859619140625, 178.014404296875, 178.0413055419922, 178.00221252441406, 178.02523803710938, 178.04388427734375, 177.99166870117188, 178.0423126220703, 177.9960479736328, 177.9939422607422, 178.00816345214844, 177.99468994140625, 178.05691528320312, 178.0072021484375, 178.06423950195312, 178.0068359375, 177.99371337890625, 178.0346221923828, 178.02120971679688, 178.07481384277344, 178.0118408203125, 178.0138397216797, 178.056640625, 178.0389862060547, 178.01336669921875, 178.03872680664062, 178.0150146484375, 178.08921813964844, 178.043212890625, 178.02508544921875, 178.02047729492188, 178.03164672851562, 178.00733947753906, 178.01055908203125, 177.99057006835938, 178.02093505859375, 178.004

100%|██████████| 120/120 [00:15<00:00,  7.72it/s]


Epoch 9, loss: 178.1092071533203
Losses [177.99659729003906, 178.02719116210938, 178.01803588867188, 177.99168395996094, 178.03231811523438, 178.01815795898438, 178.027099609375, 178.04031372070312, 178.01560974121094, 178.01925659179688, 178.0301513671875, 178.03646850585938, 178.00753784179688, 177.99142456054688, 177.99256896972656, 178.0326690673828, 178.00360107421875, 178.04623413085938, 178.01654052734375, 177.9921112060547, 178.03250122070312, 177.99644470214844, 177.99220275878906, 177.99891662597656, 178.00567626953125, 178.03475952148438, 178.0072479248047, 178.01287841796875, 178.00698852539062, 178.0072784423828, 178.02267456054688, 177.9951171875, 178.0298614501953, 178.00181579589844, 178.01760864257812, 178.0435791015625, 178.0468292236328, 177.99229431152344, 178.0223388671875, 177.99932861328125, 178.0791015625, 178.02761840820312, 178.00466918945312, 178.02659606933594, 178.0367431640625, 178.01199340820312, 178.01841735839844, 177.98748779296875, 178.00242614746094,

100%|██████████| 120/120 [00:16<00:00,  7.29it/s]

Epoch 10, loss: 178.02137756347656
Losses [178.00250244140625, 178.05242919921875, 178.00369262695312, 177.98593139648438, 178.05307006835938, 178.003662109375, 177.99461364746094, 178.04580688476562, 177.991943359375, 178.01319885253906, 178.053955078125, 178.08901977539062, 177.9940185546875, 177.99400329589844, 177.99195861816406, 178.00567626953125, 177.99795532226562, 178.02911376953125, 178.0139617919922, 178.01580810546875, 178.02813720703125, 177.99203491210938, 177.99359130859375, 178.00045776367188, 177.9938507080078, 177.9940185546875, 177.9979248046875, 178.0817413330078, 177.99224853515625, 177.99130249023438, 178.00216674804688, 178.00607299804688, 178.096435546875, 178.2023468017578, 178.0006866455078, 178.22238159179688, 178.0074005126953, 178.00064086914062, 178.0033721923828, 178.0355224609375, 178.0518035888672, 178.01351928710938, 177.99185180664062, 178.04815673828125, 178.0233154296875, 177.9921112060547, 178.00303649902344, 177.9933624267578, 177.99871826171875, 


