# Deep Learning - Project 1
Wojciech Kutak

---

### Cross Attention Network for few-shot learning problem
#### 1. Code

In [5]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
import torchvision
from torchinfo import summary


DATA_PATH = os.path.join("..", "data")
FEW_SHOT_PATH = os.path.join("..", "data_few_shot")

In [46]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsampler = None):
        super().__init__()
        self.downsampler = downsampler

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)
        self.out_channels = out_channels


    def forward(self, x):
        identity = x
        if self.downsampler is not None:
            identity = self.downsampler(identity)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        out = out + identity
        out = self.relu(out)
        return out


class ResNet_32x32(nn.Module):
    def __init__(self, block: nn.Module = ResidualBlock):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, (3, 3), stride=1, padding=1)
        self.bnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer1 = self._residual_layer(block, 64, 128, 2)

        self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer2 = self._residual_layer(block, 128, 256, 2)

        self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.res_layer3 = self._residual_layer(block, 256, 512, 2)

        self.output_shape = (512, 3, 3)


    def _residual_layer(self, block, in_channels, out_channels, blocks_num, stride=1):
        """Creates a residual layer consisting out of residual blocks"""
        downsampler = None
        if stride != 1 or in_channels != out_channels:
            downsampler = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        layers = []
        layers.append(block(in_channels, out_channels, stride, downsampler))
        for _ in range(blocks_num - 1):
            layers.append(block(out_channels, out_channels, stride))

        return nn.Sequential(*layers)


    def forward(self, x):
        x = self.relu(self.bnorm(self.conv(x)))

        x = self.res_layer1(self.max_pool1(x))
        x = self.res_layer2(self.max_pool2(x))
        x = self.res_layer3(self.max_pool3(x))

        return x

