In [311]:
"""
Fewshot Semantic Segmentation
"""

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

# make a dummy Encoder class for pytorch:
class Encoder(nn.Module):
    def __init__(self, in_channels, pretrained_path):
        super(Encoder, self).__init__()
        self.pretrained_path = pretrained_path
        self.in_channels = in_channels
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2, ceil_mode=True),
        )

    def forward(self, x):
        return self.encoder(x)


class FewShotSeg(nn.Module):
    """
    Fewshot Segmentation model

    Args:
        in_channels:
            number of input channels
        pretrained_path:
            path of the model for initialization
        cfg:
            model configurations
    """
    def __init__(self, in_channels=3, pretrained_path=None, cfg=None):
        super().__init__()
        self.pretrained_path = pretrained_path
        self.config = cfg or {'align': True}

        # Encoder
        self.encoder = nn.Sequential(OrderedDict([
            ('backbone', Encoder(in_channels, self.pretrained_path)),]))


    def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs):
        """
        Args:
            supp_imgs: support images
                way x shot x [B x 3 x H x W], list of lists of tensors
            fore_mask: foreground masks for support images
                way x shot x [B x H x W], list of lists of tensors
            back_mask: background masks for support images
                way x shot x [B x H x W], list of lists of tensors
            qry_imgs: query images
                N x [B x 3 x H x W], list of tensors
        """
        n_ways = len(supp_imgs)
        n_shots = len(supp_imgs[0])
        n_queries = len(qry_imgs)
        batch_size = supp_imgs[0][0].shape[0]
        img_size = supp_imgs[0][0].shape[-2:]
        print('n_ways:', n_ways, 'n_shots:', n_shots, 'n_queries:', n_queries, 'batch_size:', batch_size, 'img_size:', img_size)
        ###### Extract features ######
        print('supp_imgs shape:', supp_imgs[0][0].shape)
        imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs]
                                + [torch.cat(qry_imgs, dim=0),], dim=0)
        print('after concat', imgs_concat.shape, 'which is (Wa*Sh + N) x B x 3 x H x W, Wa =', n_ways, 'Sh =', n_shots, 'N =', n_queries, 'B =', batch_size)
        img_fts = self.encoder(imgs_concat)
        print('after go through encoder', img_fts.shape)
        fts_size = img_fts.shape[-2:]
        print('fts_size:', fts_size)

        supp_fts = img_fts[:n_ways * n_shots * batch_size].view(
            n_ways, n_shots, batch_size, -1, *fts_size)  # Wa x Sh x B x C x H' x W'
        print('supp_fts shape:', supp_fts.shape)
        qry_fts = img_fts[n_ways * n_shots * batch_size:].view(
            n_queries, batch_size, -1, *fts_size)   # N x B x C x H' x W'
        print('qry_fts shape:', qry_fts.shape)
        fore_mask = torch.stack([torch.stack(way, dim=0)
                                 for way in fore_mask], dim=0)  # Wa x Sh x B x H x W
        print('fore_mask shape:', fore_mask.shape)
        back_mask = torch.stack([torch.stack(way, dim=0)
                                 for way in back_mask], dim=0)  # Wa x Sh x B x H x W
        print('back_mask shape:', back_mask.shape)
        

        ###### Compute loss ######
        align_loss = 0
        outputs = []
        for epi in range(batch_size):
            ###### Extract prototype ######
            supp_fg_fts = [[self.getFeatures(supp_fts[way, shot, [epi]],
                                             fore_mask[way, shot, [epi]])
                            for shot in range(n_shots)] for way in range(n_ways)]
            supp_bg_fts = [[self.getFeatures(supp_fts[way, shot, [epi]],
                                             back_mask[way, shot, [epi]])
                            for shot in range(n_shots)] for way in range(n_ways)]
            print('supp_fg_fts shape:', len(supp_fg_fts), len(supp_fg_fts[0]), supp_fg_fts[0][0].shape)
            ###### Obtain the prototypes######
            fg_prototypes, bg_prototype = self.getPrototype(supp_fg_fts, supp_bg_fts)
            print('fg_prototypes shape:', len(fg_prototypes), fg_prototypes[0].shape)
            print('bg_prototype shape:', bg_prototype.shape)

            ###### Compute the distance ######
            prototypes = [bg_prototype,] + fg_prototypes
            print('prototypes shape:', len(prototypes), prototypes[0].shape)
            dist = [self.calDist(qry_fts[:, epi], prototype) for prototype in prototypes]
            print('dist shape:', len(dist), dist[0].shape)
            pred = torch.stack(dist, dim=1)  # N x (1 + Wa) x H' x W'
            print('pred shape:', pred.shape)
            outputs.append(F.interpolate(pred, size=img_size, mode='bilinear'))
            print('UPSAMPLING: outputs shape:', len(outputs), outputs[0].shape)

            ###### Prototype alignment loss ######
            if self.config['align'] :
                align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi],
                                                fore_mask[:, :, epi], back_mask[:, :, epi])
                align_loss += align_loss_epi

        output = torch.stack(outputs, dim=1)  # N x B x (1 + Wa) x H x W
        print('FINAL OUTPUT shape N x B x (1 + Wa) x H x W:', output.shape)
        output = output.view(-1, *output.shape[2:])
        print(output.shape)
        return output, align_loss / batch_size


    def calDist(self, fts, prototype, scaler=20):
        """
        Calculate the distance between features and prototypes

        Args:
            fts: input features
                expect shape: N x C x H x W
            prototype: prototype of one semantic class
                expect shape: 1 x C
        """
        print('CALDIST fts shape:', fts.shape, 'prototype shape:', prototype.shape)
        dist = F.cosine_similarity(fts, prototype[..., None, None], dim=1) * scaler
        return dist


    def getFeatures(self, fts, mask):
        """
        Extract foreground and background features via masked average pooling

        Args:
            fts: input features, expect shape: 1 x C x H' x W'
            mask: binary mask, expect shape: 1 x H x W
        """
        fts = F.interpolate(fts, size=mask.shape[-2:], mode='bilinear')
        masked_fts = torch.sum(fts * mask[None, ...], dim=(2, 3)) \
            / (mask[None, ...].sum(dim=(2, 3)) + 1e-5) # 1 x C
        return masked_fts


    def getPrototype(self, fg_fts, bg_fts):
        """
        Average the features to obtain the prototype

        Args:
            fg_fts: lists of list of foreground features for each way/shot
                expect shape: Wa x Sh x [1 x C]
            bg_fts: lists of list of background features for each way/shot
                expect shape: Wa x Sh x [1 x C]
        """
        n_ways, n_shots = len(fg_fts), len(fg_fts[0])
        fg_prototypes = [sum(way) / n_shots for way in fg_fts]
        bg_prototype = sum([sum(way) / n_shots for way in bg_fts]) / n_ways
        return fg_prototypes, bg_prototype


    def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask):
        """
        Compute the loss for the prototype alignment branch

        Args:
            qry_fts: embedding features for query images
                expect shape: N x C x H' x W'
            pred: predicted segmentation score
                expect shape: N x (1 + Wa) x H x W
            supp_fts: embedding features for support images
                expect shape: Wa x Sh x C x H' x W'
            fore_mask: foreground masks for support images
                expect shape: way x shot x H x W
            back_mask: background masks for support images
                expect shape: way x shot x H x W
        """
        n_ways, n_shots = len(fore_mask), len(fore_mask[0])

        # Mask and get query prototype
        pred_mask = pred.argmax(dim=1, keepdim=True)  # N x 1 x H' x W'
        binary_masks = [pred_mask == i for i in range(1 + n_ways)]
        skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
        pred_mask = torch.stack(binary_masks, dim=1).float()  # N x (1 + Wa) x 1 x H' x W'
        qry_prototypes = torch.sum(qry_fts.unsqueeze(1) * pred_mask, dim=(0, 3, 4))
        qry_prototypes = qry_prototypes / (pred_mask.sum((0, 3, 4)) + 1e-5)  # (1 + Wa) x C
        print('qry_prototypes shape:', qry_prototypes.shape)
        # Compute the support loss
        loss = 0
        for way in range(n_ways):
            if way in skip_ways:
                continue
            # Get the query prototypes
            prototypes = [qry_prototypes[[0]], qry_prototypes[[way + 1]]]
            for shot in range(n_shots):
                img_fts = supp_fts[way, [shot]]
                supp_dist = [self.calDist(img_fts, prototype) for prototype in prototypes]
                supp_pred = torch.stack(supp_dist, dim=1)
                supp_pred = F.interpolate(supp_pred, size=fore_mask.shape[-2:],
                                          mode='bilinear')
                # Construct the support Ground-Truth segmentation
                supp_label = torch.full_like(fore_mask[way, shot], 255,
                                             device=img_fts.device).long()
                supp_label[fore_mask[way, shot] == 1] = 1
                supp_label[back_mask[way, shot] == 1] = 0
                # Compute Loss
                loss = loss + F.cross_entropy(
                    supp_pred, supp_label[None, ...], ignore_index=255) / n_shots / n_ways
        return loss

