In [29]:
import torch
import torch.nn.functional as F
import torch
import torch.nn.functional as F
from typing import List
from collections import defaultdict
from typing import Tuple


def to_one_hot(mask: torch.tensor,
               num_classes: int) -> torch.tensor:
    """
    inputs:
        mask : shape [n_task, shot, h, w]
        num_classes : Number of classes

    returns :
        one_hot_mask : shape [n_task, shot, num_class, h, w]
    """
    n_tasks, shot, h, w = mask.size()
    one_hot_mask = torch.zeros(n_tasks, shot, num_classes, h, w)
    new_mask = mask.unsqueeze(2).clone()
    new_mask[torch.where(new_mask == 255)] = 0  # Ignore_pixels are anyways filtered out in the losses
    one_hot_mask.scatter_(2, new_mask, 1).long()
    return one_hot_mask
def intersectionAndUnion(output: torch.Tensor, target: torch.Tensor, num_classes: int, ignore_index=255) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Calculate intersection and union for single task and shot.

    Args:
        output (torch.Tensor): Predicted tensor.
        target (torch.Tensor): Ground truth tensor.
        num_classes (int): Number of classes.
        ignore_index (int, optional): Index to ignore in evaluation. Defaults to 255.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Intersection, Union, Target area.
    """
    ignore_mask = target != ignore_index
    output = output[ignore_mask]
    target = target[ignore_mask]

    intersection = torch.zeros(num_classes)
    union = torch.zeros(num_classes)
    target_area = torch.zeros(num_classes)

    for cls in range(num_classes):
        output_cls = output == cls
        target_cls = target == cls

        intersection[cls] = torch.sum(output_cls & target_cls)
        union[cls] = torch.sum(output_cls | target_cls)
        target_area[cls] = torch.sum(target_cls)

    return intersection, union, target_area

def batch_intersectionAndUnionGPU(logits: torch.Tensor, target: torch.Tensor, num_classes: int, ignore_index=255) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Calculate intersection and union for batch of tasks and shots.

    Args:
        logits (torch.Tensor): Predicted logits tensor.
        target (torch.Tensor): Ground truth tensor.
        num_classes (int): Number of classes.
        ignore_index (int, optional): Index to ignore in evaluation. Defaults to 255.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Intersection, Union, Target area for each task and shot.
    """
    n_task, shots, _, h, w = logits.size()
    H, W = target.size()[-2:]

    logits = F.interpolate(logits.view(n_task * shots, num_classes, h, w), size=(H, W), mode='bilinear', align_corners=True).view(n_task, shots, num_classes, H, W)
    preds = logits.argmax(2)  # [n_task, shot, H, W]

    area_intersection = torch.zeros(n_task, shots, num_classes)
    area_union = torch.zeros(n_task, shots, num_classes)
    area_target = torch.zeros(n_task, shots, num_classes)

    for task in range(n_task):
        for shot in range(shots):
            i, u, t = intersectionAndUnion(preds[task][shot], target[task][shot], num_classes, ignore_index=ignore_index)
            area_intersection[task, shot, :] = i
            area_union[task, shot, :] = u
            area_target[task, shot, :] = t

    return area_intersection, area_union, area_target


In [30]:



class Classifier(object):
    def __init__(self):
        self.num_classes = 2
        self.temperature = 20
        self.adapt_iter = 50
        self.weights = [1.0, 'auto', 'auto']
        self.lr = 0.025
        self.FB_param_update = [10]
        self.visdom_freq = 5
        self.FB_param_type = 'soft'
        self.FB_param_noise = 0

    def init_prototypes(self, features_s: torch.tensor, features_q: torch.tensor,
                        gt_s: torch.tensor, gt_q: torch.tensor, subcls: List[int],
                        callback) -> None:
        """
        inputs:
            features_s : shape [n_task, shot, c, h, w]
            features_q : shape [n_task, 1, c, h, w]
            gt_s : shape [n_task, shot, H, W]
            gt_q : shape [n_task, 1, H, W]

        returns :
            prototypes : shape [n_task, c]
            bias : shape [n_task]
        """

        # DownSample support masks
        n_task, shot, c, h, w = features_s.size()
        ds_gt_s = F.interpolate(gt_s.float(), size=features_s.shape[-2:], mode='nearest')
        ds_gt_s = ds_gt_s.long().unsqueeze(2)  # [n_task, shot, 1, h, w]

        # Computing prototypes
        fg_mask = (ds_gt_s == 1)
        fg_prototype = (features_s * fg_mask).sum(dim=(1, 3, 4))
        fg_prototype /= (fg_mask.sum(dim=(1, 3, 4)) + 1e-10)  # [n_task, c]
        self.prototype = fg_prototype

        logits_q = self.get_logits(features_q)  # [n_tasks, shot, h, w]
        self.bias = logits_q.mean(dim=(1, 2, 3))

        assert self.prototype.size() == (n_task, c), self.prototype.size()
        assert torch.isnan(self.prototype).sum() == 0, self.prototype

        if callback is not None:
            self.update_callback(callback, 0, features_s, features_q, subcls, gt_s, gt_q)

    def get_logits(self, features: torch.tensor) -> torch.tensor:

        """
        Computes the cosine similarity between self.prototype and given features
        inputs:
            features : shape [n_tasks, shot, c, h, w]

        returns :
            logits : shape [n_tasks, shot, h, w]
        """

        # Put prototypes and features in the right shape for multiplication
        features = features.permute((0, 1, 3, 4, 2))  # [n_task, shot, h, w, c]
        prototype = self.prototype.unsqueeze(1).unsqueeze(2)  # [n_tasks, 1, 1, c]

        # Compute cosine similarity
        print('feature', features.size(), 'prototype', prototype.unsqueeze(4).size())
        print("features.matmul(prototype.unsqueeze(4)", features.matmul(prototype.unsqueeze(4)).shape)
        cossim = features.matmul(prototype.unsqueeze(4)).squeeze(4)  # [n_task, shot, h, w]
        cossim /= ((prototype.unsqueeze(3).norm(dim=4) * \
                    features.norm(dim=4)) + 1e-10)  # [n_tasks, shot, h, w]

        return self.temperature * cossim

    def get_probas(self, logits: torch.tensor) -> torch.tensor:
        """
        inputs:
            logits : shape [n_tasks, shot, h, w]

        returns :
            probas : shape [n_tasks, shot, num_classes, h, w]
        """
        logits_fg = logits - self.bias.unsqueeze(1).unsqueeze(2).unsqueeze(3)  # [n_tasks, shot, h, w]
        probas_fg = torch.sigmoid(logits_fg).unsqueeze(2)
        probas_bg = 1 - probas_fg
        probas = torch.cat([probas_bg, probas_fg], dim=2)
        return probas

    def compute_FB_param(self, features_q: torch.tensor, gt_q: torch.tensor) -> torch.tensor:
        """
        inputs:
            features_q : shape [n_tasks, shot, c, h, w]
            gt_q : shape [n_tasks, shot, h, w]

        updates :
             self.FB_param : shape [n_tasks, num_classes]
        """
        ds_gt_q = F.interpolate(gt_q.float(), size=features_q.size()[-2:], mode='nearest').long()
        valid_pixels = (ds_gt_q != 255).unsqueeze(2)  # [n_tasks, shot, num_classes, h, w]
        assert (valid_pixels.sum(dim=(1, 2, 3, 4)) == 0).sum() == 0, valid_pixels.sum(dim=(1, 2, 3, 4))

        one_hot_gt_q = to_one_hot(ds_gt_q, self.num_classes)  # [n_tasks, shot, num_classes, h, w]

        oracle_FB_param = (valid_pixels * one_hot_gt_q).sum(dim=(1, 3, 4)) / valid_pixels.sum(dim=(1, 3, 4))
        logits_q = self.get_logits(features_q)
        probas = self.get_probas(logits_q).detach()
        self.FB_param = (valid_pixels * probas).sum(dim=(1, 3, 4))
        self.FB_param /= valid_pixels.sum(dim=(1, 3, 4))

        # Compute the relative error
        deltas = self.FB_param[:, 1] / oracle_FB_param[:, 1] - 1
        return deltas

    def get_entropies(self,
                      valid_pixels: torch.tensor,
                      probas: torch.tensor,
                      reduction='sum') -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
        """
        inputs:
            probas : shape [n_tasks, shot, num_class, h, w]
            valid_pixels: shape [n_tasks, shot, h, w]

        returns:
            d_kl : FB proportion kl [n_tasks,]
            cond_entropy : Entropy of predictions [n_tasks,]
            marginal : Current marginal distribution over labels [n_tasks, num_classes]
        """
        n_tasks, shot, num_classes, h, w = probas.size()
        assert (valid_pixels.sum(dim=(1, 2, 3)) == 0).sum() == 0, \
               (valid_pixels.sum(dim=(1, 2, 3)) == 0).sum()  # Make sure all tasks have a least 1 valid pixel

        cond_entropy = - ((valid_pixels.unsqueeze(2) * (probas * torch.log(probas + 1e-10))).sum(2))
        cond_entropy = cond_entropy.sum(dim=(1, 2, 3))
        cond_entropy /= valid_pixels.sum(dim=(1, 2, 3))

        marginal = (valid_pixels.unsqueeze(2) * probas).sum(dim=(1, 3, 4))
        marginal /= valid_pixels.sum(dim=(1, 2, 3)).unsqueeze(1)

        d_kl = (marginal * torch.log(marginal / (self.FB_param + 1e-10))).sum(1)

        if reduction == 'sum':
            cond_entropy = cond_entropy.sum(0)
            d_kl = d_kl.sum(0)
            assert not torch.isnan(cond_entropy), cond_entropy
            assert not torch.isnan(d_kl), d_kl
        elif reduction == 'mean':
            cond_entropy = cond_entropy.mean(0)
            d_kl = d_kl.mean(0)
        return d_kl, cond_entropy, marginal

    def get_ce(self,
               probas: torch.tensor,
               valid_pixels: torch.tensor,
               one_hot_gt: torch.tensor,
               reduction: str = 'sum') -> torch.tensor:
        """
        inputs:
            probas : shape [n_tasks, shot, c, h, w]
            one_hot_gt: shape [n_tasks, shot, num_classes, h, w]
            valid_pixels : shape [n_tasks, shot, h, w]

        updates :
             ce : Cross-Entropy between one_hot_gt and probas, shape [n_tasks,]
        """
        ce = - ((valid_pixels.unsqueeze(2) * (one_hot_gt * torch.log(probas + 1e-10))).sum(2))  # [n_tasks, shot, h, w]
        ce = ce.sum(dim=(1, 2, 3))  # [n_tasks]
        ce /= valid_pixels.sum(dim=(1, 2, 3))
        if reduction == 'sum':
            ce = ce.sum(0)
        elif reduction == 'mean':
            ce = ce.mean(0)
        return ce

    def RePRI(self,
              features_s: torch.tensor,
              features_q: torch.tensor,
              gt_s: torch.tensor,
              gt_q: torch.tensor,
              subcls: List,
              n_shots: torch.tensor,
              callback) -> torch.tensor:
        """
        Performs RePRI inference

        inputs:
            features_s : shape [n_tasks, shot, c, h, w]
            features_q : shape [n_tasks, shot, c, h, w]
            gt_s : shape [n_tasks, shot, h, w]
            gt_q : shape [n_tasks, shot, h, w]
            subcls : List of classes present in each task
            n_shots : # of support shots for each task, shape [n_tasks,]

        updates :
            prototypes : torch.Tensor of shape [n_tasks, num_class, c]

        returns :
            deltas : Relative error on FB estimation right after first update, for each task,
                     shape [n_tasks,]
        """
        deltas = torch.zeros_like(n_shots)
        l1, l2, l3 = self.weights
        if l2 == 'auto':
            l2 = 1 / n_shots
        else:
            l2 = l2 * torch.ones_like(n_shots)
        if l3 == 'auto':
            l3 = 1 / n_shots
        else:
            l3 = l3 * torch.ones_like(n_shots)

        self.prototype.requires_grad_()
        self.bias.requires_grad_()
        optimizer = torch.optim.SGD([self.prototype, self.bias], lr=self.lr)

        ds_gt_q = F.interpolate(gt_q.float(), size=features_s.size()[-2:], mode='nearest').long()
        ds_gt_s = F.interpolate(gt_s.float(), size=features_s.size()[-2:], mode='nearest').long()

        valid_pixels_q = (ds_gt_q != 255).float()  # [n_tasks, shot, h, w]
        valid_pixels_s = (ds_gt_s != 255).float()  # [n_tasks, shot, h, w]

        one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes)  # [n_tasks, shot, num_classes, h, w]

        for iteration in range(1, self.adapt_iter):

            logits_s = self.get_logits(features_s)  # [n_tasks, shot, num_class, h, w]
            logits_q = self.get_logits(features_q)  # [n_tasks, 1, num_class, h, w]
            proba_q = self.get_probas(logits_q)
            proba_s = self.get_probas(logits_s)

            d_kl, cond_entropy, marginal = self.get_entropies(valid_pixels_q,
                                                              proba_q,
                                                              reduction='none')
            ce = self.get_ce(proba_s, valid_pixels_s, one_hot_gt_s, reduction='none')
            loss = l1 * ce + l2 * d_kl + l3 * cond_entropy

            optimizer.zero_grad()
            loss.sum(0).backward()
            optimizer.step()

            # Update FB_param
            if (iteration + 1) in self.FB_param_update  \
                    and ('oracle' not in self.FB_param_type) and (l2.sum().item() != 0):
                deltas = self.compute_FB_param(features_q, gt_q).cpu()
                l2 += 1

            if callback is not None and (iteration + 1) % self.visdom_freq == 0:
                self.update_callback(callback, iteration, features_s, features_q, subcls, gt_s, gt_q)
        return deltas

    def get_mIoU(self,
                 probas: torch.tensor,
                 gt: torch.tensor,
                 subcls: torch.tensor,
                 reduction: str = 'mean') -> torch.tensor:
        """
        Computes the mIoU over the current batch of tasks being processed

        inputs:
            probas : shape [n_tasks, shot, num_class, h, w]
            gt : shape [n_tasks, shot, h, w]
            subcls : List of classes present in each task


        returns :
            class_IoU : Classwise IoU (or mean of it), shape
        """
        intersection, union, _ = batch_intersectionAndUnionGPU(probas, gt, self.num_classes)  # [num_tasks, shot, num_class]
        inter_count = defaultdict(int)
        union_count = defaultdict(int)

        for i, classes_ in enumerate(subcls):
            inter_count[0] += intersection[i, 0, 0]
            union_count[0] += union[i, 0, 0]
            for j, class_ in enumerate(classes_):
                inter_count[class_] += intersection[i, 0, j + 1]  # Do not count background
                union_count[class_] += union[i, 0, j + 1]
        class_IoU = torch.tensor([inter_count[subcls] / union_count[subcls] for subcls in inter_count if subcls != 0])
        if reduction == 'mean':
            return class_IoU.mean()
        elif reduction == 'none':
            return class_IoU

    # def update_callback(self, callback, iteration: int, features_s: torch.tensor,
    #                     features_q: torch.tensor, subcls: List[int],
    #                     gt_s: torch.tensor, gt_q: torch.tensor) -> None:
    #     """
    #     Updates the visdom callback in case live visualization of metrics is desired

    #     inputs:
    #         iteration: Current inference iteration
    #         features_s : shape [n_tasks, shot, c, h, w]
    #         features_q : shape [n_tasks, shot, c, h, w]
    #         gt_s : shape [n_tasks, shot, h, w]
    #         gt_q : shape [n_tasks, shot, h, w]
    #         subcls : List of classes present in each task


    #     returns :
    #         callback : Visdom logger
    #     """
    #     logits_q = self.get_logits(features_q)  # [n_tasks, shot, num_class, h, w]
    #     logits_s = self.get_logits(features_s)  # [n_tasks, shot, num_class, h, w]
    #     proba_q = self.get_probas(logits_q).detach()  # [n_tasks, shot, num_class, h, w]
    #     proba_s = self.get_probas(logits_s).detach()  # [n_tasks, shot, num_class, h, w]

    #     f_resolution = features_s.size()[-2:]
    #     ds_gt_q = F.interpolate(gt_q.float(), size=f_resolution, mode='nearest').long()
        # ds_gt_s = F.interpolate(gt_s.float(), size=f_resolution, mode='nearest').long()

        # valid_pixels_q = (ds_gt_q != 255).float()  # [n_tasks, shot, h, w]
        # valid_pixels_s = (ds_gt_s != 255).float()  # [n_tasks, shot, h, w]

        # one_hot_gt_q = to_one_hot(ds_gt_q, self.num_classes)  # [n_tasks, shot, num_classes, h, w]
        # oracle_FB_param = (valid_pixels_q.unsqueeze(2) * one_hot_gt_q).sum(dim=(1, 3, 4))
        # oracle_FB_param /= (valid_pixels_q.unsqueeze(2)).sum(dim=(1, 3, 4))

        # one_hot_gt_s = to_one_hot(ds_gt_s, self.num_classes)  # [n_tasks, shot, num_classes, h, w]
        # ce_s = self.get_ce(proba_s, valid_pixels_s, one_hot_gt_s)
        # ce_q = self.get_ce(proba_q, valid_pixels_q, one_hot_gt_q)

        # mIoU_q = self.get_mIoU(proba_q, gt_q, subcls)

        # callback.scalar('mIoU_q', iteration, mIoU_q, title='mIoU')
        # if iteration > 0:
        #     d_kl, cond_entropy, marginal = self.get_entropies(valid_pixels_q,
        #                                                       proba_q,
        #                                                       reduction='mean')
        #     marginal2oracle = (oracle_FB_param * torch.log(oracle_FB_param / marginal + 1e-10)).sum(1).mean()
        #     FB_param2oracle = (oracle_FB_param * torch.log(oracle_FB_param / self.FB_param + 1e-10)).sum(1).mean()
        #     callback.scalars(['Cond', 'marginal2oracle', 'FB_param2oracle'], iteration,
        #                      [cond_entropy, marginal2oracle, FB_param2oracle], title='Entropy')
        # callback.scalars(['ce_s', 'ce_q'], iteration, [ce_s, ce_q], title='CE')