In [150]:
summary(ResNet_32x32(), (30, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet_32x32                             [30, 512, 3, 3]           --
├─Conv2d: 1-1                            [30, 64, 32, 32]          1,792
├─BatchNorm2d: 1-2                       [30, 64, 32, 32]          128
├─ReLU: 1-3                              [30, 64, 32, 32]          --
├─MaxPool2d: 1-4                         [30, 64, 15, 15]          --
├─Sequential: 1-5                        [30, 128, 15, 15]         --
│    └─ResidualBlock: 2-1                [30, 128, 15, 15]         --
│    │    └─Sequential: 3-1              [30, 128, 15, 15]         8,448
│    │    └─Conv2d: 3-2                  [30, 128, 15, 15]         73,856
│    │    └─BatchNorm2d: 3-3             [30, 128, 15, 15]         256
│    │    └─ReLU: 3-4                    [30, 128, 15, 15]         --
│    │    └─Conv2d: 3-5                  [30, 128, 15, 15]         147,584
│    │    └─BatchNorm2d: 3-6             [30, 128, 15, 15]         2

In [3]:
class FusionLayer(nn.Module):
    def __init__(self, m: int, bottleneck_size: int):
        super(FusionLayer, self).__init__()

        self.temperature = 1.0
        self.m = m
        self.bottleneck_size = bottleneck_size
        self.spatial_gap = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(self.m, self.bottleneck_size, kernel_size=1)
        self.conv2 = nn.Conv2d(self.bottleneck_size, self.m, kernel_size=1)
        self.relu = nn.ReLU()


    def forward(self, R: torch.Tensor):
        # print("FusionLayer forward")
        # print("R shape:", R.shape)
        b, M, m, h, w = R.shape
        w: torch.Tensor = self.spatial_gap(R)
        # print("spatial w:", w.shape)
        w = w.view(b * M, *w.shape[2:])
        # print("spatial w after:", w.shape)

        # w = w.unsqueeze(-2)
        # Meta learner
        w = self.conv1(w)
        # print("conv1 w:", w.shape)
        w = self.relu(w)
        # print("relu w:",w.shape)
        w = self.conv2(w)
        # print("conv2 w:",w.shape)

        # w = w.squeeze((-2, -1))
        # print(w.shape)

        # Convolution operation
        # w is now a vector of average class
        A = self.attention(w, R)
        # print("Attention:", A.shape)


        return A


    def attention(self, weights: torch.Tensor, R: torch.Tensor) -> torch.Tensor:
        weights_t = weights.transpose(-3, -1).squeeze(1)

        b, M, m, h, w = R.shape
        R = R.view(b * M, m, h*w)
        # print("weights:", weights.shape)
        # print("weights_t:", weights_t.shape)
        # print("R:", R.shape)

        R_mean = torch.bmm(weights_t, R) / self.temperature
        # print("R_mean:", R_mean.shape)
        R_mean = R_mean.view(b, M, 1, m)
        A = F.softmax(R_mean, dim=-1)
        return A

In [133]:
width, height = 5, 5
m = width*height
M = 5
batch_size = 3
summary(FusionLayer(m=m, bottleneck_size=15), (batch_size, M, m, height, width))

FusionLayer forward
R shape: torch.Size([3, 5, 25, 5, 5])
spatial w: torch.Size([3, 5, 25, 1, 1])
spatial w after: torch.Size([15, 25, 1, 1])
conv1 w: torch.Size([15, 15, 1, 1])
relu w: torch.Size([15, 15, 1, 1])
conv2 w: torch.Size([15, 25, 1, 1])
weights: torch.Size([15, 25, 1, 1])
weights_t: torch.Size([15, 1, 25])
R: torch.Size([15, 25, 25])
R_mean: torch.Size([15, 1, 25])
Attention: torch.Size([3, 5, 1, 25])


Layer (type:depth-idx)                   Output Shape              Param #
FusionLayer                              [3, 5, 1, 25]             --
├─AdaptiveAvgPool2d: 1-1                 [3, 5, 25, 1, 1]          --
├─Conv2d: 1-2                            [15, 15, 1, 1]            390
├─ReLU: 1-3                              [15, 15, 1, 1]            --
├─Conv2d: 1-4                            [15, 25, 1, 1]            400
Total params: 790
Trainable params: 790
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.01
Input size (MB): 0.04
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.05

In [59]:
a = torch.tensor([1.0, 2.0]).view(1, 2)
b = torch.tensor([[2.0, 3.0, 4.0], [10.0, 11.0, 112.0]]).view(1, 2, 3, 1)
a*b

tensor([[[[  2.,   4.],
          [  3.,   6.],
          [  4.,   8.]],

         [[ 10.,  20.],
          [ 11.,  22.],
          [112., 224.]]]])

In [4]:
class CrossAttentionModule(nn.Module):
    def __init__(self, input_shape: tuple[int,int,int]):
        super(CrossAttentionModule, self).__init__()

        # print("input shape", *input_shape)
        c, h, w = input_shape
        self.fusion_layer = FusionLayer(m=h*w, bottleneck_size=int(h*w*2/3))


    def forward(self, P: torch.Tensor, Q: torch.Tensor):
    # def forward(self, S: torch.Tensor):
        # P, Q = S[0], S[1, :, 0, :, :, :]
        # print("CrossAttentionModule forward")
        # print(P.shape, Q.shape)
        # assert P.shape == Q.shape
        b, M, c, h, w = P.shape
        assert (b, c, h, w) == Q.shape
        m = h*w
        # Change representation from tensor c*h*w to c*m (2 dims)
        P = P.view(b, M, c, m)
        Q = Q.view(b, 1, c, m)

        P_norm = F.normalize(P, p=2, dim=1)
        Q_norm = F.normalize(Q, p=2, dim=1)
        # print("P norm:", P_norm.shape)
        P_norm_t = P_norm.transpose(-2, -1)
        # print("P norm t:", P_norm_t.shape)
        # print("Q norm:", Q_norm.shape)
        R_q = torch.matmul(P_norm_t, Q_norm)
        R_p = R_q.transpose(-2, -1).view(b, M, m, h, w)
        R_q = R_q.view(b, M, m, h, w)
        # print("R_q", R_q.shape)
        # print("R_p", R_p.shape)



        A_p: torch.Tensor = self.fusion_layer(R_p)
        # print()
        A_q: torch.Tensor = self.fusion_layer(R_q)
        # print()

        # print("P:", P.shape, "A_p:", A_p.shape)
        P_feat = torch.mul(P, A_p) + P
        # print("P_feat:", P_feat.shape)

        # print("Q:", Q.shape, "A_q:", A_q.shape)
        Q = Q.expand(b, M, c, m)
        # print("Q exp:", Q.shape, "A_q:", A_q.shape)
        Q_feat = torch.mul(Q, A_q) + Q
        # print("Q_feat:", Q_feat.shape)
        return P_feat, Q_feat

In [None]:
a = torch.arange(8).reshape(2, 2, 2)
print(a)
a2 = a.expand(2, 2, 2, 2).transpose(0, 1)
a2[:, 0, :, :], a2[:, 1, :, :]

tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])


