In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
from typing import List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url

from cell_segmentation.utils.post_proc_stardist import StarDistPostProcessor
from models.segmentation.cell_segmentation.cellvit import CellViT
from models.segmentation.cell_segmentation.cpp_net_stardist_rn50 import up, outconv, resnet50

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from models.segmentation.cell_segmentation.cellvit_cpp_net import CellViTCPP

In [4]:
CellViTCPP(
    6,
    19,
    384,
    3,
    12,
    6,
    [3, 6, 9, 12]
)

TypeError: __init__() got an unexpected keyword argument 'kernsel_size'

In [None]:

class CPPNet(nn.Module):

    def __init__(self, nrays=32, n_seg_cls=6):
        super(CPPNet, self).__init__()
        # Refinement
        self.erosion_factor_list=[0.2, 0.4, 0.6, 0.8, 1.0],
        self.sampling_feature = SamplingFeatures(nrays)
        self.nrays = nrays
        self.n_seg_cls = n_seg_cls

        self.backbone = resnet50(True)
        self.up1 = up(2048+1024, 1024, bilinear=True)
        self.up2 = up(1024+512, 512, bilinear=True)
        self.up3 = up(512+256, 256, bilinear=True)
        self.up4 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        
        self.features = nn.Conv2d(256, 256, 3, padding=1)
        self.out_prob = outconv(256, 1)
        self.out_ray = outconv(256, nrays)
        self.conv_0_confidence = outconv(256, nrays)
        self.conv_1_confidence = outconv(1+len(erosion_factor_list), 1+len(erosion_factor_list))
        
        # init
        nn.init.constant_(self.conv_1_confidence.conv.bias, 1.0)

        # upsamling
        self.up1_seg = up(2048+1024, 1024, bilinear=True)
        self.up2_seg = up(1024+512, 512, bilinear=True)
        self.up3_seg = up(512+256, 256, bilinear=True)
        self.up4_seg = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        
        # self.out_seg = outconv(256, n_seg_cls)
        #     if self.n_seg_cls == 1:
        #         self.final_activation_seg = nn.Sigmoid()
        #     else:
        #         self.final_activation_seg = nn.Softmax(dim=1)
        # self.final_activation_prob = nn.Sigmoid()

        self.final_activation_ray = nn.ReLU()
        


    def forward(self, img, gt_dist=None):
        x1, x2, x3, x4 = self.backbone(img) # .repeat(1,3,1,1))
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x)
        x = self.features(x)
        
        out_prob = self.out_prob(x)
        
        out_ray = self.out_ray(x)
        out_ray = self.final_activation_ray(out_ray)

        out_confidence = self.conv_0_confidence(x)

        if gt_dist is not None:
            out_ray_for_sampling = gt_dist
        else:
            out_ray_for_sampling = out_ray
        ray_refined = [ out_ray_for_sampling ]

        confidence_refined = [ out_confidence ]
        for erosion_factor in self.erosion_factor_list:
            base_dist = (out_ray_for_sampling-1.0)*erosion_factor
            ray_sampled, _ = self.sampling_feature(out_ray_for_sampling, base_dist, 1)
            conf_sampled, _ = self.sampling_feature(out_confidence, base_dist, 1)
            ray_refined.append(ray_sampled + base_dist)
            confidence_refined.append(conf_sampled)
        ray_refined = torch.stack(ray_refined, dim=1)
        b, k, c, h, w = ray_refined.shape

        confidence_refined = torch.stack(confidence_refined, dim=1)
        #confidence_refined = torch.cat((confidence_refined, ray_refined), dim=1)
        confidence_refined = confidence_refined.permute([0,2,1,3,4]).contiguous().view(b*c, k, h, w)
        confidence_refined = self.conv_1_confidence(confidence_refined)
        confidence_refined = confidence_refined.view(b, c, k, h, w).permute([0,2,1,3,4])
        confidence_refined = F.softmax(confidence_refined, dim=1)
        if self.return_conf:
            out_conf = [out_confidence, confidence_refined]
        else:
            out_conf = None
        ray_refined = (ray_refined*confidence_refined).sum(dim=1)

        out_ray = self.final_activation_ray(out_ray)
        ray_refined = self.final_activation_ray(ray_refined)
        out_prob = self.final_activation_prob(out_prob)

        if self.with_seg:
            x_seg = self.up1_seg(x4, x3)
            x_seg = self.up2_seg(x_seg, x2)
            x_seg = self.up3_seg(x_seg, x1)
            out_seg = self.out_seg(x_seg)
            if self.n_seg_cls == 1:
                out_seg = self.final_activation_seg(out_seg)
            elif not self.training:
                out_seg = self.final_activation_seg(out_seg)
        else:
            out_seg = None

        return [out_ray, ray_refined], [out_prob], [out_seg, ], [out_conf, ]


    def init_weight(self):
        for m in self.modules():        
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if hasattr(m, 'bias'):
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
        nn.init.constant_(self.conv_1_confidence.conv.bias, 1.0)