In [312]:
# create dummy model for testing
model = FewShotSeg()
model.eval()


FewShotSeg(
  (encoder): Sequential(
    (backbone): Encoder(
      (encoder): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): MaxPool2d(kernel_size=2, stride=2, padding=0,

In [313]:
# create dummy input
supp_imgs = [[torch.rand(2, 3, 125, 125) for _ in range(5)] for _ in range(5)]
fore_mask = [[torch.rand(2, 125, 125) for _ in range(5)] for _ in range(5)]
back_mask = [[torch.rand(2, 125, 125) for _ in range(5)] for _ in range(5)] # 5 ways, 5 shots, 2 batch size
qry_imgs = [torch.rand(2, 3, 125, 125) for _ in range(5)] # 5 images, 2 batch size

In [314]:
# run the model
output, align_loss = model(supp_imgs, fore_mask, back_mask, qry_imgs)

n_ways: 5 n_shots: 5 n_queries: 5 batch_size: 2 img_size: torch.Size([125, 125])
supp_imgs shape: torch.Size([2, 3, 125, 125])
after concat torch.Size([60, 3, 125, 125]) which is (Wa*Sh + N) x B x 3 x H x W, Wa = 5 Sh = 5 N = 5 B = 2
after go through encoder torch.Size([60, 512, 4, 4])
fts_size: torch.Size([4, 4])
supp_fts shape: torch.Size([5, 5, 2, 512, 4, 4])
qry_fts shape: torch.Size([5, 2, 512, 4, 4])
fore_mask shape: torch.Size([5, 5, 2, 125, 125])
back_mask shape: torch.Size([5, 5, 2, 125, 125])
supp_fg_fts shape: 5 5 torch.Size([1, 512])
fg_prototypes shape: 5 torch.Size([1, 512])
bg_prototype shape: torch.Size([1, 512])
prototypes shape: 6 torch.Size([1, 512])
CALDIST fts shape: torch.Size([5, 512, 4, 4]) prototype shape: torch.Size([1, 512])
CALDIST fts shape: torch.Size([5, 512, 4, 4]) prototype shape: torch.Size([1, 512])
CALDIST fts shape: torch.Size([5, 512, 4, 4]) prototype shape: torch.Size([1, 512])
CALDIST fts shape: torch.Size([5, 512, 4, 4]) prototype shape: torch.S

In [318]:
import numpy as np


output.shape, align_loss, np.array(output.argmax(dim=1).cpu().shape)

(torch.Size([10, 6, 125, 125]),
 tensor(nan, grad_fn=<DivBackward0>),
 array([ 10, 125, 125]))

In [307]:
# rebuild the whole thing to work with itme series: original input is 5x5x2x3x125x125 -> 5x5x2x6x100
# 5x5x2x3x125x125 -> 5x5x2x6x100, mask would be 5x5x2x125x125 -> 5x5x2x100

# create dummy encoder for time series (6x100)
# class time_Encoder(nn.Module):
#     # input could be Batchx6x100, just use very simple network with conv1d:
#     def __init__(self, in_channels, pretrained_path):
#         super(time_Encoder, self).__init__()
#         self.pretrained_path = pretrained_path
#         self.in_channels = in_channels
#         self.encoder = nn.Sequential(
#             nn.Linear(6, 512),
#         )
    
#     def forward(self, x):
#         x = x.transpose(1, 2)
#         x = self.encoder(x)
#         x = x.transpose(1, 2)
#         x = nn.Linear(100, 2)(x)
#         return  x
    
class time_Encoder(nn.Module):
    # input could be Batchx6x100, just use very simple network with conv1d:
    def __init__(self, in_channels, pretrained_path):
        super(time_Encoder, self).__init__()
        self.pretrained_path = pretrained_path
        self.in_channels = in_channels
        self.encoder = nn.Sequential(
            nn.Linear(6, 512),
        )
    
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.encoder(x).transpose(1, 2)
        return  x
# create dummy model for testing
class time_FewShotSeg(nn.Module):
    """
    Fewshot Segmentation model

    Args:
        in_channels:
            number of input channels
        pretrained_path:
            path of the model for initialization
        cfg:
            model configurations
    """
    def __init__(self, in_channels=6, pretrained_path=None, cfg=None):
        super().__init__()
        self.pretrained_path = pretrained_path
        self.config = cfg or {'align': True}
        

        # Encoder
        self.encoder = nn.Sequential(OrderedDict([
            ('backbone', time_Encoder(in_channels, self.pretrained_path)),]))

    def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs):
        """
        Args:
            supp_imgs: support images
                way x shot x [B x 6 x 100], list of lists of tensors
            fore_mask: foreground masks for support images
                way x shot x [B x 100], list of lists of tensors
            back_mask: background masks for support images
                way x shot x [B x 100], list of lists of tensors
            qry_imgs: query images
                N x [B x 6 x 100], list of tensors
        """
        n_ways = len(supp_imgs)
        n_shots = len(supp_imgs[0])
        n_queries = len(qry_imgs)
        batch_size = supp_imgs[0][0].shape[0]
        img_size = supp_imgs[0][0].shape[-1:]
        print('n_ways:', n_ways, 'n_shots:', n_shots, 'n_queries:', n_queries, 'batch_size:', batch_size, 'img_size:', img_size)
        ###### Extract features ######
        print('supp_imgs shape:', supp_imgs[0][0].shape)
        imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs]
                                + [torch.cat(qry_imgs, dim=0),], dim=0)
        print('after concat', imgs_concat.shape, 'which is (Wa*Sh + N) x B x 6 x 100, Wa =', n_ways, 'Sh =', n_shots, 'N =', n_queries, 'B =', batch_size)
        print(imgs_concat.shape)
        img_fts = self.encoder(imgs_concat)
        print('after go through encoder', img_fts.shape)
        fts_size = img_fts.shape[-1:] # this is the length, aka 100 (Time series length)
        print('fts_size:', fts_size)
        
        supp_fts = img_fts[:n_ways * n_shots * batch_size].view(
            n_ways, n_shots, batch_size, -1, *fts_size)
        print('supp_fts shape:', supp_fts.shape)
        qry_fts = img_fts[n_ways * n_shots * batch_size:].view(
            n_queries, batch_size, -1, *fts_size)
        print('qry_fts shape:', qry_fts.shape)
        fore_mask = torch.stack([torch.stack(way, dim=0)
                                 for way in fore_mask], dim=0)
        print('fore_mask shape:', fore_mask.shape)
        back_mask = torch.stack([torch.stack(way, dim=0)
                                 for way in back_mask], dim=0)
        print('back_mask shape:', back_mask.shape)
        

        ###### Compute loss ######
        align_loss = 0
        outputs = []
        for epi in range(batch_size):
            ###### Extract prototype ######
            supp_fg_fts = [[self.getFeatures(supp_fts[way, shot, [epi]],
                                             fore_mask[way, shot, [epi]])
                            for shot in range(n_shots)] for way in range(n_ways)] # 1 x C (channel embedding)
            supp_bg_fts = [[self.getFeatures(supp_fts[way, shot, [epi]],
                                             back_mask[way, shot, [epi]])
                            for shot in range(n_shots)] for way in range(n_ways)] # 1 x C (channel embedding)
            ###### Obtain the prototypes######
            fg_prototypes, bg_prototype = self.getPrototype(supp_fg_fts, supp_bg_fts) 
            print('fg_prototypes shape:', len(fg_prototypes), fg_prototypes[0].shape)
            print('bg_prototype shape:', bg_prototype.shape)

            ###### Compute the distance ######
            prototypes = [bg_prototype,] + fg_prototypes
            print('prototypes shape:', len(prototypes), prototypes[0].shape)
            dist = [self.calDist(qry_fts[:, epi], prototype) for prototype in prototypes]
            print('dist shape:', len(dist), dist[0].shape)
            pred = torch.stack(dist, dim=1)
            print('pred shape:', pred.shape)
            outputs.append(F.interpolate(pred, size=img_size, mode='linear'))
            print('upsampling', F.interpolate(pred, size=img_size, mode='linear').shape)
            print('outputs shape:', len(outputs), outputs[0].shape)

            ###### Prototype alignment loss ######
            if self.config['align']:
                align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi],
                                                fore_mask[:, :, epi], back_mask[:, :, epi])
                align_loss += align_loss_epi
        print('outputs shape:', len(outputs), outputs[0].shape)
        output = torch.stack(outputs, dim=1)  # N x B x (1 + Wa) x H x W
        output = output.view(-1, *output.shape[2:])
        print(output.shape)
        return output, align_loss / batch_size
        
    def calDist(self, fts, prototype, scaler=20):
        """
        Calculate the distance between features and prototypes

        Args:
            fts: input features
                expect shape: N x C x H
            prototype: prototype of one semantic class
                expect shape: 1 x C
        """
        # print('CALDIST fts shape:', fts.shape, 'prototype shape:', prototype.shape)
        dist = F.cosine_similarity(fts, prototype[..., None], dim=1) * scaler
        return dist
        
    def getFeatures(self, fts, mask):
        """
        Extract foreground and background features via masked average pooling

        Args:
            fts: input features, expect shape: 1 x C x H'
            mask: binary mask, expect shape: 1 x H
            
        originally: 
        fts = F.interpolate(fts, size=mask.shape[-2:], mode='bilinear')
        # IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
        masked_fts = torch.sum(fts * mask[None, ...], dim=(2, 3)) \
            / (mask[None, ...].sum(dim=(2)) + 1e-5)
        return masked_fts
        """
        # interpolate on three dim data:
        fts = F.interpolate(fts, size=mask.shape[-1:], mode='linear')
        # print(mask.shape[-1:])
        # fts = fts
        # print(fts.shape)
        # print('fts shape:', fts.shape)x
        masked_fts = torch.sum(fts * mask[None, ...], dim=(2)) \
            / (mask[None, ...].sum(dim=(2)) + 1e-5)
        return masked_fts
        
    def getPrototype(self, fg_fts, bg_fts):
        """
        Average the features to obtain the prototype

        Args:
            fg_fts: lists of list of foreground features for each way/shot
                expect shape: Wa x Sh x [1 x C]
            bg_fts: lists of list of background features for each way/shot
                expect shape: Wa x Sh x [1 x C]
        """
        n_ways, n_shots = len(fg_fts), len(fg_fts[0])
        fg_prototypes = [sum(way) / n_shots for way in fg_fts]
        bg_prototype = sum([sum(way) / n_shots for way in bg_fts]) / n_ways
        return fg_prototypes, bg_prototype
    
    def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask):
        """
        Compute the loss for the prototype alignment branch

        Args:
            qry_fts: embedding features for query images
                expect shape: N x C x H'
            pred: predicted segmentation score
                expect shape: N x (1 + Wa) x H
            supp_fts: embedding features for support images
                expect shape: Wa x Sh x C x H'
            fore_mask: foreground masks for support images
                expect shape: way x shot x H
            back_mask: background masks for support images
                expect shape: way x shot x H
        """
        n_ways, n_shots = len(fore_mask), len(fore_mask[0])

        # Mask and get query prototype
        pred_mask = pred.argmax(dim=1, keepdim=True)
        binary_masks = [pred_mask == i for i in range(1 + n_ways)]
        skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
        pred_mask = torch.stack(binary_masks, dim=1).float()
        # print(pred_mask.shape, qry_fts.shape, (qry_fts.unsqueeze(1) * pred_mask).shape)
        qry_prototypes = torch.sum(qry_fts.unsqueeze(1) * pred_mask, dim=(0, 3))
        qry_prototypes = qry_prototypes / (pred_mask.sum((0, 3)) + 1e-5)
        # print('qry_prototypes shape:', qry_prototypes.shape)
        # Compute the support loss
        loss = 0
        for way in range(n_ways):
            if way in skip_ways:
                continue
            # Get the query prototypes
            prototypes = [qry_prototypes[[0]], qry_prototypes[[way + 1]]]
            for shot in range(n_shots):
                img_fts = supp_fts[way, [shot]]
                supp_dist = [self.calDist(img_fts, prototype) for prototype in prototypes]
                supp_pred = torch.stack(supp_dist, dim=1)
                supp_pred = F.interpolate(supp_pred, size=fore_mask.shape[-1],
                                          mode='linear')
                # Construct the support Ground-Truth segmentation
                supp_label = torch.full_like(fore_mask[way, shot], 255,
                                             device=img_fts.device).long()
                supp_label[fore_mask[way, shot] == 1] = 1
                supp_label[back_mask[way, shot] == 1] = 0
                # Compute Loss
                loss = loss + F.cross_entropy(
                    supp_pred, supp_label[None, ...], ignore_index=255) / n_shots / n_ways
        return loss
                