(tensor([[[0, 1],
          [2, 3]],
 
         [[4, 5],
          [6, 7]]]),
 tensor([[[0, 1],
          [2, 3]],
 
         [[4, 5],
          [6, 7]]]))

In [113]:
b = 3
m = 5

summary(CrossAttentionModule((512, 5, 5)), (2, b, m, 512, 5, 5))

CrossAttentionModule forward
torch.Size([3, 5, 512, 5, 5]) torch.Size([3, 512, 5, 5])
P norm: torch.Size([3, 5, 512, 25])
P norm t: torch.Size([3, 5, 25, 512])
Q norm: torch.Size([3, 1, 512, 25])
R_q torch.Size([3, 5, 25, 5, 5])
R_p torch.Size([3, 5, 25, 5, 5])
FusionLayer forward
R shape: torch.Size([3, 5, 25, 5, 5])
spatial w: torch.Size([3, 5, 25, 1, 1])
spatial w after: torch.Size([15, 25, 1, 1])
conv1 w: torch.Size([15, 15, 1, 1])
relu w: torch.Size([15, 15, 1, 1])
conv2 w: torch.Size([15, 25, 1, 1])
weights: torch.Size([15, 25, 1, 1])
weights_t: torch.Size([15, 1, 25])
R: torch.Size([15, 25, 25])
R_mean: torch.Size([15, 1, 25])
Attention: torch.Size([3, 5, 1, 25])

FusionLayer forward
R shape: torch.Size([3, 5, 25, 5, 5])
spatial w: torch.Size([3, 5, 25, 1, 1])
spatial w after: torch.Size([15, 25, 1, 1])
conv1 w: torch.Size([15, 15, 1, 1])
relu w: torch.Size([15, 15, 1, 1])
conv2 w: torch.Size([15, 25, 1, 1])
weights: torch.Size([15, 25, 1, 1])
weights_t: torch.Size([15, 1, 25])


Layer (type:depth-idx)                   Output Shape              Param #
CrossAttentionModule                     [3, 5, 512, 25]           --
├─FusionLayer: 1-1                       [3, 5, 1, 25]             --
│    └─AdaptiveAvgPool2d: 2-1            [3, 5, 25, 1, 1]          --
│    └─Conv2d: 2-2                       [15, 15, 1, 1]            390
│    └─ReLU: 2-3                         [15, 15, 1, 1]            --
│    └─Conv2d: 2-4                       [15, 25, 1, 1]            400
├─FusionLayer: 1-2                       [3, 5, 1, 25]             (recursive)
│    └─AdaptiveAvgPool2d: 2-5            [3, 5, 25, 1, 1]          --
│    └─Conv2d: 2-6                       [15, 15, 1, 1]            (recursive)
│    └─ReLU: 2-7                         [15, 15, 1, 1]            --
│    └─Conv2d: 2-8                       [15, 25, 1, 1]            (recursive)
Total params: 790
Trainable params: 790
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.02
Input size (MB): 1.54


In [36]:
A = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
B = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
torch.mul(A, B), torch.matmul(A, B)

(tensor([[ 1.,  4.],
         [ 9., 16.]]),
 tensor([[ 7., 10.],
         [15., 22.]]))

In [108]:
A = torch.tensor([[1.0, 2.0], [3., 4.]])
B = torch.tensor([[7., 11.]])
A*B

tensor([[ 7., 22.],
        [21., 44.]])

In [5]:
class CrossAttentionNetwork(nn.Module):
    def __init__(self, cam: CrossAttentionModule = None):
        super(CrossAttentionNetwork, self).__init__()

        self.embedding = ResNet_32x32()
        self.cam = cam if cam is not None else CrossAttentionModule(self.embedding.output_shape)


    def forward(self, support: torch.Tensor, query: torch.Tensor):
    # def forward(self, X: torch.Tensor):

        # support = X[0]
        # query = X[1, :, 0, 0, :, :]
        b, M, K, c, h, w = support.shape
        # print("Support: ", support.shape)
        # print("query: ", query.shape)

        # Shapes of support and query tensors should be
        # - support.shape = (b, M, K, c, h, w),
        # - query.shape = (b, c, h, w),
        # where b - batch size, M - number of classes, K - number of shots (examples per class),
        # c - channels, h - height and w - width.


        # Class feature P is defined as an average of embedded samples from one class
        P = self.embedding(support.view(-1, c, h, w))
        embedding_shape = P.shape[1:]
        P = P.view(b, M, K, *embedding_shape)
        P = torch.mean(P, dim=2)

        Q = self.embedding(query)
        P_feature, Q_feature = self.cam(P, Q)

        # print("P_feature:", P_feature.shape)
        # print("Q_feature:", Q_feature.shape)

        return P_feature, Q_feature