In [31]:
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch


# ==================================================================================================
# Taken from https://github.com/hszhao/semseg/blob/master/model/resnet.py ==========================
# ==================================================================================================

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, deep_base=True):
        super(ResNet, self).__init__()
        self.deep_base = deep_base
        if not self.deep_base:
            self.inplanes = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
        else:
            self.inplanes = 128
            self.conv1 = conv3x3(3, 64, stride=2)
            self.bn1 = nn.BatchNorm2d(64)
            self.conv2 = conv3x3(64, 64)
            self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = conv3x3(64, 128)
            self.bn3 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        if self.deep_base:
            x = self.relu(self.bn2(self.conv2(x)))
            x = self.relu(self.bn3(self.conv3(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [32]:
def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
        model_path = './initmodel/resnet50_v2.pth'
        model.load_state_dict(torch.load(model_path), strict=False)
    return model

In [33]:
model = resnet50(pretrained=False)

In [34]:
inp = torch.randn(1, 3, 224, 224)
out = model(inp)
out.shape

torch.Size([1, 1000])

In [35]:
classifier = Classifier()
batch_size_val, shot, c, h, w = 2, 5, 2048, 7, 7
image_size = 224
gt_s = 255 * torch.rand(batch_size_val, shot, image_size,
                                    image_size).long()
gt_q = 255 * torch.rand(batch_size_val, 1, image_size,
                        image_size).long()
n_shots = torch.rand(batch_size_val)
features_s = torch.rand(batch_size_val, shot, c, h, w)
features_q = torch.rand(batch_size_val, 1, c, h, w)
classifier.init_prototypes(features_s, features_q, gt_s, gt_q, [1, 2], None)
batch_deltas = classifier.compute_FB_param(features_q=features_q, gt_q=gt_q)
features_s.shape, features_q.shape, classifier.prototype.shape, classifier.FB_param.shape

            

feature torch.Size([2, 1, 7, 7, 2048]) prototype torch.Size([2, 1, 1, 2048, 1])
features.matmul(prototype.unsqueeze(4) torch.Size([2, 1, 7, 7, 1])
feature torch.Size([2, 1, 7, 7, 2048]) prototype torch.Size([2, 1, 1, 2048, 1])
features.matmul(prototype.unsqueeze(4) torch.Size([2, 1, 7, 7, 1])


(torch.Size([2, 5, 2048, 7, 7]),
 torch.Size([2, 1, 2048, 7, 7]),
 torch.Size([2, 2048]),
 torch.Size([2, 2]))