In [None]:
class CellViTCPP(CellViT):
    def __init__(
        self,
        num_nuclei_classes: int,
        num_tissue_classes: int,
        embed_dim: int,
        input_channels: int,
        depth: int,
        num_heads: int,
        extract_layers: List,
        nrays: int = 32,
        mlp_ratio: float = 4,
        qkv_bias: bool = True,
        drop_rate: float = 0,
        attn_drop_rate: float = 0,
        drop_path_rate: float = 0,
        # cpp-net specific 
        erosion_factors: Tuple[float] = (0.2, 0.4, 0.6, 0.8, 1.0),

    ):
        super(CellViT, self).__init__()
        assert len(extract_layers) == 4, "Please provide 4 layers for skip connections"

        self.patch_size = 16
        self.num_tissue_classes = num_tissue_classes
        self.num_nuclei_classes = num_nuclei_classes
        self.embed_dim = embed_dim
        self.input_channels = input_channels
        self.depth = depth
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.extract_layers = extract_layers
        self.drop_rate = drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.drop_path_rate = drop_path_rate
        self.nrays = nrays
        self.prompt_embed_dim = 256
        
        self.encoder = ViTCellViT(
            patch_size=self.patch_size,
            num_classes=self.num_tissue_classes,
            embed_dim=self.embed_dim,
            depth=self.depth,
            num_heads=self.num_heads,
            mlp_ratio=self.mlp_ratio,
            qkv_bias=self.qkv_bias,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            extract_layers=self.extract_layers,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
        )
        
        self.decoder0 = nn.Sequential(
            Conv2DBlock(3, 32, 3, dropout=self.drop_rate),
            Conv2DBlock(32, 64, 3, dropout=self.drop_rate),
        )  # skip connection after positional encoding, shape should be H, W, 64
        self.decoder1 = nn.Sequential(
            Deconv2DBlock(self.embed_dim, self.skip_dim_11, dropout=self.drop_rate),
            Deconv2DBlock(self.skip_dim_11, self.skip_dim_12, dropout=self.drop_rate),
            Deconv2DBlock(self.skip_dim_12, 128, dropout=self.drop_rate),
        )  # skip connection 1
        self.decoder2 = nn.Sequential(
            Deconv2DBlock(self.embed_dim, self.skip_dim_11, dropout=self.drop_rate),
            Deconv2DBlock(self.skip_dim_11, 256, dropout=self.drop_rate),
        )  # skip connection 2
        self.decoder3 = nn.Sequential(
            Deconv2DBlock(self.embed_dim, self.bottleneck_dim, dropout=self.drop_rate)
        )
        
        # all decoders here are without a head and return 32 features
        self.stardist_decoder = self.create_upsampling_branch(32) 
        self.dist_decoder = self.create_upsampling_branch(32)
        self.nuclei_type_maps_decoder = self.create_upsampling_branch(32)
        
        self.stardist_head = nn.Conv2d(
            in_channels=32, in_channels=self.nrays, kernel_size=1, bias=False
        )
        self.dist_head = nn.Conv2d(
            in_channels=32, in_channels=1, kernel_size=1, bias=False
        )
        self.type_head = nn.Conv2d(
            in_channels=32, in_channels=self.num_nuclei_classes, kernel_size=1, bias=False
        )
        
        self.classifier_head = (
            nn.Linear(self.prompt_embed_dim, num_tissue_classes)
            if num_tissue_classes > 0
            else nn.Identity()
        )

        # cpp-net specific head
        self.erosion_factors = list(erosion_factors)
        self.conv_0_confidence = nn.Conv2d(
            in_channels=32, in_channels=self.nrays, kernel_size=1, bias=False
        )
        self.conv_1_confidence =  nn.Conv2d(
            in_channels=(1 + len(erosion_factors)), 
            out_channels=(1 + len(erosion_factors)), 
            kernsel_size=1,
            bias=True
        )
        self.sampling_features = SamplingFeatures(nrays=nrays)
        self.final_activation_ray = nn.ReLU(inplace=True)

        def cppnet_refine(
        self, stardist_map: torch.Tensor, features: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Refine the stardist map and confidence map.

        Parameters
        ----------
            stardist_map : torch.Tensor
                The stardist map. Shape: (B, n_rays, H, W)
            features : torch.Tensor
                The features from the encoder. Shape: (B, C, H, W)

        Returns
        -------
            Tuple[torch.Tensor, torch.Tensor]
                - refined stardist map. Shape: (B, n_rays, H, W)
                - refined confidence map. Shape: (B, C, H, W)
        """
        # cppnet specific ops
        out_confidence = self.conv_0_confidence(features) # TODO. check feature shape
        out_ray_for_sampling = stardist_map

        ray_refined = [out_ray_for_sampling]
        confidence_refined = [out_confidence]

        for erosion_factor in self.erosion_factors:
            base_dist = (out_ray_for_sampling - 1.0) * erosion_factor
            ray_sampled, _, _ = self.sampling_features(
                out_ray_for_sampling, base_dist, 1
            )
            conf_sampled, _, _ = self.sampling_features(out_confidence, base_dist, 1)
            ray_refined.append(ray_sampled + base_dist)
            confidence_refined.append(conf_sampled)
        ray_refined = torch.stack(ray_refined, dim=1)
        b, k, c, h, w = ray_refined.shape

        confidence_refined = torch.stack(confidence_refined, dim=1)
        confidence_refined = (
            confidence_refined.permute([0, 2, 1, 3, 4])
            .contiguous()
            .view(b * c, k, h, w)
        )
        confidence_refined = self.conv_1_confidence(confidence_refined)
        confidence_refined = confidence_refined.view(b, c, k, h, w).permute(
            [0, 2, 1, 3, 4]
        )
        confidence_refined = F.softmax(confidence_refined, dim=1)

        ray_refined = (ray_refined * confidence_refined).sum(dim=1)
        ray_refined = self.final_activation_ray(ray_refined)

        return ray_refined, confidence_refined
    

    def forward(self, x: torch.Tensor, retrieve_tokens: bool = False):
        assert (
            x.shape[-2] % self.patch_size == 0
        ), "Img must have a shape of that is divisble by patch_soze (token_size)"
        assert (
            x.shape[-1] % self.patch_size == 0
        ), "Img must have a shape of that is divisble by patch_soze (token_size)"
        
        classifier_logits, _, z = self.encoder(x)

        z0, z1, z2, z3, z4 = x, *z

        # performing reshape for the convolutional layers and upsampling (restore spatial dimension)
        patch_dim = [int(d / self.patch_size) for d in [x.shape[-2], x.shape[-1]]]
        z4 = z4[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
        z3 = z3[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
        z2 = z2[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
        z1 = z1[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
        
        stardist_features = self._forward_upsample(
            z0, z1, z2, z3, z4, self.stardist_decoder
        )
        dist_map_features = self._forward_upsample(
            z0, z1, z2, z3, z4, self.dist_decoder
        )
        type_map_features = self._forward_upsample(
            z0, z1, z2, z3, z4, self.dist_decoder
        )
        
        stardist_head_out = self.stardist_head(stardist_features)
        dist_map_head_out = self.dist_head(dist_map_features)
        type_map_head_out = self.type_head(type_map_features)
        
        ray_refined, confidence_refined = self.cppnet_refine(
            stardist_head_out, stardist_features
        )
        
        out_dict = {
            "stardist_map": stardist_head_out,
            "stardist_refined": ray_refined,
            "confidence_refined": confidence_refined,
            "dist_map_head_out": dist_map_head_out,
            "type_map_head_out": type_map_head_out,
            "tissue_types": classifier_logits
        }
        
        if retrieve_tokens:
            out_dict["tokens"] = z4

        return out_dict