In [163]:
summary(CrossAttentionNetwork(), (2, 32, 5, 5, 3, 32, 32))

input shape 512 3 3
Support:  torch.Size([32, 5, 5, 3, 32, 32])
query:  torch.Size([32, 3, 32, 32])
CrossAttentionModule forward
torch.Size([32, 5, 512, 3, 3]) torch.Size([32, 512, 3, 3])
P norm: torch.Size([32, 5, 512, 9])
P norm t: torch.Size([32, 5, 9, 512])
Q norm: torch.Size([32, 1, 512, 9])
R_q torch.Size([32, 5, 9, 3, 3])
R_p torch.Size([32, 5, 9, 3, 3])
FusionLayer forward
R shape: torch.Size([32, 5, 9, 3, 3])
spatial w: torch.Size([32, 5, 9, 1, 1])
spatial w after: torch.Size([160, 9, 1, 1])
conv1 w: torch.Size([160, 6, 1, 1])
relu w: torch.Size([160, 6, 1, 1])
conv2 w: torch.Size([160, 9, 1, 1])
weights: torch.Size([160, 9, 1, 1])
weights_t: torch.Size([160, 1, 9])
R: torch.Size([160, 9, 9])
R_mean: torch.Size([160, 1, 9])
Attention: torch.Size([32, 5, 1, 9])

FusionLayer forward
R shape: torch.Size([32, 5, 9, 3, 3])
spatial w: torch.Size([32, 5, 9, 1, 1])
spatial w after: torch.Size([160, 9, 1, 1])
conv1 w: torch.Size([160, 6, 1, 1])
relu w: torch.Size([160, 6, 1, 1])
conv2 

Layer (type:depth-idx)                        Output Shape              Param #
CrossAttentionNetwork                         [32, 5, 512, 9]           --
├─ResNet_32x32: 1-1                           [800, 512, 3, 3]          --
│    └─Conv2d: 2-1                            [800, 64, 32, 32]         1,792
│    └─BatchNorm2d: 2-2                       [800, 64, 32, 32]         128
│    └─ReLU: 2-3                              [800, 64, 32, 32]         --
│    └─MaxPool2d: 2-4                         [800, 64, 15, 15]         --
│    └─Sequential: 2-5                        [800, 128, 15, 15]        --
│    │    └─ResidualBlock: 3-1                [800, 128, 15, 15]        230,400
│    │    └─ResidualBlock: 3-2                [800, 128, 15, 15]        295,680
│    └─MaxPool2d: 2-6                         [800, 128, 7, 7]          --
│    └─Sequential: 2-7                        [800, 256, 7, 7]          --
│    │    └─ResidualBlock: 3-3                [800, 256, 7, 7]          919,552
│

