In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import torch
import torch.nn as nn

In [2]:
class UNetDecoderBlock(nn.Module):
    def __init__(self, in_channels, res_channels, out_channels):
        super(UNetDecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels + res_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip):
        x = self.upconv(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return

In [3]:
class FPN(nn.Module):
    def __init__(self, input_channels: list, output_channels: list, mask_size = (112, 112)):
        super(FPN, self).__init__()
        self.mask_size = mask_size
        self.convs = nn.ModuleList(
            [nn.Sequential(
                nn.Conv2d(in_ch, out_ch * 2, kernel_size=3, padding=1), # 384, 64  (batch size = 2*11=22frames)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(out_ch * 2),  # Changed to BatchNorm2d
                nn.Conv2d(out_ch * 2, out_ch, kernel_size=3, padding=1) # 64, 32
            ) for in_ch, out_ch in zip(input_channels, output_channels)]
        )

    def forward(self, xs: list):
        # Scale feature maps to the same resolution
        hcs = [
            F.interpolate(c(x), size=self.mask_size, mode='bilinear', align_corners=False)
            for i, (c, x) in enumerate(zip(self.convs, xs))
        ]
        print(len(hcs))
        return torch.cat(hcs, dim=1)
        # Concatenate along channel dimension

In [None]:
class _25DCnnRnnSegAux(nn.Module):
    '''
    2.5 D model : CNN model + RNN with Segmentation head
    '''
    def __init__(self,
                 encoder,
                 num_classes = 2,
                 num_classes_aux = 11,
                 n_channels = 3,
                 head_3d = "",
                 n_frames = 1,
                 dropout_rate=0,
                 segmentation_aux=True):
        super().__init__()
        #set up
        self.encoder = encoder
        self.num_features= encoder.num_features
        self.num_classes = num_classes
        self.num_classes_aux = num_classes_aux
        self.n_channels = n_channels
        self.head_3d = head_3d
        self.n_frames = n_frames
        self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate > 0 else nn.Identity()
        # layers.
        if head_3d == 'lstm':
            self.lstm = nn.LSTM(input_size = self.num_features,
                                hidden_size = self.num_features // 4,
                                batch_first = True,
                                bidirectional=True)


        self.last_layer = nn.Linear(self.num_features, self.num_classes)

        if self.num_classes_aux > 0:
            if segmentation_aux:
                self.aux_layer = SegmentationHead(self.encoder)
            else:
                self.aux_layer = nn.Linear(self.num_features, self.num_classes_aux)


    def extract_features(self, x):
        feature_output = self.encoder(x)
        intermediate_feature = None
        if self.segmentation_aux and 'coat' in self.encoder.name:
            layer0_encoder_feature = self.encoder.stem(x)
            layer1_encoder_feature = self.encoder.stages[0](layer0_encoder_feature)
            layer2_encoder_feature = self.encoder.stages[1](layer1_encoder_feature)
            layer3_encoder_feature = self.encoder.stages[2](layer2_encoder_feature)
            layer4_encoder_feature = self.encoder.stages[3](layer3_encoder_feature)

            intermediate_features = [
                layer4_encoder_feature,
                layer3_encoder_feature,
                layer2_encoder_feature,
                layer1_encoder_feature,
                layer0_encoder_feature
                ]

        fts = self.dropout(fts)
        return fts, intermediate_features

    def forward_head3d(self, x):
        if self.head_3d == 'lstm':
            x, _ = self.lstm(x)
            mean  = x.mean(1)
            max_ = x.amax(1)
            x = torch.cat([mean, max_], dim=1)
        return x


    def forward(self, x):
        if self.head_3d:
            bs, n_frames, c, h , w  = x.size()
            x = x.view(bs*n_frames, c, h, w)

        fts, intermediate_features = self.extract_features(x)

        print("fts size: ", fts.shape)

        if self.head_3d != "":
            fts = fts.view(bs, n_frames, -1)
            fts = self.forward_head3d(fts)

        output = self.last_layer(fts)

        if self.num_classes_aux:
            segmentation_output = self.segmentation_head(intermediate_features)
        else:
            segmentation_output = None
        if self.num_classes_aux > 0:
            aux_output = self.aux_layer(fts)
        else:
            aux_output = torch.zeros((fts.size(0)))
        # processing volumn mask into slice -> input to the model

        return output, aux_output


In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = smp.losses.DiceLoss(smp.losses.MULTILABEL_MODE, from_logits=True)

    def forward(self, outputs, targets, masks_outputs, masks_targets):
        loss1 = self.bce(outputs, targets)

        masks_outputs = masks_outputs.float()

        masks_targets = masks_targets.float().flatten(0, 1)

        loss2 = self.dice(masks_outputs, masks_targets) #+ self.dice(masks_outputs2, masks_targets)


        loss = loss1 + (loss2 * CFG.segw)

        return loss