In [308]:
# create dummy model for testing
model = time_FewShotSeg()
model.eval()

s = time_Encoder(6, None)
s(torch.rand(2, 6, 100)).shape

torch.Size([2, 512, 100])

In [309]:
# create dummy input
supp_imgs = [[torch.rand(1, 6, 100) for _ in range(5)] for _ in range(5)]
fore_mask = [[torch.rand(1, 100) for _ in range(5)] for _ in range(5)]
back_mask = [[torch.rand(1, 100) for _ in range(5)] for _ in range(5)]
qry_imgs = [torch.rand(1, 6, 100) for _ in range(5)]


In [310]:
# run
output, align_loss = model(supp_imgs, fore_mask, back_mask, qry_imgs)

output.shape, align_loss

n_ways: 5 n_shots: 5 n_queries: 5 batch_size: 1 img_size: torch.Size([100])
supp_imgs shape: torch.Size([1, 6, 100])
after concat torch.Size([30, 6, 100]) which is (Wa*Sh + N) x B x 6 x 100, Wa = 5 Sh = 5 N = 5 B = 1
torch.Size([30, 6, 100])
after go through encoder torch.Size([30, 512, 100])
fts_size: torch.Size([100])
supp_fts shape: torch.Size([5, 5, 1, 512, 100])
qry_fts shape: torch.Size([5, 1, 512, 100])
fore_mask shape: torch.Size([5, 5, 1, 100])
back_mask shape: torch.Size([5, 5, 1, 100])
fg_prototypes shape: 5 torch.Size([1, 512])
bg_prototype shape: torch.Size([1, 512])
prototypes shape: 6 torch.Size([1, 512])
dist shape: 6 torch.Size([5, 100])
pred shape: torch.Size([5, 6, 100])
upsampling torch.Size([5, 6, 100])
outputs shape: 1 torch.Size([5, 6, 100])
outputs shape: 1 torch.Size([5, 6, 100])
torch.Size([5, 6, 100])