In [6]:
class CANLoss(nn.Module):
    def __init__(self, can: CrossAttentionNetwork, lamb: float = 0.5, n_classes: int = 5):
        super(CANLoss, self).__init__()

        self.can = can
        self.lamb = lamb
        self.classifier = nn.LazyConv1d(out_channels=n_classes, kernel_size=1)


    def cosine_dist(self, P_gap: torch.Tensor, Q_features: torch.Tensor) -> torch.Tensor:
        # print("\ncosine_dist")
        # print("Q_features",  Q_features.shape)
        # Q_features: (b, c, m)
        b, c, m = Q_features.shape
        # print("P_gap",  P_gap.shape)
        # P_gap: (b, c)
        P_gap_exp = P_gap.expand(m, b, c).transpose(0, 1).transpose(1, 2)
        # print("P_gap_exp", P_gap_exp.shape)
        cos_dist = F.cosine_similarity(P_gap_exp, Q_features, dim=-2)
        # print("cos_dist", cos_dist.shape)
        return cos_dist


    def L1_loss(self, P_sig: torch.Tensor, Q_sig: torch.Tensor, y_true: torch.Tensor):
        # P_features: (b, c, m)
        # Q_features: (b, c, m)
        # print("l1 loss")
        # print("P_features:", P_sig.shape)
        # print("Q_features:", Q_sig.shape)


        P_gap = torch.mean(P_sig, dim=-1)
        # P_gap: (b, c)
        distances = self.cosine_dist(P_gap, Q_sig)
        # distances: (b, m)
        # print("Distances:", distances.shape)

        likelihoods = torch.log(F.softmax(-distances, dim=-1))
        # print("likelihoods", likelihoods.shape)
        l1 = torch.sum(likelihoods)
        # print("l1", l1)
        return l1


    def L2_loss(self, Q_sig: torch.Tensor, y_true: torch.Tensor):
        # print("l2 loss")
        # print("Q_sig:", Q_sig.shape)
        # print()

        Z: torch.Tensor = self.classifier(Q_sig)
        # print("Z", Z.shape)
        Z_significant = []
        for y, z in zip(y_true, Z):
            Z_significant.append(z[y])
        Z_significant = torch.concat(Z_significant)
        # print("Z_significant:", Z_significant.shape)

        y_pred = F.softmax(Z_significant, dim=-1)
        # print("y_pred", y_pred.shape)
        l2 = torch.sum(y_pred)
        # print("l2:", l2.shape)
        return l2


    def forward(self, support: torch.Tensor, query: torch.Tensor, y_true: torch.Tensor):
        P_features, Q_features = self.can(support, query)
        Q_significant = []
        P_significant = []
        for y, q, p in zip(y_true, Q_features, P_features):
            Q_significant.append(q[y])
            P_significant.append(p[y])
        Q_significant = torch.concatenate(Q_significant, dim=0)
        P_significant = torch.concatenate(P_significant, dim=0)
        l1 = self.L1_loss(P_significant, Q_significant, y_true)
        l2 = self.L2_loss(Q_significant, y_true)
        loss = self.lamb*l1 + l2
        return loss

In [248]:
b = 32
M = 5
K = 5
c = 3
height, width = 32, 32
can = CrossAttentionNetwork()
criterion = CANLoss(can, n_classes=M)

support = torch.ones(b, M, K, c, height, width)
query = torch.ones(b, c, height, width)
y_true = torch.randint(0, 5, (b, 1))
loss = criterion(support, query, y_true)
loss.backward()

input shape 512 3 3
Support:  torch.Size([32, 5, 5, 3, 32, 32])
query:  torch.Size([32, 3, 32, 32])
CrossAttentionModule forward
torch.Size([32, 5, 512, 3, 3]) torch.Size([32, 512, 3, 3])
P norm: torch.Size([32, 5, 512, 9])
P norm t: torch.Size([32, 5, 9, 512])
Q norm: torch.Size([32, 1, 512, 9])
R_q torch.Size([32, 5, 9, 3, 3])
R_p torch.Size([32, 5, 9, 3, 3])
FusionLayer forward
R shape: torch.Size([32, 5, 9, 3, 3])
spatial w: torch.Size([32, 5, 9, 1, 1])
spatial w after: torch.Size([160, 9, 1, 1])
conv1 w: torch.Size([160, 6, 1, 1])
relu w: torch.Size([160, 6, 1, 1])
conv2 w: torch.Size([160, 9, 1, 1])
weights: torch.Size([160, 9, 1, 1])
weights_t: torch.Size([160, 1, 9])
R: torch.Size([160, 9, 9])
R_mean: torch.Size([160, 1, 9])
Attention: torch.Size([32, 5, 1, 9])

FusionLayer forward
R shape: torch.Size([32, 5, 9, 3, 3])
spatial w: torch.Size([32, 5, 9, 1, 1])
spatial w after: torch.Size([160, 9, 1, 1])
conv1 w: torch.Size([160, 6, 1, 1])
relu w: torch.Size([160, 6, 1, 1])
conv2 

In [214]:
a = torch.arange(30).view(3, 5, 2)
idx = torch.tensor([0, 1])
a, a[idx].shape, idx.shape

(tensor([[[ 0,  1],
          [ 2,  3],
          [ 4,  5],
          [ 6,  7],
          [ 8,  9]],
 
         [[10, 11],
          [12, 13],
          [14, 15],
          [16, 17],
          [18, 19]],
 
         [[20, 21],
          [22, 23],
          [24, 25],
          [26, 27],
          [28, 29]]]),
 torch.Size([2, 5, 2]),
 torch.Size([2]))

