In [102]:
from functools import partial

import torch
import torch.nn as nn
# from util.diceloss import SoftDiceLoss
from util.diceloss import DiceLoss
from timm.models.vision_transformer import PatchEmbed, Block

from util.pos_embed import get_2d_sincos_pos_embed

import seg3d_2dencoder
from util.pos_embed import interpolate_pos_embed
from swin_transformer import *
from email.mime import image
from turtle import forward
from util.diceloss import DiceLoss
from einops import rearrange
class SegVit3D(nn.Module):
    """
    encoder: 预训练2dmae的encoder参数(冻结权重)
    decoder: 采用3D滑窗transformer_decoder
    """
    def __init__(self, img_size=224,img_deep=160, patch_size=16, in_chans=1,embed_dim=256,
                 decoder_embed_dim=256,decoder_depth=[2,2,2,2],decoder_num_heads=[4,8,16,32],
                 encoder:nn.Module=None,encoder_finetune:str=None,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
                 window_size = 4, shift_size = 1,
                 qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.1,
                 act_layer=nn.GELU,
                 fused_window_process=False,
                 lossforpatch = True,
                 ):
        r"""SegVit3D.
        Args:
            img_size(int):the size of one slide of 3D subject
            img_deep(int):the slides number of 3D subject
            in_chans(int):the channel of 3D subject
            embed_dim (int):the dim of encoder patch_embed
            decoder_dim(int):the dim of decoder patch_embed
            decoder_depth(int):the decoder-block number 
            decoder_num_heads (int): Number of decoder attention heads.
            encoder(nn.Module):mae encoder(pretrain)
            encoder_fine_tune(str):pretrain model path
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            window_size (int): Window size.
            shift_size (int): Shift size for SW-MSA.
            qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
            qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
            drop (float, optional): Dropout rate. Default: 0.0
            attn_drop (float, optional): Attention dropout rate. Default: 0.0
            drop_path (float, optional): Stochastic depth rate. Default: 0.0
            act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
            norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
            fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
            lossforpatch:choose target->patch or pred->image
        """
        super().__init__()
        #-----------------------------------------------------
        self.img_size = img_size
        self.norm = norm_layer(decoder_embed_dim)
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.img_deep = img_deep
        
        # MaeSeg3D encoder
        self.encoder = encoder
        # ---------------------------------------------------------------------
        # MaeSeg3D decoder 
        dpr = [x.item() for x in torch.linspace(0, drop_path, sum(decoder_depth))]  # stochastic depth decay rule

        self.num_layers = len(decoder_depth)

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(decoder_embed_dim),
                               input_resolution=(int((img_size//patch_size)**2) ,
                                                 img_deep ),
                               depth=decoder_depth[i_layer],
                               num_heads=decoder_num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop, attn_drop=attn_drop,
                               drop_path=dpr[sum(decoder_depth[:i_layer]):sum(decoder_depth[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample= None,
                               use_checkpoint=False,
                               fused_window_process=fused_window_process)
            self.layers.append(layer)
        self.apply(self._init_weights)
        self.lossforpatch = lossforpatch

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}



    def encoder_forward(self,x):
        """
        input:
        x_in:(N,1,D,H,W) #(1,1,160,224,224)
        return:
        x_latent:(N,L,patch_size**2*1) #(1,196*160,256)
        """
        N,C,D,H,W = x.shape()
        
        x = x.squeeze(0).permute(1,0,2,3)
        x_latent = self.encoder(x)
        return x_latent
        

    # x_latent.shape = (160,196,256)

    def decoder_forward(self,x):
        """
        x_in:(B,patch_num*img_deep,256)  #(B,196*160,256)
        """
        for layer in self.layers:
            x = layer(x)
        
        x = self.norm(x) # B L C
        return x

    def forward_loss(self,pred,target):
        """
        target:[B,1,H,W,C]
        pred:[B,L,256]
        """
        LOSS = DiceLoss()
        if self.lossforpatch:
            target = self.patchify3D(target)  #[N,196,256]
            assert pred.shape == target.shape
            loss = LOSS(pred,target)
        else:
            pred = self.unpatchify3D(pred)  #[N,1,224,224]
            assert pred.shape == target.shape
            loss = LOSS(pred,target)
        return loss
    
    def forward(self,imgs,label):
        latent = self.encoder_forward(imgs)
        pred = self.decoder_forward(latent)  # [N, L, p*p*1] (N,196,256)
        loss = self.forward_loss(pred, label)
        return loss, pred
        

    def patchify3D(self, imgs):
        """
        imgs: (N, 1, T, H, W)
        x: (N, L, patch_size**2 *1 *temp_stride)
        """
        x = rearrange(imgs, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=1, p1=self.patch_size, p2=self.patch_size)
        x = rearrange(x, 'b n p c -> b n (p c)')
        return x

    def unpatchify3D(self, x):
        """
        x: (N, L, patch_size**2 *1 *temp_stride)
        imgs: (N, 1, T, H, W)
        """
        x = rearrange(x, 'b (t h w) (p0 p1 p2 c) -> b c (t p0) (h p1) (w p2)', p1=self.patch_size, p2=self.patch_size, c=self.in_chans, h=int(self.img_size//self.patch_size), w=int(self.img_size//self.patch_size))
        return x
        
        