(torch.Size([5, 6, 100]), tensor(nan, grad_fn=<DivBackward0>))

In [291]:
"""
n_ways: 5 n_shots: 5 n_queries: 5 batch_size: 2 img_size: torch.Size([125, 125])
supp_imgs shape: torch.Size([2, 3, 125, 125])
after concat torch.Size([60, 3, 125, 125]) which is (Wa*Sh + N) x B x 3 x H x W, Wa = 5 Sh = 5 N = 5 B = 2
after go through encoder torch.Size([60, 512, 4, 4])
fts_size: torch.Size([4, 4])
supp_fts shape: torch.Size([5, 5, 2, 512, 4, 4])
qry_fts shape: torch.Size([5, 2, 512, 4, 4])
fore_mask shape: torch.Size([5, 5, 2, 125, 125])
back_mask shape: torch.Size([5, 5, 2, 125, 125])
supp_fg_fts shape: 5 5 torch.Size([1, 512])
fg_prototypes shape: 5 torch.Size([1, 512])
bg_prototype shape: torch.Size([1, 512])
prototypes shape: 6 torch.Size([1, 512])
dist shape: 6 torch.Size([5, 4, 4])
pred shape: torch.Size([5, 6, 4, 4])
UPSAMPLING: outputs shape: 1 torch.Size([5, 6, 125, 125])
torch.Size([10, 6, 125, 125])
.....
n_ways: 5 n_shots: 5 n_queries: 5 batch_size: 2 img_size: torch.Size([100])
supp_imgs shape: torch.Size([2, 6, 100])
after concat torch.Size([60, 6, 100]) which is (Wa*Sh + N) x B x 6 x 100, Wa = 5 Sh = 5 N = 5 B = 2
after go through encoder torch.Size([60, 512, 100])
fts_size: torch.Size([100])
supp_fts shape: torch.Size([5, 5, 2, 512, 100])
qry_fts shape: torch.Size([5, 2, 512, 100])
fore_mask shape: torch.Size([5, 5, 2, 100])
back_mask shape: torch.Size([5, 5, 2, 100])
supp_fg_fts shape: 5 5 torch.Size([1, 512])
fg_prototypes shape: 5 torch.Size([1, 512])
bg_prototype shape: torch.Size([1, 512])
prototypes shape: 6 torch.Size([1, 512])
dist shape: 6 torch.Size([5, 100])
pred shape: torch.Size([5, 6, 100])
outputs shape: 1 torch.Size([5, 6, 100])
"""