In [257]:
b = 32
M = 5
K = 5
c = 3
height, width = 32, 32
can = CrossAttentionNetwork().cuda()
criterion = CANLoss(can, n_classes=M).cuda()

support = torch.ones(b, M, K, c, height, width).cuda()
query = torch.ones(b, c, height, width).cuda()
y_true = torch.randint(0, 5, (b, 1)).cuda()
loss = criterion(support, query, y_true)
loss.backward()

Utils

In [45]:
import numpy as np
import torch
import random


def set_seed(seed: int):
    random.seed(seed)
    np.random.RandomState(seed)
    torch.manual_seed(seed)

def get_device():
    return "cuda:0" if torch.cuda.device_count() > 0 else 'cpu'


Data prep

In [46]:
from torch.utils.data import Sampler
import random

class FewShotSampler(Sampler):
    def __init__(self, data_source, n_shots: int, shuffle: bool = True):
        self.data_source = data_source
        self.n_shots = n_shots
        self.classes = self._get_classes()
        self.shuffle = shuffle


    def _get_classes(self):
        classes = dict()
        for i, (_, label) in enumerate(self.data_source):
            if label not in classes:
                classes[label] = []
            classes[label].append(i)

        return classes


    def __iter__(self):
        indices = []
        n = len(self.data_source)
        n_classes = len(self.classes)
        episodes = n // ((self.n_shots + 1) * n_classes)

        for i in range(episodes):
            query = []
            for c in self.classes:
                class_ind = self.classes[c]
                lower = i*(self.n_shots + 1)
                upper = (i+1)*(self.n_shots + 1) - 1
                support_one_class = class_ind[lower : upper]
                query.append(class_ind[upper])
                if self.shuffle:
                    random.shuffle(support_one_class)
                indices.extend(support_one_class)
            indices.extend(query)
        return iter(indices)

    def __len__(self):
        return len(self.data_source)

In [47]:
import os
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
import torchvision
from torchvision import transforms
import shutil


CINIC_MEAN = [0.47889522, 0.47227842, 0.43047404]
CINIC_STD = [0.24205776, 0.23828046, 0.25874835]


def prepare_data_folder(data_path: str, few_shot_path: str):
    subsets = ["train", "valid", "test"]
    classes = {"train": [ 'bird', 'cat', 'deer', 'dog', 'frog', 'horse'],
               "valid": ['airplane','automobile'],
               "test": [ 'ship','truck']
    }
    [os.makedirs(os.path.join(few_shot_path, sub), exist_ok=True) for sub in subsets]
    [[os.makedirs(os.path.join(few_shot_path, sub, c), exist_ok=True) for c in classes[sub]]
     for sub in subsets]
    for sub in subsets:
        old_sub_path = os.path.join(data_path, sub)
        new_sub_path = os.path.join(few_shot_path, sub)
        for c in classes[sub]:
            old_dir = os.path.join(old_sub_path, c)
            new_dir = os.path.join(new_sub_path, c)
            imgs = os.listdir(old_dir)
            random.shuffle(imgs)
            # 600 support imgs -> 120 batches of 5-way
            # 120 imgs for query
            # 720 in total
            new_imgs = imgs[:720]
            [shutil.copyfile(os.path.join(old_dir, img), os.path.join(new_dir, img))
             for img in new_imgs]



def get_dataset(
        path: str,
        batch_size: int,
        n_shots: int,
        shuffle: bool,
        use_augmentations: bool,
) -> DataLoader:
    augmentations = ([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomErasing()
    ] if use_augmentations else [])
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=CINIC_MEAN,std=CINIC_STD),
                                    *augmentations])

    ds = torchvision.datasets.ImageFolder(path, transform=transform)
    sampler = FewShotSampler(ds, n_shots, shuffle=shuffle)
    n_way = len(ds.classes)
    episode_size = batch_size * n_way * (n_shots + 1)
    loader = DataLoader(ds, sampler=sampler, batch_size=episode_size, num_workers=2, pin_memory=True)
    return loader