'\nn_ways: 5 n_shots: 5 n_queries: 5 batch_size: 2 img_size: torch.Size([125, 125])\nsupp_imgs shape: torch.Size([2, 3, 125, 125])\nafter concat torch.Size([60, 3, 125, 125]) which is (Wa*Sh + N) x B x 3 x H x W, Wa = 5 Sh = 5 N = 5 B = 2\nafter go through encoder torch.Size([60, 512, 4, 4])\nfts_size: torch.Size([4, 4])\nsupp_fts shape: torch.Size([5, 5, 2, 512, 4, 4])\nqry_fts shape: torch.Size([5, 2, 512, 4, 4])\nfore_mask shape: torch.Size([5, 5, 2, 125, 125])\nback_mask shape: torch.Size([5, 5, 2, 125, 125])\nsupp_fg_fts shape: 5 5 torch.Size([1, 512])\nfg_prototypes shape: 5 torch.Size([1, 512])\nbg_prototype shape: torch.Size([1, 512])\nprototypes shape: 6 torch.Size([1, 512])\ndist shape: 6 torch.Size([5, 4, 4])\npred shape: torch.Size([5, 6, 4, 4])\nUPSAMPLING: outputs shape: 1 torch.Size([5, 6, 125, 125])\n\n.....\nn_ways: 5 n_shots: 5 n_queries: 5 batch_size: 2 img_size: torch.Size([100])\nsupp_imgs shape: torch.Size([2, 6, 100])\nafter concat torch.Size([60, 6, 100]) which 