def get_cinic_few(
        data_path: str,
        batch_size: int,
        n_shots: int,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    train_path = os.path.join(data_path, "train")
    valid_path = os.path.join(data_path, "valid")
    test_path = os.path.join(data_path, "test")

    cinic_train = get_dataset(train_path, batch_size, n_shots, True, True)
    cinic_validation = get_dataset(valid_path, batch_size, n_shots, False, False)
    cinic_test = get_dataset(test_path, batch_size, n_shots, False, False)

    return cinic_train, cinic_validation, cinic_test


def parse_episode_batch(episode_samples: torch.Tensor, episode_labels: torch.Tensor, n_way: int, n_shot: int):
    pass

In [None]:
# prepare_data_folder(DATA_PATH, FEW_SHOT_PATH)

In [7]:
len(os.listdir(os.path.join(FEW_SHOT_PATH, "train", "bird")))

720

In [41]:
n_way = 2
n_shots = 5
ds = torchvision.datasets.ImageFolder(os.path.join(FEW_SHOT_PATH, "test"), transform=transforms.ToTensor())
print(len(ds))
sampler = FewShotSampler(ds, n_shots)
loader = DataLoader(ds, sampler=sampler, batch_size=(n_way * n_shots + n_way))

for samples, labels in loader:
    support = samples[:n_way * n_shots]
    support_labels = labels[:n_way * n_shots]
    support_labels = support_labels.view(1, n_way, n_shots)
    print(support_labels)
    break
    # print(samples.shape, label.shape)


1440
0 : 720
1 : 720
1440
1440 120
0 0 (l, u): 0 5
support: 5
0 1 (l, u): 0 5
support: 5
query; 2
1 0 (l, u): 6 11
support: 5
1 1 (l, u): 6 11
support: 5
query; 2
2 0 (l, u): 12 17
support: 5
2 1 (l, u): 12 17
support: 5
query; 2
3 0 (l, u): 18 23
support: 5
3 1 (l, u): 18 23
support: 5
query; 2
4 0 (l, u): 24 29
support: 5
4 1 (l, u): 24 29
support: 5
query; 2
5 0 (l, u): 30 35
support: 5
5 1 (l, u): 30 35
support: 5
query; 2
6 0 (l, u): 36 41
support: 5
6 1 (l, u): 36 41
support: 5
query; 2
7 0 (l, u): 42 47
support: 5
7 1 (l, u): 42 47
support: 5
query; 2
8 0 (l, u): 48 53
support: 5
8 1 (l, u): 48 53
support: 5
query; 2
9 0 (l, u): 54 59
support: 5
9 1 (l, u): 54 59
support: 5
query; 2
10 0 (l, u): 60 65
support: 5
10 1 (l, u): 60 65
support: 5
query; 2
11 0 (l, u): 66 71
support: 5
11 1 (l, u): 66 71
support: 5
query; 2
12 0 (l, u): 72 77
support: 5
12 1 (l, u): 72 77
support: 5
query; 2
13 0 (l, u): 78 83
support: 5
13 1 (l, u): 78 83
support: 5
query; 2
14 0 (l, u): 84 89
suppor

In [89]:
ds.classes

['ship', 'truck']

In [None]:
set_seed(1)
ds_train, ds_valid, ds_test = get_cinic_few(FEW_SHOT_PATH, 32, 5)

[tensor([[[[-0.0343, -0.0505, -0.0505,  ...,  0.0000,  0.0000,  0.0000],
           [-0.0181, -0.0505, -0.0505,  ...,  0.0000,  0.0000,  0.0000],
           [-0.0019, -0.0181, -0.0343,  ...,  0.0000,  0.0000,  0.0000],
           ...,
           [-0.7472, -0.6824, -0.0505,  ...,  0.0000,  0.0000,  0.0000],
           [-0.6986, -0.6014, -0.6662,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
          [[ 0.1246,  0.1081,  0.1081,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.1739,  0.1410,  0.1410,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.2233,  0.1904,  0.1739,  ...,  0.0000,  0.0000,  0.0000],
           ...,
           [-0.7642, -0.6983, -0.0071,  ...,  0.0000,  0.0000,  0.0000],
           [-0.5996, -0.5008, -0.5831,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
          [[-0.6785, -0.6937, -0.6937,  ...,  0.0000,  0.0000,  0.0000],
           [-

In [54]:
samples, labels = next(iter(ds_train))
print(labels.tolist())

[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0, 1, 1, 1, 1, 

In [95]:
ds_train.dataset.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [None]:
can = CrossAttentionNetwork().cuda()
criterion = CANLoss(can, n_classes=M).cuda()
optimizer = SGD(CANLoss.parameters())

for