In [None]:
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np
import math
    
class CoPreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x,x1, **kwargs):
        return self.fn(self.norm(x),self.norm(x1), **kwargs)

class CoAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x,x1):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        qkv1 = self.to_qkv(x1).chunk(3, dim = -1)
        q1, k1, v1 = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv1)
        dots = einsum('b h i d, b h j d -> b h i j', q, k1) * self.scale

        attn = dots.softmax(dim=-1)
        
        out = einsum('b h i j, b h j d -> b h i d', attn, v1)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class CoTransformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                CoPreNorm(dim, CoAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))

    def forward(self, x,x1):
        for attn, ff in self.layers:
            x = attn(x,x1) + x
            x = ff(x) + x
        return self.norm(x)
    
class cosa(nn.Module):
    def __init__(self, image_size, patch_size, in_channels ,num_frames, depth = 4, heads = 3,num_classes=1, pool = 'cls', 
                 dim = 8, dim_head = 64, dropout = 0.,emb_dropout = 0., scale_dim = 4, ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = in_channels * patch_size ** 2
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b t c (h p1) (w p2) -> b (t c) (h w) (p1 p2)', p1 = patch_size, p2 = patch_size),
        )
        self.pos_embedding = PositionalEmbedding(num_patches*patch_size**2)
        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = CoTransformer(patch_size**2, depth, heads, dim_head, dim*scale_dim, dropout)
        self.img_out = image_size
        self.time = num_frames
        self.channel = in_channels
        self.patch_size = patch_size
    def forward(self, x1,x2):
        x1,x2 = rearrange(x1, 'b c t n d -> b t c n d'), rearrange(x2, 'b c t n d -> b t c n d')
        x1,x2 = self.to_patch_embedding(x1),self.to_patch_embedding(x2)
        b, t, n, d = x1.shape
        pos1,pos2 = self.pos_embedding(x1,n,d),self.pos_embedding(x2,n,d)
        x1 = x1 + pos1
        x2 = x2 + pos2
        x1 = rearrange(x1, 'b t n d -> (b t) n d')
        x2 = rearrange(x2, 'b t n d -> (b t) n d')
        x = self.transformer(x1,x2)
        patch_height = int(self.img_out/self.patch_size)
        x = rearrange(x, '(t c) (ph pw) (p1 p2) -> t c (ph p1) (pw p2)', t = self.time, c = self.channel,
                      ph = patch_height, p1 = self.patch_size)
        x = x.unsqueeze(0)
        x = rearrange(x, 'b t c n d -> b c t n d')
        return x
    
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np
import math

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x_attn = attn(x) 
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x),self.norm(x_attn)

class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb
        self.mlp = nn.Sequential(nn.LayerNorm(self.demb),nn.Linear(self.demb, 1))
        inv_freq = 1 / (10 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq,imgsize,imgdim, bsz=None):
        pos_seq = rearrange(pos_seq, 'b t n d -> b t (n d)')
        pos_seq = self.mlp(pos_seq)
        pos_seq = pos_seq.squeeze(0)
        pos_seq = pos_seq.squeeze(1)
        #print(pos_seq)
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        if bsz is not None:
            pos_emb = pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            pos_emb = pos_emb[:,None,:]
        #print(pos_emb.shape)
        pos_emb = rearrange(pos_emb, 't b (n d) -> b t n d',n=imgsize,d=imgdim)
        return pos_emb
    
class STFH(nn.Module):
    def __init__(self, image_size,patch_size, in_channels ,num_frames, depth = 4, heads = 3,num_classes=1, pool = 'cls', 
                 dim = 8, dim_head = 64, dropout = 0.,emb_dropout = 0., scale_dim = 4, ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = in_channels * patch_size ** 2
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b t c (h p1) (w p2) -> b (t c) (h w) (p1 p2)', p1 = patch_size, p2 = patch_size),
        )
        self.pos_embedding = PositionalEmbedding(num_patches*patch_size**2)
        self.dropout = nn.Dropout(emb_dropout)
        self.spatio_temporal_transformer = Transformer(patch_size**2, depth, heads, dim_head, dim*scale_dim, dropout)
        self.img_out = image_size
        self.time = num_frames
        self.channel = in_channels
        self.patch_size = patch_size
        
       
    def forward(self, x):
        #print('before p_emb',x.shape)
        theta = 5
        pool = nn.MaxPool3d(kernel_size=theta, stride=1, padding=(theta - 1) // 2)
        x_pool = pool(x)
        x = 2*x - x_pool
        x = rearrange(x, 'b c t n d -> b t c n d')
        x = self.to_patch_embedding(x)
        b, t, n, d = x.shape
        pos = self.pos_embedding(x,n,d)
        x = x + pos
        x = self.dropout(x)
        x = rearrange(x, 'b t n d -> (b t) n d')
        x,x_att = self.spatio_temporal_transformer(x)
        patch_height = int(self.img_out/self.patch_size)
        x = rearrange(x, '(t c) (ph pw) (p1 p2) -> t c (ph p1) (pw p2)', t = self.time, c = self.channel,
                      ph = patch_height, p1 = self.patch_size)
        x = x.unsqueeze(0)
        x = rearrange(x, 'b t c n d -> b c t n d')
        x_att = rearrange(x_att, '(t c) (ph pw) (p1 p2) -> t c (ph p1) (pw p2)', t = self.time, c = self.channel,
                      ph = patch_height, p1 = self.patch_size)
        x_att = x_att.unsqueeze(0)
        x_att = rearrange(x_att, 'b t c n d -> b c t n d')
        return x,x_att
    
class Dual_path(nn.Module):
    def __init__(self, image_size, in_channels ,num_frames, depth = 4, heads = 3,num_classes=1, pool = 'cls', 
                 dim = 8, dim_head = 64, dropout = 0.,emb_dropout = 0., scale_dim = 4, ):
        super().__init__()
        #assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        #num_patches = (image_size // patch_size) ** 2
        #patch_dim = in_channels * patch_size ** 2
        self.spatio_stfh = STFH(image_size=image_size, patch_size=image_size//2, in_channels=in_channels ,num_frames=num_frames)
        self.temporal_stfh = STFH(image_size=image_size, patch_size=image_size, in_channels=in_channels ,num_frames=num_frames)
       
    def forward(self, x):
        #print('before p_emb',x.shape)
        theta = 5
        pool = nn.MaxPool3d(kernel_size=theta, stride=1, padding=(theta - 1) // 2)
        x_pool = pool(x)
        x_bound = 2*x - x_pool
        x_bound,x_att_b = self.spatio_stfh(x_bound)
        x,x_att_t = self.temporal_stfh(x)
        x_croatt = x_att_t*x_att_b
        x = x+x_bound+x_croatt
        return x

class TimeDistributed(nn.Module):
    def __init__(self, layer, time_steps):        
        super(TimeDistributed, self).__init__()
        self.layers = nn.ModuleList([layer for i in range(time_steps)])

    def forward(self, x):
        x = rearrange(x, 'b c t n d -> b t c n d')
        batch_size, time_steps, C, H, W = x.size()
        output = torch.tensor([]).cuda()
        #output = torch.tensor([])
        for i in range(time_steps):
            output_t = self.layers[i](x[:, i, :, :, :])
            output_t  = output_t.unsqueeze(1)
            output = torch.cat((output, output_t ), 1)
        output = rearrange(output, 'b t c n d -> b c t n d')
        return output


class EncoderBottleneck3d(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, base_width=64):
        super().__init__()

        self.downsample = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=[1,2,2], bias=False),
            nn.BatchNorm3d(out_channels)
        )

        width = int(out_channels * (base_width / 64))

        self.conv1 = nn.Conv3d(in_channels, width, kernel_size=1, stride=1, bias=False)
        self.norm1 = nn.BatchNorm3d(width)

        self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=[1,2,2], groups=1, padding=1, dilation=1, bias=False)
        self.norm2 = nn.BatchNorm3d(width)

        self.conv3 = nn.Conv3d(width, out_channels, kernel_size=1, stride=1, bias=False)
        self.norm3 = nn.BatchNorm3d(out_channels)

        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x_down = self.downsample(x)
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = x + x_down
        x = self.relu(x)

        return x
    
class Encoder(nn.Module):
    def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim,seq_frame):
        super().__init__()

        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=7, stride=[1,2,2], padding=3, bias=False)
        self.norm1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=False)
        self.encoder1 = EncoderBottleneck3d(out_channels, out_channels * 2, stride=2)
        self.encoder2 = EncoderBottleneck3d(out_channels * 2, out_channels * 4, stride=2)
        self.encoder3 = EncoderBottleneck3d(out_channels * 4, out_channels * 8, stride=2)

        self.conv2 = nn.Conv3d(out_channels * 8, out_channels * 4, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm3d(out_channels * 4)

    def forward(self, x):
        x = rearrange(x, 'b t c n d -> b c t n d')
        x = self.conv1(x)
        x = self.norm1(x)
        x1 = self.relu(x)
        x2 = self.encoder1(x1)
        x3 = self.encoder2(x2)
        x = self.encoder3(x3)
        #x = rearrange(x, "b t (h w) c ->b t c h w",h = 8, w=8)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)
        #x = rearrange(x, 'b c t n d -> b t c n d')
        #x1_out = rearrange(x1_out, 'b c t n d -> b t c n d')
        #x2_out = rearrange(x2_out, 'b c t n d -> b t c n d')
        #x3_out = rearrange(x3_out, 'b c t n d -> b t c n d')
        return x, x1, x2, x3

class DecoderBottleneck3d(nn.Module):
    def __init__(self, in_channels, out_channels,seq_frame, scale_factor=2):
        super().__init__()

        self.upsample = TimeDistributed(nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True), time_steps = seq_frame)
        self.layer = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=False),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=False))

    def forward(self, x, x_concat=None):
        
        x = self.upsample(x)

        if x_concat is not None:
            x = torch.cat([x_concat, x], dim=1)
        x = self.layer(x)
        return x

class dconv3d(nn.Module):
    def __init__(self, in_channels, out_channels,d_rate, stride=1, base_width=64):
        super().__init__()

        width = int(out_channels * (base_width / 64))

        self.conv1 = nn.Conv3d(in_channels, width, kernel_size=1, stride=1, bias=False)
        self.norm1 = nn.BatchNorm3d(width)

        self.conv2 = nn.Conv3d(width, width, kernel_size=3, stride=[1,1,1], groups=1, padding=d_rate, dilation=d_rate, bias=False)
        self.norm2 = nn.BatchNorm3d(width)

        self.conv3 = nn.Conv3d(width, out_channels, kernel_size=1, stride=1, bias=False)
        self.norm3 = nn.BatchNorm3d(out_channels)

        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu(x)

        return x
    
class task_specific_decoder(nn.Module):
    def __init__(self, out_channels, class_num, drop_rate,seq_frame):
        super().__init__()
        self.d_rate = drop_rate
        self.decoder1 = DecoderBottleneck3d(out_channels * 8, out_channels * 2,seq_frame)
        self.decoder2 = DecoderBottleneck3d(out_channels * 4, out_channels,seq_frame)
        self.dropout = TimeDistributed(torch.nn.Dropout(p=self.d_rate),time_steps = seq_frame)
        
    def forward(self, x, x2, x3):
        x = self.decoder1(x, x3)
        x = self.dropout(x)
        x = self.decoder2(x, x2)
        #x = rearrange(x, 'b c t n d -> b t c n d')
        return x

    
class segment_head(nn.Module):
    def __init__(self, out_channels, class_num, drop_rate,seq_frame):
        super().__init__()
        self.decoder3 = DecoderBottleneck3d(out_channels * 2, int(out_channels * 1 / 2),seq_frame)
        self.decoder4 = DecoderBottleneck3d(int(out_channels * 1 / 2), int(out_channels * 1 / 8),seq_frame)
        self.conv1 = nn.Conv3d(int(out_channels * 1 / 8), class_num, kernel_size=1)
        
    def forward(self, x, x1):
        x = self.decoder3(x, x1)
        x = self.decoder4(x)
        x = self.conv1(x)
        x = rearrange(x, 'b c t n d -> b t c n d')
        return x

class point_head(nn.Module):
    def __init__(self, out_channels, class_num, seq_frame,height,weight):
        super().__init__()
        self.seq_frame = seq_frame
        self.height = height
        self.weight = weight
        self.decoder3 = DecoderBottleneck3d(out_channels * 2, int(out_channels * 1 / 2),seq_frame)
        self.decoder4 = DecoderBottleneck3d(int(out_channels * 1 / 2), int(out_channels * 1 / 8),seq_frame)
        self.conv1 = nn.Conv3d(int(out_channels * 1 / 8), class_num, kernel_size=1)
        self.conv2 = nn.Conv3d(class_num,1, kernel_size=1)
        self.mlp_out = nn.Sequential(
            nn.LayerNorm(self.height*self.weight),
            nn.Linear(self.height*self.weight, 2))
        
    def forward(self, x, x1):
        x = self.decoder3(x, x1)
        x = self.decoder4(x)
        x = self.conv1(x)
        x_map = self.conv2(x)
        x_map = rearrange(x_map, 'b c t n d -> b t c n d')
        x = rearrange(x, 'b c t n d -> b t c n d')
        x_pot = rearrange(x, 'b t c n d -> b t c (n d)')
        x_pot = self.mlp_out(x_pot)
        return x_pot,x_map

# class classifi_head(nn.Module):
#     def __init__(self, out_channels, class_num,seq_frame):
#         super().__init__()
#         self.out_ch = out_channels
#         self.class_num = class_num
#         self.seq_frame = seq_frame
#         self.avgpool = nn.AvgPool2d(32, stride=1)
#         self.fc = nn.Linear(self.out_ch,self.class_num)
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#             elif isinstance(m, nn.BatchNorm2d):
#                 nn.init.constant_(m.weight, 1)
#                 nn.init.constant_(m.bias, 0)
                
#     def forward(self,x):
#         x = x[:,:,0,:,:]
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
#         return x
    
class frame_head(nn.Module):
    def __init__(self, out_channels, class_num,seq_frame):
        super().__init__()
        self.out_ch = out_channels
        self.class_num = class_num
        self.seq_frame = seq_frame
        self.dconv3d2 = dconv3d(self.out_ch*1, self.out_ch*2, d_rate = 3, stride=2)
        self.avgpool2d2  = TimeDistributed(nn.AdaptiveAvgPool2d(8), time_steps = seq_frame)
        self.dconv3d3 = dconv3d(self.out_ch*2, self.out_ch*4, d_rate = 2, stride=2)
        self.avgpool2d3  = TimeDistributed(nn.AdaptiveAvgPool2d(2), time_steps = seq_frame)
        self.dconv3d4 = dconv3d(self.out_ch*4, self.out_ch*8, d_rate = 1, stride=2)
        self.avgpool2d4  = TimeDistributed(nn.AdaptiveAvgPool2d(1), time_steps = seq_frame)
        self.mlp_out = nn.Sequential(
            nn.LayerNorm(self.out_ch*8),
            nn.Linear(self.out_ch*8, self.class_num))
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self,x):
        #print(x.shape)
        x = self.dconv3d2(x)
        #print(x.shape)
        x = self.avgpool2d2(x)
        x = self.dconv3d3(x)
        x = self.avgpool2d3(x)
        x = self.dconv3d4(x)
        x = self.avgpool2d4(x)
        #print(x.shape)
        x = rearrange(x, 'b t c n d -> (c n d)(b t)')
        x_reg = self.mlp_out(x)
        x_reg.squeeze()
        #x_reg = rearrange(x_reg, 'n s -> s n')
        return x_reg
    
class TaskAttention(nn.Module):
    def __init__(self, in_ch, ratio=16):
        super(TaskAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
           
        self.fc = nn.Sequential(nn.Conv3d(in_ch*2, in_ch//16, kernel_size=1),
                               nn.ReLU(),
                               nn.Conv3d(in_ch // 16, in_ch*2, kernel_size=1))
        self.sigmoid = nn.Sigmoid()
        self.conv3 = nn.Conv3d(in_ch*2, in_ch, kernel_size=1, stride=1, bias=False)
        self.norm3 = nn.BatchNorm3d(in_ch)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x1,x2):
        x = torch.cat([x1, x2], dim=1)
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        x = self.sigmoid(out)+x
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu(x)
        return x
    
class pattern_strcture(nn.Module):
    def __init__(self, image_size, in_channels ,num_frames, depth = 4, heads = 3,num_classes=1, pool = 'cls', 
                 dim = 8, dim_head = 64, dropout = 0.,emb_dropout = 0., scale_dim = 4, ):
        super().__init__()
        self.spatio_stfh = STFH(image_size=image_size, patch_size=image_size//4, in_channels=in_channels ,num_frames=num_frames)
        self.ta = TaskAttention(in_channels)
        size = 32
        self.cosa = cosa(image_size=size,patch_size=int(size//2), in_channels=32 ,num_frames=30)
       
    def forward(self,x1,x2,x3):
        x1_sa,x_att_b1 = self.spatio_stfh(x1)
        #print(self.ta(x2,x3,x4).shape)
        x2_sa,x_att_b2 = self.spatio_stfh(x2)
        #print(x1.shape,x2.shape,x3.shape,x4.shape)
        x3_sa,x_att_b3 = self.spatio_stfh(x3)
        # x4_sa,x_att_b4 = self.spatio_stfh(x4)
        beta = 0.1
        # x1 = beta*self.ta(x2_sa,x3_sa)+(1-beta)*x1_sa
        # x2 = beta*self.ta(x1_sa,x3_sa)+(1-beta)*x2_sa
        # x3 = beta*self.ta(x2_sa,x1_sa)+(1-beta)*x3_sa
        #print(self.ta(x2_sa,x3_sa).shape,x1_sa.shape)
        x1 = self.cosa(beta*self.ta(x2_sa,x3_sa),(1-beta)*x1_sa)
        x2 = self.cosa(beta*self.ta(x1_sa,x3_sa),(1-beta)*x2_sa)
        x3 = self.cosa(beta*self.ta(x2_sa,x1_sa),(1-beta)*x3_sa)
        return x1,x2,x3

seq_frame = 30

class Multivit_net(nn.Module):
    def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num, drop_rate,seq_frame
                 ,mode,height,weight):
        super().__init__()

        self.encoder = Encoder(img_dim, in_channels, out_channels,
                               head_num, mlp_dim, block_num, patch_dim,seq_frame)
        self.seg_decoder = task_specific_decoder(out_channels, class_num, drop_rate,seq_frame)
        self.pot_decoder = task_specific_decoder(out_channels, class_num, drop_rate,seq_frame)
        self.frm_decoder = task_specific_decoder(out_channels, class_num, drop_rate,seq_frame)
        # self.cls_decoder = task_specific_decoder(out_channels, class_num, drop_rate,seq_frame)
        self.seg_head = segment_head(out_channels, class_num, drop_rate,seq_frame)
        # self.cls_head = classifi_head(out_channels, class_num=4,seq_frame=seq_frame)
        self.pot_head = point_head(out_channels, class_num=4,seq_frame=seq_frame,height=height,weight=weight)
        self.frm_head = frame_head(out_channels, class_num=3,seq_frame=seq_frame)
        self.mode = mode
        self.stps = pattern_strcture(image_size=int(img_dim/4), in_channels=out_channels ,num_frames=seq_frame)

    def forward(self, x):
        x, x1, x2, x3 = self.encoder(x)
        if self.mode == 'seg':
            x_seg = self.seg_decoder(x, x2, x3)
            x_seg = self.seg_head(x_seg,x1)
            return x_seg
        elif self.mode == 'pot':
            x_pot = self.pot_decoder(x, x2, x3)
            x_pot = self.pot_head(x_pot,x1)
            return x_pot
        elif self.mode == 'frm':
            x_frm = self.frm_decoder(x, x2, x3)
            x_frm = self.frm_head(x_frm)
            return x_frm
        # elif self.mode == 'cls':
        #     x_cls = self.cls_decoder(x, x2, x3)
        #     x_cls = self.cls_head(x_cls)
            return x_cls
        elif self.mode == 'mtl':
            x_seg = self.seg_decoder(x, x2, x3)
            x_pot = self.pot_decoder(x, x2, x3)
            x_frm = self.frm_decoder(x, x2, x3)
            x_seg,x_pot,x_frm = self.stps(x_seg,x_pot,x_frm)
            x_seg = self.seg_head(x_seg,x1)
            x_pot,x_potmap = self.pot_head(x_pot,x1)
            x_frm = self.frm_head(x_frm)
        return x_pot,x_potmap,x_frm,x_seg  #x_pot=(1,30,4,2),x_frm(30,3),x_seg(1,30,1,128,128)

In [None]:
class EmbbedingSpace(nn.Module):
    def __init__(self, dim, frame ,mode,class_num=1):
        super().__init__()
        self.dim = dim
        self.frame = frame
        self.mode = mode
        self.class_num = class_num
        self.mlp_out = nn.Sequential(
            nn.LayerNorm(self.dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.dim, self.class_num))
    def forward(self, x):
        if self.mode == 'seg':
            x = rearrange(x, 'b t c n d -> (b t) (c n d)')
        elif self.mode == 'pot':
            x = rearrange(x, 'b t n d -> (b t) (n d)')
        elif self.mode == 'frm':
            x = x
        x = self.mlp_out(x)
        return x
    
class Embbeding_net(nn.Module):
    def __init__(self, dim, frame,device,class_num=1):
        super().__init__()
        self.dim = dim
        self.frame = frame
        self.class_num = class_num
        self.device = device
        self.pot_emb = EmbbedingSpace(8,frame,'pot').to(self.device)
        self.frm_emb = EmbbedingSpace(3,frame,'frm').to(self.device)
        self.seg_emb = EmbbedingSpace(128*128,frame,'seg').to(self.device)
    def forward(self, x_pot,x_frm,x_seg):
        x_frm = one_hot(x_frm.long(),3).to(device)
        x_frm = tr_oh(x_frm,30,3)
        xp_emb = self.pot_emb(x_pot.to(torch.float32)).to(self.device)
        xf_emb = self.frm_emb(x_frm.to(torch.float32)).to(self.device)
        xs_emb = self.seg_emb(x_seg.to(torch.float32)).to(self.device)
        return xp_emb,xf_emb,xs_emb

def one_hot(label, n_classes, requires_grad=True):
    """Return One Hot Label"""
    divce = label.device
    one_hot_label = torch.eye(n_classes, requires_grad=requires_grad)[label]
    # one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)
    return one_hot_label
    
def trend(frm, f, k):
    frm1 = torch.zeros_like(frm)
    for i in range(f-k):
        for kt in range(k):
            # print(k-kt,i+kt)
            frm1[i] = frm1[i]+(k-kt)*frm[i+kt]
    for i in range(k):
        i = i+f-k
        for kt in range(k):
            # print(k-kt,k-(f-i)+kt)
            frm1[i] = frm1[i]+(k-kt)*frm[k-(f-i)+kt]
    return frm1/k

def tr_oh(frm, f, k):
    frm[:,0],frm[:,1],frm[:,2] = trend(frm[:,0],f,k),trend(frm[:,1],f,k),trend(frm[:,2],f,k)
    return frm

def cosine_loss(input,target):
    # print(input.squeeze(1).shape,target.shape)
    input,target = input.squeeze(1),target.squeeze(1)
    sim = torch.cosine_similarity(input.unsqueeze(0),target.unsqueeze(0))
    return 1-sim

In [None]:
from scipy.ndimage import morphology
def coefficients(gt, pred, smooth=1e-12):
    pred = torch.sigmoid(pred)
    pred = torch.gt(pred, 0.5)
    pred = pred.type(torch.float32)
    intersection = torch.sum(gt * pred)
    gt, pred = torch.sum(gt), torch.sum(pred)
    union = gt + pred - intersection

    precision = intersection / (pred + smooth)
    recall = intersection / (gt + smooth)

    beta_square = 0.3
    f_beta_coeff = (1 + beta_square) * precision * recall / (beta_square * precision + recall + smooth)
    dice_coeff = (2. * intersection) / (union + intersection + smooth)
    jaccard_coeff = intersection / (union + smooth)
    return dice_coeff, jaccard_coeff, f_beta_coeff

def acc(pred,gt):
    right,error = 0,0
    right_es,error_es = 0,0
    right_ed,error_ed = 0,0
    for i in range(30):
        if float(gt[i:i+1]) == 0:
            if float(gt[i:i+1]) == float(torch.max(pred[i:i+1], 1)[1]):
                    right_es = right_es+1
            else:
                    error_es = error_es+1
        elif float(gt[i:i+1]) == 1:
            if float(gt[i:i+1]) == float(torch.max(pred[i:i+1], 1)[1]):
                    right = right+1
            else:
                    error = error+1
        elif float(gt[i:i+1]) == 2:
            if float(gt[i:i+1]) == float(torch.max(pred[i:i+1], 1)[1]):
                    right_ed = right_ed+1
            else:
                    error_ed = error_ed+1
    return right/(right+error),right_es/(right_es+error_es+0.002),right_ed/(right_ed+error_ed+0.002)
    

def get_hausdorff(gt, pred, sampling=0.3, connectivity=1):
    pred = torch.sigmoid(pred)
    pred = torch.gt(pred, 0.5)
    input1 = gt
    input2 = pred
    input1 = np.array(input1.cpu().clone()) 
    input2 = np.array(input2.cpu().clone()) 
    input_1 = np.atleast_1d(input1.astype(np.bool))
    input_2 = np.atleast_1d(input2.astype(np.bool))

    conn = morphology.generate_binary_structure(input_1.ndim, connectivity)

    S = input_1 ^ morphology.binary_erosion(input_1, conn)
    Sprime = input_2 ^ morphology.binary_erosion(input_2, conn)

    dta = morphology.distance_transform_edt(~S, sampling)
    dtb = morphology.distance_transform_edt(~Sprime, sampling)

    sds = np.concatenate([np.ravel(dta[Sprime != 0]), np.ravel(dtb[S != 0])])
    hausdorff_distance = sds.max()
    mean_abs_distance = np.abs(sds).mean()
    return hausdorff_distance, mean_abs_distance
    
def seg_loss(y_pred,y_true):
    y_pred = torch.sigmoid(y_pred)
    smooth       = 1e-12
    y_true_back  = 1 - y_true
    y_pred_back  = 1 - y_pred
    alpha        = 1 / (torch.pow(torch.sum(y_true), 2) + smooth)
    beta         = 1 / (torch.pow(torch.sum(y_true_back), 2) + smooth)
    numerater    = alpha * torch.sum(y_true * y_pred) + beta * torch.sum(y_true_back * y_pred_back)
    denominator  = alpha * torch.sum(y_true + y_pred) + beta * torch.sum(y_true_back + y_pred_back)
    dice_loss    = 1 - (2. * numerater) / (denominator + smooth)
    mae_loss     = torch.mean(torch.log(1 + torch.exp(torch.abs(y_pred - y_true))))
    w            = (img_size * img_size - torch.sum(y_pred)) / (torch.sum(y_pred) + smooth)
    key_w        = 0.003
    crossentropy = - torch.mean(key_w * w * y_true * torch.log(y_pred + smooth) + y_true_back * torch.log(y_pred_back + smooth))
    #print(crossentropy)
    return crossentropy + dice_loss + mae_loss

def one_hot(label, n_classes, requires_grad=True):
    """Return One Hot Label"""
    divce = label.device
    one_hot_label = torch.eye(n_classes, device=device, requires_grad=requires_grad)[label]
    # one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)
    return one_hot_label

def boundary_cos_loss(gt , pred):
    pred = torch.sigmoid(pred)
    #gt dimension (B,T,C,H,W)
    b,t,c,h,w = gt.shape
    for i in range(t):
        gt_frame = gt[0,i:i+1,:,:,:]
        pred_frame = pred[0,i:i+1,:,:,:]
        theta0 = 3
        gt_cont = F.max_pool2d(1 - gt_frame, kernel_size=theta0, stride=1, padding=(theta0 - 1) // 2)
        gt_cont -= 1 - gt_frame
        pred_cont = F.max_pool2d(1 - pred_frame, kernel_size=theta0, stride=1, padding=(theta0 - 1) // 2)
        pred_cont -= 1 - pred_frame
        sim = torch.cosine_similarity(gt_cont.squeeze(0).squeeze(0),pred_cont.squeeze(0).squeeze(0))
        sim_norm = torch.sum(sim)/(h)
        sim_loss = 1 - sim_norm*2
        if i == 0:
            loss = sim_loss
        else:
            loss = loss + sim_loss
    return loss / t

def cos_sim_loss(gt , pred):
    pred = torch.sigmoid(pred)
    #gt dimension (B,T,C,H,W)
    b,t,c,h,w = gt.shape
    for i in range(t):
        gt_frame = gt[0,i:i+1,:,:,:]
        pred_frame = pred[0,i:i+1,:,:,:]
        theta0 = 3
        sim = torch.cosine_similarity(gt_frame.squeeze(0).squeeze(0),pred_frame.squeeze(0).squeeze(0))
        sim_norm = torch.sum(sim)/(h)
        sim_loss = 1 - sim_norm*2
        if i == 0:
            loss = sim_loss
        else:
            loss = loss + sim_loss
    return loss / t

def corr_loss(pred,gt):
    pred_lva,pred_mvd,pred_lvd = pred[0,:],pred[1,:],pred[2,:]
    gt_lva,  gt_mvd,  gt_lvd   = gt[0,:],  gt[1,:],  gt[2,:]
    pred,gt = pred_lva, gt_lva
    pred_mean, gt_mean = torch.mean(pred), torch.mean(gt)
    corr_lva = (torch.sum((pred - pred_mean) * (gt - gt_mean))) / ((
                torch.sqrt(torch.sum((pred - pred_mean) ** 2)) * torch.sqrt(torch.sum((gt - gt_mean) ** 2)))+1e-12)
    pred,gt = pred_mvd, gt_mvd
    pred_mean, gt_mean = torch.mean(pred), torch.mean(gt)
    corr_mvd = (torch.sum((pred - pred_mean) * (gt - gt_mean))) / ((
                torch.sqrt(torch.sum((pred - pred_mean) ** 2)) * torch.sqrt(torch.sum((gt - gt_mean) ** 2)))+1e-12)
    pred,gt = pred_lvd, gt_lvd
    pred_mean, gt_mean = torch.mean(pred), torch.mean(gt)
    corr_lvd = (torch.sum((pred - pred_mean) * (gt - gt_mean))) / ((
                torch.sqrt(torch.sum((pred - pred_mean) ** 2)) * torch.sqrt(torch.sum((gt - gt_mean) ** 2)))+1e-12)
    #print('corr1:',corr_lva,'corr2:',corr_mvd,'corr3:',corr_lvd)
    corr = corr_lva+2*corr_mvd+2*corr_lvd+1e-12
    return 5-corr

def mae_loss(pred,gt):
    mae1 = torch.mean(torch.abs(pred[0,:]-gt[0,:]))
    mae2 = torch.mean(torch.abs(pred[1,:]-gt[1,:]))
    mae3 = torch.mean(torch.abs(pred[2,:]-gt[2,:]))
    mae = mae2+mae3*2
    #mae = torch.mean((pred-gt)* torch.tanh(pred-gt))
    #logcosh = torch.mean(torch.log(torch.cosh((pred-gt) + 1e-12)))
    mae_mean = torch.mean(torch.abs(torch.mean(pred)-torch.mean(gt)))
    return mae

def mae_point(pred,gt):
    mae = torch.mean(torch.abs(pred-gt))
    #mae = torch.mean((pred-gt)* torch.tanh(pred-gt))
    logcosh = torch.mean(torch.log(torch.cosh((pred-gt) + 1e-12)))
    return mae

def mae_cal(pred,gt):
    mae1 = torch.mean(torch.abs(pred[0,:]-gt[0,:]))
    mae2 = torch.mean(torch.abs(pred[1,:]-gt[1,:]))
    mae3 = torch.mean(torch.abs(pred[2,:]-gt[2,:]))
    mae = mae2+mae3*2
    #mae = torch.mean((pred-gt)* torch.tanh(pred-gt))
    logcosh = torch.mean(torch.log(torch.cosh((pred-gt) + 1e-12)))
    return mae


def person_corr(pred, gt):#皮尔森相关系数
    pred_lva,pred_mvd,pred_lvd = pred[0,:],pred[1,:],pred[2,:]
    gt_lva,  gt_mvd,  gt_lvd   = gt[0,:],  gt[1,:],  gt[2,:]
    pred,gt = pred_lva, gt_lva
    pred_mean, gt_mean = torch.mean(pred), torch.mean(gt)
    corr_lva = (torch.sum((pred - pred_mean) * (gt - gt_mean))) / (
                torch.sqrt(torch.sum((pred - pred_mean) ** 2)) * torch.sqrt(torch.sum((gt - gt_mean) ** 2))+1e-12)
    pred,gt = pred_mvd, gt_mvd
    pred_mean, gt_mean = torch.mean(pred), torch.mean(gt)
    corr_mvd = (torch.sum((pred - pred_mean) * (gt - gt_mean))) / (
                torch.sqrt(torch.sum((pred - pred_mean) ** 2)) * torch.sqrt(torch.sum((gt - gt_mean) ** 2))+1e-12)
    pred,gt = pred_lvd, gt_lvd
    pred_mean, gt_mean = torch.mean(pred), torch.mean(gt)
    corr_lvd = (torch.sum((pred - pred_mean) * (gt - gt_mean))) / (
                torch.sqrt(torch.sum((pred - pred_mean) ** 2)) * torch.sqrt(torch.sum((gt - gt_mean) ** 2))+1e-12)
    return corr_lva, corr_mvd,corr_lvd


def point2linear(point):
    mvdt = []
    lvdt = []
    #point:1,30,4,2
    for t in range(30):
        mvd = ((point[t,0,0]-point[t,1,0])**2+(point[t,0,1]-point[t,1,1])**2)**(1/2)
        lvd = ((point[t,2,0]-point[t,3,0])**2+(point[t,2,1]-point[t,3,1])**2)**(1/2)
        mvdt.append(mvd)
        lvdt.append(lvd)
        if t>0:
            if mvdt[t]<mvdt[t-1]*0.6:
                mvdt[t] = mvdt[t-1]
            # elif mvdt[t]>mvdt[t-1]*1.8:
            #     mvdt[t] = mvdt[t-1]
    return mvdt,lvdt

def pot2index(srcnpy):
    mvd_gt,lvd_gt  = [],[]
    # print(srcnpy.shape)#1,30,4,2
    # for i in range(srcnpy.shape[0]):
    mvd1,lvd1 = point2linear(srcnpy[0])
    mvd_gt.append(mvd1)
    lvd_gt.append(lvd1)
    #print(lvd,lvd1)
    lvd_gt = np.array(lvd_gt)[:,:,np.newaxis]
    mvd_gt = np.array(mvd_gt)[:,:,np.newaxis]
    index  = np.concatenate((lvd_gt, mvd_gt), axis=2)
    return index

def mae_div(pred,gt):
    pred = pred.cpu().detach().numpy()
    gt = gt.cpu().detach().numpy()
    pred_ind = pot2index(pred)
    gt_ind   = pot2index(gt)
    # print(gt_ind.shape)
    mae1 = np.mean(np.abs(pred_ind[:,:,0]-gt_ind[:,:,0]))
    mae2 = np.mean(np.abs(pred_ind[:,:,1]-gt_ind[:,:,1]))
    mae = mae1+mae2
    #mae = torch.mean((pred-gt)* torch.tanh(pred-gt))
    # logcosh = torch.mean(torch.log(torch.cosh((pred-gt) + 1e-12)))
    return mae1,mae2



In [None]:
def l2_norm(seg , pot):
    seg = torch.sigmoid(seg)
    pot = torch.sigmoid(pot)
    #gt dimension (B,T,C,H,W)
    b,t,c,h,w = seg.shape
    for i in range(t):
        seg_frame = seg[0,i:i+1,:,:,:]
        pot_frame = pot[0,i:i+1,:,:,:]
        theta0 = 9
        seg_cont = F.max_pool2d(1 - seg_frame, kernel_size=theta0, stride=1, padding=(theta0 - 1) // 2)
        seg_cont -= 1 - seg_frame
        pot_cont = F.max_pool2d(1 - pot_frame, kernel_size=theta0, stride=1, padding=(theta0 - 1) // 2)
        pot_cont -= 1 - pot_frame
        potmine = (pot_frame.squeeze(0).squeeze(0).to(torch.float32)/3-
                                                                    seg_cont.squeeze(0).squeeze(0).to(torch.float32))
        potmine = 1-torch.gt(potmine,0).to(torch.float32)
        potsum = torch.sum(potmine)
        segmine = torch.gt(seg_cont,0).to(torch.float32)
        segsum = torch.sum(segmine)
    return potsum/segsum

def cal_a(mask):
    lab = torch.reshape(mask,(mask.shape[0],30,1,128*128))
    lab_line = torch.sum(lab,dim=3)
    lab_line1 = lab_line - torch.mean(lab_line)
    lab_line1 = lab_line1/torch.std(lab_line)
    return lab_line1


def cosine_loss(input,target):
    # print(input.squeeze(1).shape,target.shape)
    input,target = input.squeeze(1),target.squeeze(1)
    sim = torch.cosine_similarity(input.unsqueeze(0),target.unsqueeze(0))
    return 1-sim

In [None]:
#traning set
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import OrdinalEncoder
import numpy.matlib
def landmark(center_x,center_y,IMAGE_HEIGHT, IMAGE_WIDTH):
    R = np.sqrt(2**2 + 2**2)
    Gauss_map = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH))
    mask_x = np.matlib.repmat(center_x, IMAGE_HEIGHT, IMAGE_WIDTH)
    mask_y = np.matlib.repmat(center_y, IMAGE_HEIGHT, IMAGE_WIDTH)
    x1 = np.arange(IMAGE_WIDTH)
    x_map = np.matlib.repmat(x1, IMAGE_HEIGHT, 1)
    y1 = np.arange(IMAGE_HEIGHT)
    y_map = np.matlib.repmat(y1, IMAGE_WIDTH, 1)
    y_map = np.transpose(y_map)
    Gauss_map = np.sqrt((x_map-mask_x)**2+(y_map-mask_y)**2)
    Gauss_map = np.exp(-0.5*Gauss_map/R)
    return Gauss_map

def locmap(pot):
    gauss_batch = []
    for i in range(0,pot.shape[0]):
        gauss_tp = []
        for j in range(0,pot.shape[1]):
            g_map1 = landmark(pot[i, j, 0, 0],pot[i, j, 0, 1],128,128)
            g_map2 = landmark(pot[i, j, 1, 0],pot[i, j, 1, 1],128,128)
            g_map3 = landmark(pot[i, j, 2, 0],pot[i, j, 2, 1],128,128)
            g_map4 = landmark(pot[i, j, 3, 0],pot[i, j, 3, 1],128,128)
            Gauss_map = (g_map1+g_map2+g_map3+g_map4)/4
            gauss_tp.append(Gauss_map)
        gauss_batch.append(gauss_tp)
    gauss_batch = np.array(gauss_batch)[:, :, :, :, np.newaxis]
    return gauss_batch

data_path = '/data/zhangzhenxuan/nature_data'
train_set_down = 0
train_set_up = 1800
nat_a4c_ims  = np.load(data_path + '/' + 'ims_a4c_1.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
nat_a4c_gts  = np.load(data_path + '/' + 'gts_a4c_1.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
nat_a4c_cls  = 4*np.ones((train_set_up))
nat_a4c_pot  = np.load(data_path + '/' + 'pot_a4c_1.npy')[train_set_down:train_set_up, :]
nat_a4c_potmap  = np.load(data_path + '/' + 'potmap_a4c_1.npy')[train_set_down:train_set_up, :]
nat_a4c_frm  = np.load(data_path + '/' + 'frm_a4c_1.npy')[train_set_down:train_set_up, :]
print('==============load nature===============')

data_path = '/data/zhangzhenxuan/HMC_QU_data'
train_set_down = 0
train_set_up = 100
hmc_a4c_ims  = np.load(data_path + '/' + 'ims_a4c_hmc.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
hmc_a4c_gts  = np.load(data_path + '/' + 'gts_a4c_hmc.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
hmc_a4c_cls  = 4*np.ones((train_set_up))
hmc_a4c_pot  = np.load(data_path + '/' + 'pot_a4c_hmc.npy')[train_set_down:train_set_up, :]
hmc_a4c_potmap  = np.load(data_path + '/' + 'potmap_a4c_hmc.npy')[train_set_down:train_set_up, :]
hmc_a4c_frm  = np.load(data_path + '/' + 'frm_a4c_hmc.npy')[train_set_down:train_set_up, :]
print('==============load hmc===============')

data_path = '/data/zhangzhenxuan/camus_data'
train_set_down = 0
train_set_up = 500
camus_a2c_ims  = np.load(data_path + '/' + 'ims_a2c_camus.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
camus_a4c_ims  = np.load(data_path + '/' + 'ims_a4c_camus.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
camus_a2c_gts  = np.load(data_path + '/' + 'gts_a2c_camus.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
camus_a4c_gts  = np.load(data_path + '/' + 'gts_a4c_camus.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
camus_a2c_cls  = 2*np.ones((train_set_up))
camus_a4c_cls  = 4*np.ones((train_set_up))
camus_a2c_pot  = np.load(data_path + '/' + 'pot_a2c_camus.npy')[train_set_down:train_set_up, :]
camus_a4c_pot  = np.load(data_path + '/' + 'pot_a4c_camus.npy')[train_set_down:train_set_up, :]
camus_a2c_potmap  = np.load(data_path + '/' + 'potmap_a2c_camus.npy')[train_set_down:train_set_up, :]
camus_a4c_potmap  = np.load(data_path + '/' + 'potmap_a4c_camus.npy')[train_set_down:train_set_up, :]
camus_a2c_frm  = np.load(data_path + '/' + 'frm_a2c_camus.npy')[train_set_down:train_set_up, :]
camus_a4c_frm  = np.load(data_path + '/' + 'frm_a4c_camus.npy')[train_set_down:train_set_up, :]
print('==============load camus===============')

data_path = '/data/zhangzhenxuan/lm_data'
train_set_down = 0
train_set_up = 122
lm_a2c_ims  = np.load(data_path + '/' + 'ims_a2c_lm.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
lm_a3c_ims  = np.load(data_path + '/' + 'ims_a3c_lm.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
lm_a4c_ims  = np.load(data_path + '/' + 'ims_a4c_lm.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
lm_a2c_gts  = np.load(data_path + '/' + 'gts_a2c_lm.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
lm_a3c_gts  = np.load(data_path + '/' + 'gts_a3c_lm.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
lm_a4c_gts  = np.load(data_path + '/' + 'gts_a4c_lm.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
lm_a2c_cls  = 2*np.ones((train_set_up))
lm_a3c_cls  = 3*np.ones((train_set_up))
lm_a4c_cls  = 4*np.ones((train_set_up))
lm_a2c_pot  = np.load(data_path + '/' + 'pot_a2c_lm.npy')[train_set_down:train_set_up, :]
lm_a3c_pot  = np.load(data_path + '/' + 'pot_a3c_lm.npy')[train_set_down:train_set_up, :]
lm_a4c_pot  = np.load(data_path + '/' + 'pot_a4c_lm.npy')[train_set_down:train_set_up, :]
lm_a2c_potmap  = np.load(data_path + '/' + 'potmap_a2c_lm.npy')[train_set_down:train_set_up, :]
lm_a3c_potmap  = np.load(data_path + '/' + 'potmap_a3c_lm.npy')[train_set_down:train_set_up, :]
lm_a4c_potmap  = np.load(data_path + '/' + 'potmap_a4c_lm.npy')[train_set_down:train_set_up, :]
lm_a2c_frm  = np.load(data_path + '/' + 'frm_a2c_lm.npy')[train_set_down:train_set_up, :]
lm_a3c_frm  = np.load(data_path + '/' + 'frm_a3c_lm.npy')[train_set_down:train_set_up, :]
lm_a4c_frm  = np.load(data_path + '/' + 'frm_a4c_lm.npy')[train_set_down:train_set_up, :]
print('==============load lm===============')

data_path = '/data/zhangzhenxuan/mx_data'
train_set_down = 0
train_set_up = 80
mx_a2c_ims  = np.load(data_path + '/' + 'ims_a2c_mx.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
mx_a3c_ims  = np.load(data_path + '/' + 'ims_a3c_mx.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
mx_a4c_ims  = np.load(data_path + '/' + 'ims_a4c_mx.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
mx_a2c_gts  = np.load(data_path + '/' + 'gts_a2c_mx.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
mx_a3c_gts  = np.load(data_path + '/' + 'gts_a3c_mx.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
mx_a4c_gts  = np.load(data_path + '/' + 'gts_a4c_mx.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
mx_a2c_cls  = 2*np.ones((train_set_up))
mx_a3c_cls  = 3*np.ones((train_set_up))
mx_a4c_cls  = 4*np.ones((train_set_up))
mx_a2c_pot  = np.load(data_path + '/' + 'pot_a2c_mx.npy')[train_set_down:train_set_up, :]
mx_a3c_pot  = np.load(data_path + '/' + 'pot_a3c_mx.npy')[train_set_down:train_set_up, :]
mx_a4c_pot  = np.load(data_path + '/' + 'pot_a4c_mx.npy')[train_set_down:train_set_up, :]
mx_a2c_potmap  = np.load(data_path + '/' + 'potmap_a2c_mx.npy')[train_set_down:train_set_up, :]
mx_a3c_potmap  = np.load(data_path + '/' + 'potmap_a3c_mx.npy')[train_set_down:train_set_up, :]
mx_a4c_potmap  = np.load(data_path + '/' + 'potmap_a4c_mx.npy')[train_set_down:train_set_up, :]
mx_a2c_frm  = np.load(data_path + '/' + 'frm_a2c_mx.npy')[train_set_down:train_set_up, :]
mx_a3c_frm  = np.load(data_path + '/' + 'frm_a3c_mx.npy')[train_set_down:train_set_up, :]
mx_a4c_frm  = np.load(data_path + '/' + 'frm_a4c_mx.npy')[train_set_down:train_set_up, :]
print('==============load mx===============')

data_path = '/data/zhangzhenxuan/szkid'
train_set_down = 0
train_set_up = 80
sz_a2c_ims  = np.load(data_path + '/' + 'ims_a2c.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_a3c_ims  = np.load(data_path + '/' + 'ims_a3c.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_a4c_ims  = np.load(data_path + '/' + 'ims_a4c.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_asc_ims  = np.load(data_path + '/' + 'ims_asc.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_a2c_gts  = np.load(data_path + '/' + 'gts_a2c.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_a3c_gts  = np.load(data_path + '/' + 'gts_a3c.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_a4c_gts  = np.load(data_path + '/' + 'gts_a4c.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_asc_gts  = np.load(data_path + '/' + 'gts_asc.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_out_gts  = np.load(data_path + '/' + 'gts_out.npy')[train_set_down:train_set_up, :, :, :, np.newaxis]
sz_a2c_cls  = 2*np.ones((train_set_up))
sz_a3c_cls  = 3*np.ones((train_set_up))
sz_a4c_cls  = 4*np.ones((train_set_up))
sz_asc_cls  = 1*np.ones((train_set_up))
sz_a2c_pot  = np.load(data_path + '/' + 'pot_a2c.npy')[train_set_down:train_set_up, :]
sz_a3c_pot  = np.load(data_path + '/' + 'pot_a3c.npy')[train_set_down:train_set_up, :]
sz_a4c_pot  = np.load(data_path + '/' + 'pot_a4c.npy')[train_set_down:train_set_up, :]
sz_a2c_potmap  = np.load(data_path + '/' + 'potmap_a2c.npy')[train_set_down:train_set_up, :]
sz_a3c_potmap  = np.load(data_path + '/' + 'potmap_a3c.npy')[train_set_down:train_set_up, :]
sz_a4c_potmap  = np.load(data_path + '/' + 'potmap_a4c.npy')[train_set_down:train_set_up, :]
sz_a2c_frm  = np.load(data_path + '/' + 'frm_a2c.npy')[train_set_down:train_set_up, :]
sz_a3c_frm  = np.load(data_path + '/' + 'frm_a3c.npy')[train_set_down:train_set_up, :]
sz_a4c_frm  = np.load(data_path + '/' + 'frm_a4c.npy')[train_set_down:train_set_up, :]
print('==============load sz===============')

# ims         = np.concatenate((hmc_a4c_ims,camus_a2c_ims, camus_a4c_ims,mx_a2c_ims, mx_a3c_ims, mx_a4c_ims,
#                               lm_a2c_ims, lm_a3c_ims, lm_a4c_ims,sz_a2c_ims, sz_a3c_ims, sz_a4c_ims, sz_asc_ims), axis=0)
# gts         = np.concatenate((hmc_a4c_gts,camus_a2c_gts, camus_a4c_gts,mx_a2c_gts, mx_a3c_gts, mx_a4c_gts,
#                               lm_a2c_gts, lm_a3c_gts, lm_a4c_gts,sz_a2c_gts, sz_a3c_gts, sz_a4c_gts, sz_out_gts), axis=0)
# clss         = np.concatenate((hmc_a4c_cls,camus_a2c_cls, camus_a4c_cls,mx_a2c_cls, mx_a3c_cls, mx_a4c_cls,
#                              lm_a2c_cls, lm_a3c_cls, lm_a4c_cls,sz_a2c_cls, sz_a3c_cls, sz_a4c_cls, sz_asc_cls), axis=0)
# ims         = np.concatenate((hmc_a4c_ims,camus_a2c_ims, camus_a4c_ims,mx_a2c_ims, mx_a3c_ims, mx_a4c_ims,
#                               lm_a2c_ims, lm_a3c_ims, lm_a4c_ims,sz_a2c_ims, sz_a3c_ims, sz_a4c_ims), axis=0)
# gts         = np.concatenate((hmc_a4c_gts,camus_a2c_gts, camus_a4c_gts,mx_a2c_gts, mx_a3c_gts, mx_a4c_gts,
#                               lm_a2c_gts, lm_a3c_gts, lm_a4c_gts,sz_a2c_gts, sz_a3c_gts, sz_a4c_gts), axis=0)
# clss        = np.concatenate((hmc_a4c_cls,camus_a2c_cls, camus_a4c_cls,mx_a2c_cls, mx_a3c_cls, mx_a4c_cls,
#                               lm_a2c_cls, lm_a3c_cls, lm_a4c_cls,sz_a2c_cls, sz_a3c_cls, sz_a4c_cls), axis=0)
# reg         = np.concatenate((hmc_a4c_reg,camus_a2c_reg, camus_a4c_reg,mx_a2c_reg, mx_a3c_reg, mx_a4c_reg,
#                               lm_a2c_reg, lm_a3c_reg, lm_a4c_reg,sz_a2c_reg, sz_a3c_reg, sz_a4c_reg), axis=0)

data_mode = 'nat'
if data_mode == 'psax':
    ims         = np.concatenate((sz_a2c_ims, sz_a3c_ims, sz_a4c_ims, sz_asc_ims), axis=0)
    gts         = np.concatenate((sz_a2c_gts, sz_a3c_gts, sz_a4c_gts, sz_out_gts), axis=0)
    clss         = np.concatenate((sz_a2c_cls, sz_a3c_cls, sz_a4c_cls, sz_asc_cls), axis=0)
    pot         = np.concatenate((sz_a2c_pot, sz_a3c_pot, sz_a4c_pot, sz_asc_pot), axis=0)
    frm         = np.concatenate((sz_a2c_frm, sz_a3c_frm, sz_a4c_frm, sz_asc_frm), axis=0)
elif data_mode == 'apic_1':
    ims         = np.concatenate((sz_a2c_ims, sz_a3c_ims, sz_a4c_ims), axis=0)
    gts         = np.concatenate((sz_a2c_gts, sz_a3c_gts, sz_a4c_gts), axis=0)
    clss         = np.concatenate((sz_a2c_cls, sz_a3c_cls, sz_a4c_cls), axis=0)
    pot         = np.concatenate((sz_a2c_pot, sz_a3c_pot, sz_a4c_pot), axis=0)
    potmap      = locmap(pot)
    frm         = np.concatenate((sz_a2c_frm, sz_a3c_frm, sz_a4c_frm), axis=0) 
elif data_mode == 'apic_2':
    ims         = np.concatenate((hmc_a4c_ims,camus_a2c_ims, camus_a4c_ims,mx_a2c_ims, mx_a3c_ims, mx_a4c_ims,
                                  lm_a2c_ims, lm_a3c_ims, lm_a4c_ims,sz_a2c_ims, sz_a3c_ims, sz_a4c_ims), axis=0)
    print(ims.shape)
    gts         = np.concatenate((hmc_a4c_gts,camus_a2c_gts, camus_a4c_gts,mx_a2c_gts, mx_a3c_gts, mx_a4c_gts,
                                  lm_a2c_gts, lm_a3c_gts, lm_a4c_gts,sz_a2c_gts, sz_a3c_gts, sz_a4c_gts), axis=0)
    print(gts.shape)
    clss        = np.concatenate((hmc_a4c_cls,camus_a2c_cls, camus_a4c_cls,mx_a2c_cls, mx_a3c_cls, mx_a4c_cls,
                                  lm_a2c_cls, lm_a3c_cls, lm_a4c_cls,sz_a2c_cls, sz_a3c_cls, sz_a4c_cls), axis=0)
    print(clss.shape)
    pot         = np.concatenate((hmc_a4c_pot,camus_a2c_pot, camus_a4c_pot,mx_a2c_pot, mx_a3c_pot, mx_a4c_pot,
                                  lm_a2c_pot, lm_a3c_pot, lm_a4c_pot,sz_a2c_pot, sz_a3c_pot, sz_a4c_pot), axis=0)
    potmap      = np.concatenate((hmc_a4c_potmap,camus_a2c_potmap, camus_a4c_potmap,mx_a2c_potmap, mx_a3c_potmap, mx_a4c_potmap,
                                  lm_a2c_potmap, lm_a3c_potmap, lm_a4c_potmap,sz_a2c_potmap, sz_a3c_potmap, sz_a4c_potmap), axis=0)
    print(pot.shape,potmap.shape)
    #potmap      = locmap(pot)
    frm         = np.concatenate((hmc_a4c_frm,camus_a2c_frm, camus_a4c_frm,mx_a2c_frm, mx_a3c_frm, mx_a4c_frm,
                                  lm_a2c_frm, lm_a3c_frm, lm_a4c_frm,sz_a2c_frm, sz_a3c_frm, sz_a4c_frm), axis=0) 
elif data_mode == 'nat':
    ims         = np.concatenate((nat_a4c_ims,hmc_a4c_ims,camus_a2c_ims, camus_a4c_ims,mx_a2c_ims, mx_a3c_ims, mx_a4c_ims,
                                  lm_a2c_ims, lm_a3c_ims, lm_a4c_ims,sz_a2c_ims, sz_a3c_ims, sz_a4c_ims), axis=0)
    print(ims.shape)
    gts         = np.concatenate((nat_a4c_gts,hmc_a4c_gts,camus_a2c_gts, camus_a4c_gts,mx_a2c_gts, mx_a3c_gts, mx_a4c_gts,
                                  lm_a2c_gts, lm_a3c_gts, lm_a4c_gts,sz_a2c_gts, sz_a3c_gts, sz_a4c_gts), axis=0)
    print(gts.shape)
    clss        = np.concatenate((nat_a4c_cls,hmc_a4c_cls,camus_a2c_cls, camus_a4c_cls,mx_a2c_cls, mx_a3c_cls, mx_a4c_cls,
                                  lm_a2c_cls, lm_a3c_cls, lm_a4c_cls,sz_a2c_cls, sz_a3c_cls, sz_a4c_cls), axis=0)
    print(clss.shape)
    pot         = np.concatenate((nat_a4c_pot,hmc_a4c_pot,camus_a2c_pot, camus_a4c_pot,mx_a2c_pot, mx_a3c_pot, mx_a4c_pot,
                                  lm_a2c_pot, lm_a3c_pot, lm_a4c_pot,sz_a2c_pot, sz_a3c_pot, sz_a4c_pot), axis=0)
    potmap      = np.concatenate((nat_a4c_potmap,hmc_a4c_potmap,camus_a2c_potmap, camus_a4c_potmap,mx_a2c_potmap, mx_a3c_potmap, mx_a4c_potmap,
                                  lm_a2c_potmap, lm_a3c_potmap, lm_a4c_potmap,sz_a2c_potmap, sz_a3c_potmap, sz_a4c_potmap), axis=0)
    print(pot.shape,potmap.shape)
    #potmap      = locmap(pot)
    frm         = np.concatenate((nat_a4c_frm,hmc_a4c_frm,camus_a2c_frm, camus_a4c_frm,mx_a2c_frm, mx_a3c_frm, mx_a4c_frm,
                                  lm_a2c_frm, lm_a3c_frm, lm_a4c_frm,sz_a2c_frm, sz_a3c_frm, sz_a4c_frm), axis=0) 
    print(frm.shape) 
    
# ims         = camus_a4c_ims
# gts         = camus_a4c_gts
#oe = OrdinalEncoder()
#clss = oe.fit_transform(clss.reshape(-1, 1)).ravel()
#clss = np.eye(4)[np.array(clss, dtype=np.int32)]
print(ims.shape, gts.shape, clss.shape,pot.shape,potmap.shape,frm.shape)
#ims = ims[:,:,:,:,:]
#gts = gts[1:91,:,:,:,:]
image = torch.from_numpy(ims)
image = image.permute(0, 1, 4, 2, 3)
label = torch.from_numpy(gts)
label = label.permute(0, 1, 4, 2, 3)
classi = torch.from_numpy(clss)
point = torch.from_numpy(pot)
pointmap = torch.from_numpy(potmap)
pointmap = pointmap.permute(0, 1, 4, 2, 3)
frame = torch.from_numpy(frm)
#image = image[0:15, :, :,:,:]
#label = label[0:15, :, :,:,:]
pi = np.random.permutation(image.shape[0])
image = image[pi, :, :,:,:]
label = label[pi, :, :,:,:]
classi = classi[pi]-1
point = point[pi]
pointmap = pointmap[pi]
frame = frame[pi]
#split train-test
sp = 3500
sup = 3700
image_test = image[sp:sup,:,:,:,:]
label_test = label[sp:sup,:,:,:,:]
classi_test = classi[sp:sup]
point_test = point[sp:sup,:,:]
pointmap_test = pointmap[sp:sup,:,:,:,:]
frame_test = frame[sp:sup,:]
image = image[:sp,:,:,:,:]
label = label[:sp:,:,:,:,:]
classi = classi[:sp]
pointmap = pointmap[:sp:,:,:,:,:]
frame = frame[:sp,:]

In [None]:
#训练分割分支可视化
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torchvision
import torch
from torchvision.transforms import transforms
from torch import nn, optim
import timeit
import os
from tqdm import tqdm
from tqdm import trange

img_size = 128
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.backends.cudnn.benchmark = True
l_r = 0.0002  #0.0002
device = torch.device("cuda")
train_mode = 'mtl'
model = Multivit_net(img_dim=128,in_channels=1,out_channels=32,head_num=4,mlp_dim=512,block_num=8,
                     patch_dim=16,class_num=1,drop_rate = 0.2,seq_frame=30,mode =train_mode,height=128,weight=128).to(device)
#model.load_state_dict(torch.load('./weight/mlt_weights_beta0.1.pth'),True)
model.load_state_dict(torch.load('./weight/mlt_weights_base.pth'),True)
enet = Embbeding_net(dim=1, frame=30,device = device).to(device)
enet.load_state_dict(torch.load( './weight_emb/mlt_emb_min4.pth'),True)
param_optim = []
layers = []
optimizer = torch.optim.Adam(model.parameters(), lr=l_r)
dice_less = 0.85
cor_less = 0.85
mae_less = 30
pot_less = 30
frm_less = 30
cls_less = 0.8
criterion = nn.CrossEntropyLoss()
criterion_f = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array([5/13,3/13,5/13])).float(),size_average=True)
criterion_f.to(device)
crit_pot = torch.nn.SmoothL1Loss()
kl_loss = nn.KLDivLoss(reduction="batchmean")
pot_loss = nn.MSELoss(size_average=True)
lossseg,losspot,lossfrm,losscls = [],[],[],[]
for t in range(50):
    # Forward pass: Compute predicted y by passing x to the model
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [5, 10,15], 0.1)
    lens = image.shape[0]
    print('epoch=',t)
    dice,jcd, fb = 0,0,0
    cor_lva ,cor_mvd,cor_lvd = 0,0,0
    mae_all = 0
    mae_all_pot = 0
    mae_all_frm = 0
    right,error = 0,0
    image = image
    label = label
    classi = classi
    point = point
    pointmap = pointmap
    frame = frame
    lens2 = image_test.shape[0]
    image_test = image_test.to(device)
    label_test = label_test.to(device)
    classi_test = classi_test.to(device)
    point_test = point_test.to(device)
    frame_test = frame_test.to(device)
    model.train()
    print('training')
    with trange(lens) as tr:
        for i in tr:
            video = image[i:i+1,:,:,:,:].to(device)
            #print('batch=',i,' of ',lens)
            labelv = label[i:i+1,:,:,:,:].to(device)
            classiv = classi[i:i+1].to(device)
            pointi = point[i:i+1].to(device)
            pointmapi = pointmap[i:i+1].to(device)
            framei = frame[i].to(device)
            if train_mode=='cls':
                pred_cls = model(video)
                # Compute and print loss
                loss_cls = criterion(pred_cls,classiv.long())
                loss = loss_cls
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),class_ori=classiv,class_pred=float(torch.max(pred_cls, 1)[1]))
                if float(classiv) == float(torch.max(pred_cls, 1)[1]):
                    right = right+1
                else:
                    error = error+1
            elif train_mode=='pot':
                pred_pot,pred_potmap = model(video)
                #print(pred_pot.shape,pointi.shape)
                loss_pot = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                # loss_pot_map = kl_loss(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                loss_pot_map = pot_loss(pointmapi.to(torch.float32),pred_potmap.to(torch.float32))
                loss = loss_pot_map+loss_pot
                mae = mae_point(pred_pot,pointi)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),mae=float(mae),loss_map=float(loss_pot_map),loss_pot=float(loss_pot))
                mae_all+=float(mae)
            elif train_mode=='frm':
                pred_frm = model(video)
                # pred_frm = torch.tanh(pred_frm)
                # loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                framei = framei+1
                loss_frm = criterion_f(pred_frm,framei.long())
                loss = loss_frm
                mae = acc(pred_frm,framei)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),mae=float(mae))
                mae_all+=float(mae)
            elif train_mode=='seg':
                pred_seg = model(video)
                dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv,pred_seg) # 计算指标
                dice = dice + float(dice_coeff)
                loss1 = seg_loss(pred_seg, labelv)
                loss2 = boundary_cos_loss(labelv,pred_seg)
                a = t/50
                loss = (1-a)*loss1+a*loss2
                hd,md = get_hausdorff(labelv,pred_seg)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),dice=float(dice_coeff))
            elif train_mode == 'mtl':
                pred_pot,pred_potmap,pred_frm,pred_seg = model(video)
                # loss_cls = criterion(pred_cls,classiv.long())
                loss_pot1 = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                loss_pot_map = pot_loss(torch.sigmoid(pred_potmap).to(torch.float32), pointmapi.to(torch.float32))
                #print(pred_potmap.shape, pointmapi.shape)
                #loss_pot_map = criterion(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                loss_pot = loss_pot1+loss_pot_map
                # pred_frm = torch.tanh(pred_frm)
                framei = framei+1
                # print(framei.shape)
                loss_frm = criterion_f(pred_frm,framei.long())
                loss1 = seg_loss(pred_seg, labelv)
                loss2 = boundary_cos_loss(labelv,pred_seg)
                loss_con1 = l2_norm(pred_seg,pointmapi)
                loss_con2 = kl_loss(F.log_softmax(cal_a(pred_seg)[0,:,0], dim=0), F.softmax(torch.topk(pred_frm, 1)[1].squeeze(1).float(), dim=0))
                a = t/50
                #embspace
                pot_emb,frm_emb,seg_emb = enet(pred_pot,framei,pred_seg)
                loss_emb = cosine_loss(pot_emb,framei.unsqueeze(1))+cosine_loss(seg_emb,framei.unsqueeze(1))+crit_pot(pot_emb,seg_emb)
                loss_seg = (1-a)*loss1+a*loss2
                w_1,w_2,w_3 = 1.0,1.0,1.0
                loss = w_1*loss_pot+w_2*loss_frm+w_3*loss_seg+loss_emb
                # loss = loss_con1+loss_con2
                mae_lvd,mae_mvd = mae_div(pred_pot,pointi)
                mae_frm,mae_es,mae_ed = acc(pred_frm,framei)
                dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv,pred_seg) # 计算指标
                mae_all_frm+=float(mae_frm)
                mae_all_pot+=float(mae_lvd)
                dice = dice + float(dice_coeff)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),mae_lvd=float(mae_lvd),mae_mvd=float(mae_mvd)
                              ,mae_frm=float(mae_frm),mae_ed = float(mae_ed),mae_es = float(mae_es),
                               dice=float(dice_coeff),loss_seg=float(loss_seg),loss_frm=float(loss_frm),loss_pot=float(loss_pot))
                # if float(classiv) == float(torch.max(pred_cls, 1)[1]):
                #     right = right+1
                # else:
                #     error = error+1
                lossseg.append(float(loss_seg))
                lossfrm.append(float(loss_frm))
                losspot.append(float(loss_pot))
                # losscls.append(float(loss_cls))
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad() # 梯度置零，因为反向传播过程中梯度会累加上一次循环的梯度
            loss.backward() # loss反向传播
            optimizer.step() # 反向传播后参数更新 
    #print('loss=', loss)
    if train_mode=='cls':
        print('epoch_train_acc=',right/(right+error))
    elif train_mode=='pot':
        print('epoch_train_acc=',mae_all/lens)
    elif train_mode=='frm':
        print('epoch_train_acc=',mae_all/lens)
    elif train_mode=='seg':
        print('epoch_train_acc=',dice/lens)
    elif train_mode=='mtl':
        print('epoch_train_frame=',mae_all_frm/lens,
              'epoch_train_point=',mae_all_pot/lens,'epoch_train_dice=',dice/lens)
    
    with torch.no_grad():
    #测试阶段
        #model.eval()
        print('evaluating')
        dice,jcd, fb = 0,0,0
        cor_lva ,cor_mvd,cor_lvd = 0,0,0
        mae_all = 0
        mae_all_pot = 0
        mae_all_frm = 0
        right,error = 0,0
        with trange(lens2) as tr:
            for i in tr:
                video_test = image_test[i:i+1,:,:,:,:]
                #print('batch=',i,' of ',lens)
                labelv_test = label_test[i:i+1,:,:,:,:]
                classiv_test = classi_test[i:i+1]
                pointi_test = point_test[i:i+1]
                framei_test = frame_test[i]
                if train_mode == 'cls':
                    pred_test_cls = model(video_test)
                    tr.set_postfix(loss=float(loss),class_ori=classiv,class_pred=float(torch.max(pred_cls, 1)[1]))
                    if float(classiv_test) == float(torch.max(pred_test_cls, 1)[1]):
                        right = right+1
                    else:
                        error = error+1
                elif train_mode == 'pot':
                    pred_pot_test,_ = model(video_test)
                    mae = mae_point(pred_pot_test,pointi_test)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),mae=float(mae))
                    mae_all+=float(mae) 
                elif train_mode == 'frm':
                    pred_frm_test = model(video_test)
                    framei_test = framei_test+1
                    mae = acc(pred_frm_test,framei_test)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),mae=float(mae))
                    mae_all+=float(mae) 
                elif train_mode=='seg':
                    pred_seg_test = model(video_test)
                    dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv_test,pred_seg_test) # 计算指标
                    dice = dice + float(dice_coeff)
                    hd,md = get_hausdorff(labelv,pred_seg_test)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),dice=float(dice_coeff))
                elif train_mode=='mtl':
                    pred_pot_test,_,pred_frm_test,pred_seg_test = model(video_test)
                    pred_frm_test = torch.tanh(pred_frm_test)
                    mae_lvd_test,mae_mvd_test = mae_div(pred_pot_test,pointi_test)
                    framei_test = framei_test+1
                    mae_frm_test,mae_es_test,mae_ed_test = acc(pred_frm_test,framei_test)
                    # mae_frm_test = mae_point(pred_frm_test,framei_test)
                    dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv_test,pred_seg_test) # 计算指标
                    mae_all_frm+=float(mae_frm_test)
                    mae_all_pot+=float(mae_lvd_test)
                    dice = dice + float(dice_coeff)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),
                                   mae_lvd=float(mae_lvd_test),mae_mvd=float(mae_mvd_test),mae_frm=float(mae_frm_test),
                                   mae_ed=float(mae_ed_test),mae_es=float(mae_es_test),dice=float(dice_coeff))
                    # if float(classiv_test) == float(torch.max(pred_cls_test, 1)[1]):
                    #     right = right+1
                    # else:
                    #     error = error+1
                    
            if train_mode=='cls':
                print('epoch_train_acc=',right/(right+error))
            elif train_mode=='pot':
                print('epoch_train_acc=',mae_all/lens2)
            elif train_mode=='frm':
                print('epoch_train_acc=',mae_all/lens2)
            elif train_mode=='seg':
                print('epoch_train_acc=',dice/lens2)
            elif train_mode=='mtl':
                print('epoch_train_point=',mae_all_frm/lens2,
                      'epoch_train_frame=',mae_all_pot/lens2,'epoch_train_dice=',dice/lens2)
                
    if train_mode == 'cls':
        cls_acc =   right/(right+error)  
        if cls_acc >= cls_less:
            print('save model')
            cls_less = cls_acc
            torch.save(model.state_dict(), './weight_cls/cls_weights3.pth')    
        else:
            print('not save,the best cor is:',cls_less)
    elif train_mode == 'pot':
        pot_acc =   mae_all/lens 
        if pot_acc <= pot_less:
            print('save model')
            pot_less = pot_acc
            torch.save(model.state_dict(), './weight_pot/pot_weights4.pth')    
        else:
            print('not save,the best cor is:',pot_less)
    elif train_mode == 'frm':
        frm_acc =   mae_all/lens2 
        if frm_acc >= 0.85:
            print('save model')
            frm_less = frm_acc
            torch.save(model.state_dict(), './weight_frm/frm_weights3.pth')    
        else:
            print('not save,the best cor is:',frm_less)
    elif train_mode == 'seg':
        seg_acc =   dice/lens2
        if seg_acc >= dice_less:
            print('save model')
            dice_less = seg_acc
            torch.save(model.state_dict(), './weight_seg/seg_weights3.pth')    
        else:
            print('not save,the best cor is:',dice_less)
    elif train_mode == 'mtl':
        seg_acc =   dice/lens2
        if seg_acc >= dice_less:
            print('save model')
            dice_less = seg_acc
            torch.save(model.state_dict(), './weight/mlt_weights_base2.pth')    
        else:
            print('not save,the best cor is:',dice_less)
scheduler.step()

In [None]:
# torch.topk(pred_frm, 1)[1].squeeze(1)
cal_a(pred_seg).shape

In [None]:
with torch.no_grad():
    i = 23
    video_test = image_test[i:i+1,:,:,:,:]
    framei_test = frame_test[i:i+1]
    b = model(video_test)
    print(torch.max(b, 1)[1])
    c = framei_test+1
    print(c)

In [None]:
import numpy as np

import torch
import torch.utils.data
from torch.autograd import Variable

import pickle

import numpy as np
import torch
import os

class MinNormSolver:
    MAX_ITER = 250
    STOP_CRIT = 1e-5

    def _min_norm_element_from2(v1v1, v1v2, v2v2):
        """
        Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
        d is the distance (objective) optimzed
        v1v1 = <x1,x1>
        v1v2 = <x1,x2>
        v2v2 = <x2,x2>
        """
        if v1v2 >= v1v1:
            # Case: Fig 1, third column
            gamma = 0.999
            cost = v1v1
            return gamma, cost
        if v1v2 >= v2v2:
            # Case: Fig 1, first column
            gamma = 0.001
            cost = v2v2
            return gamma, cost
        # Case: Fig 1, second column
        gamma = -1.0 * ( (v1v2 - v2v2) / (v1v1+v2v2 - 2*v1v2) )
        cost = v2v2 + gamma*(v1v2 - v2v2)
        return gamma, cost

    def _min_norm_2d(vecs, dps):
        """
        Find the minimum norm solution as combination of two points
        This is correct only in 2D
        ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
        """
        dmin = 1e8
        for i in range(len(vecs)):
            for j in range(i+1,len(vecs)):
                if (i,j) not in dps:
                    dps[(i, j)] = 0.0
                    for k in range(len(vecs[i])):
                        dps[(i,j)] += torch.dot(vecs[i][k], vecs[j][k]).item()#torch.dot(vecs[i][k], vecs[j][k]).data[0]
                    dps[(j, i)] = dps[(i, j)]
                if (i,i) not in dps:
                    dps[(i, i)] = 0.0
                    for k in range(len(vecs[i])):
                        dps[(i,i)] += torch.dot(vecs[i][k], vecs[i][k]).item()#torch.dot(vecs[i][k], vecs[i][k]).data[0]
                if (j,j) not in dps:
                    dps[(j, j)] = 0.0   
                    for k in range(len(vecs[i])):
                        dps[(j, j)] += torch.dot(vecs[j][k], vecs[j][k]).item()#torch.dot(vecs[j][k], vecs[j][k]).data[0]
                c,d = MinNormSolver._min_norm_element_from2(dps[(i,i)], dps[(i,j)], dps[(j,j)])
                if d < dmin:
                    dmin = d
                    sol = [(i,j),c,d]
        return sol, dps

    def _projection2simplex(y):
        """
        Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
        """
        m = len(y)
        sorted_y = np.flip(np.sort(y), axis=0)
        tmpsum = 0.0
        tmax_f = (np.sum(y) - 1.0)/m
        for i in range(m-1):
            tmpsum+= sorted_y[i]
            tmax = (tmpsum - 1)/ (i+1.0)
            if tmax > sorted_y[i+1]:
                tmax_f = tmax
                break
        return np.maximum(y - tmax_f, np.zeros(y.shape))
    
    def _next_point(cur_val, grad, n):
        proj_grad = grad - ( np.sum(grad) / n )
        tm1 = -1.0*cur_val[proj_grad<0]/proj_grad[proj_grad<0]
        tm2 = (1.0 - cur_val[proj_grad>0])/(proj_grad[proj_grad>0])
        
        skippers = np.sum(tm1<1e-7) + np.sum(tm2<1e-7)
        t = 1
        if len(tm1[tm1>1e-7]) > 0:
            t = np.min(tm1[tm1>1e-7])
        if len(tm2[tm2>1e-7]) > 0:
            t = min(t, np.min(tm2[tm2>1e-7]))

        next_point = proj_grad*t + cur_val
        next_point = MinNormSolver._projection2simplex(next_point)
        return next_point

    def find_min_norm_element(vecs):
        """
        Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
        as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
        It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
        Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
        """
        # Solution lying at the combination of two points
        dps = {}
        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
        
        n=len(vecs)
        sol_vec = np.zeros(n)
        sol_vec[init_sol[0][0]] = init_sol[1]
        sol_vec[init_sol[0][1]] = 1 - init_sol[1]

        if n < 3:
            # This is optimal for n=2, so return the solution
            return sol_vec , init_sol[2]
    
        iter_count = 0

        grad_mat = np.zeros((n,n))
        for i in range(n):
            for j in range(n):
                grad_mat[i,j] = dps[(i, j)]
                

        while iter_count < MinNormSolver.MAX_ITER:
            grad_dir = -1.0*np.dot(grad_mat, sol_vec)
            new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
            # Re-compute the inner products for line search
            v1v1 = 0.0
            v1v2 = 0.0
            v2v2 = 0.0
            for i in range(n):
                for j in range(n):
                    v1v1 += sol_vec[i]*sol_vec[j]*dps[(i,j)]
                    v1v2 += sol_vec[i]*new_point[j]*dps[(i,j)]
                    v2v2 += new_point[i]*new_point[j]*dps[(i,j)]
            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
            new_sol_vec = nc*sol_vec + (1-nc)*new_point
            change = new_sol_vec - sol_vec
            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
                return sol_vec, nd
            sol_vec = new_sol_vec

    def find_min_norm_element_FW(vecs):
        """
        Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
        as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
        It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
        Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
        """
        # Solution lying at the combination of two points
        dps = {}
        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)

        n=len(vecs)
        sol_vec = np.zeros(n)
        sol_vec[init_sol[0][0]] = init_sol[1]
        sol_vec[init_sol[0][1]] = 1 - init_sol[1]

        if n < 3:
            # This is optimal for n=2, so return the solution
            return sol_vec , init_sol[2]

        iter_count = 0

        grad_mat = np.zeros((n,n))
        for i in range(n):
            for j in range(n):
                grad_mat[i,j] = dps[(i, j)]

        while iter_count < MinNormSolver.MAX_ITER:
            t_iter = np.argmin(np.dot(grad_mat, sol_vec))

            v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
            v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
            v2v2 = grad_mat[t_iter, t_iter]

            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
            new_sol_vec = nc*sol_vec
            new_sol_vec[t_iter] += 1 - nc

            change = new_sol_vec - sol_vec
            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
                return sol_vec, nd
            sol_vec = new_sol_vec


def gradient_normalizers(grads, losses, normalization_type):
    gn = {}
    if normalization_type == 'l2':
        for t in grads:
            gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
    elif normalization_type == 'loss':
        for t in grads:
            gn[t] = losses[t]
    elif normalization_type == 'loss+':
        for t in grads:
            gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
    elif normalization_type == 'none':
        for t in grads:
            gn[t] = 1.0
    else:
        print('ERROR: Invalid Normalization Type')
    return gn

def get_d_paretomtl_init(grads,value,weights,i):
    """ 
    calculate the gradient direction for ParetoMTL initialization 
    """
    
    flag = False
    nobj = value.shape
   
    # check active constraints
    current_weight = weights[i]
    rest_weights = weights
    w = rest_weights - current_weight
    
    gx =  torch.matmul(w,value/torch.norm(value))
    idx = gx >  0
   
    # calculate the descent direction
    if torch.sum(idx) <= 0:
        flag = True
        return flag, torch.zeros(nobj)
    if torch.sum(idx) == 1:
        sol = torch.ones(1).cuda().float()
    else:
        vec =  torch.matmul(w[idx],grads)
        sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec))])


    weight0 =  torch.sum(torch.stack([sol[j] * w[idx][j ,0] for j in torch.arange(0, torch.sum(idx))]))
    weight1 =  torch.sum(torch.stack([sol[j] * w[idx][j ,1] for j in torch.arange(0, torch.sum(idx))]))
    weight2 =  torch.sum(torch.stack([sol[j] * w[idx][j ,2] for j in torch.arange(0, torch.sum(idx))]))
    weight3 =  torch.sum(torch.stack([sol[j] * w[idx][j ,3] for j in torch.arange(0, torch.sum(idx))]))
    weight = torch.stack([weight0,weight1,weight2,weight3])
   
    
    return flag, weight


def get_d_paretomtl(grads,value,weights,i):
    """ calculate the gradient direction for ParetoMTL """
    
    # check active constraints
    current_weight = weights[i]
    rest_weights = weights 
    w = rest_weights - current_weight
    
    gx =  torch.matmul(w,value/torch.norm(value))
    idx = gx >  0
    

    # calculate the descent direction
    if torch.sum(idx) <= 0:
        sol, nd = MinNormSolver.find_min_norm_element([[grads[t]] for t in range(len(grads))])
        return torch.tensor(sol).cuda().float()


    vec =  torch.cat((grads, torch.matmul(w[idx],grads)))
    sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec))])


    weight0 =  sol[0] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,0] for j in torch.arange(2, 2 + torch.sum(idx))]))
    weight1 =  sol[1] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,1] for j in torch.arange(2, 2 + torch.sum(idx))]))
    weight2 =  sol[2] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,2] for j in torch.arange(2, 2 + torch.sum(idx))]))
    weight3 =  sol[3] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,3] for j in torch.arange(2, 2 + torch.sum(idx))]))
    weight = torch.stack([weight0,weight1,weight2,weight3])
    
    return weight


def circle_points(r, n):
    """
    generate evenly distributed unit preference vectors for two tasks
    """
    circles = []
    for r, n in zip(r, n):
        t = np.linspace(0, 0.5 * np.pi, n)
        x = r * np.cos(t)
        y = 2/3*r * np.cos(t)+1/3*r * np.sin(t)
        z = 1/3*r * np.cos(t)+2/3*r * np.sin(t)
        k = r * np.sin(t)
        circles.append(np.c_[x, y, z, k])
    return circles

In [None]:
#训练分割分支可视化
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torchvision
import torch
from torchvision.transforms import transforms
from torch import nn, optim
import timeit
import os
from tqdm import tqdm
from tqdm import trange

img_size = 128
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
torch.backends.cudnn.benchmark = True
l_r = 0.0002  #0.0002
device = torch.device("cuda")
train_mode = 'mtl'
model = Multivit_net(img_dim=128,in_channels=1,out_channels=128,head_num=4,mlp_dim=512,block_num=8,
                     patch_dim=16,class_num=1,drop_rate = 0.2,seq_frame=30,mode =train_mode,height=128,weight=128).to(device)
#model.load_state_dict(torch.load('./weight/mlt_weights_beta0.1.pth'),True)
#model.load_state_dict(torch.load('./weight_seg/seg_weights.pth'),True)
param_optim = []
layers = []
optimizer = torch.optim.Adam(model.parameters(), lr=l_r)
dice_less = 0.85
cor_less = 0.85
mae_less = 30
pot_less = 30
frm_less = 30
cls_less = 0.8
criterion = nn.CrossEntropyLoss()
crit_pot = torch.nn.SmoothL1Loss()
kl_loss = nn.KLDivLoss(reduction="mean")
lossseg,losspot,lossfrm,losscls = [],[],[],[]
for t in range(10):
    # Forward pass: Compute predicted y by passing x to the model
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 30], 0.1)
    lens = image.shape[0]
    print('epoch=',t)
    dice,jcd, fb = 0,0,0
    cor_lva ,cor_mvd,cor_lvd = 0,0,0
    mae_all = 0
    mae_all_pot = 0
    mae_all_frm = 0
    right,error = 0,0
    image = image.to(device)
    label = label.to(device)
    classi = classi.to(device)
    point = point.to(device)
    pointmap = pointmap.to(device)
    frame = frame.to(device)
    lens2 = image_test.shape[0]
    image_test = image_test.to(device)
    label_test = label_test.to(device)
    classi_test = classi_test.to(device)
    point_test = point_test.to(device)
    frame_test = frame_test.to(device)
    model.train()
    print('training')
    with trange(lens) as tr:
        for i in tr:
            video = image[i:i+1,:,:,:,:]
            #print('batch=',i,' of ',lens)
            labelv = label[i:i+1,:,:,:,:]
            classiv = classi[i:i+1]
            pointi = point[i]
            pointmapi = pointmap[i:i+1]
            framei = frame[i]
            if train_mode=='cls':
                pred_cls = model(video)
                # Compute and print loss
                loss_cls = criterion(pred_cls,classiv.long())
                loss = loss_cls
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),class_ori=classiv,class_pred=float(torch.max(pred_cls, 1)[1]))
                if float(classiv) == float(torch.max(pred_cls, 1)[1]):
                    right = right+1
                else:
                    error = error+1
            elif train_mode=='pot':
                pred_pot,pred_potmap = model(video)
                #print(pred_pot.shape,pointi.shape)
                loss_pot = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                loss_pot_map = kl_loss(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                loss = loss_pot+loss_pot_map
                mae = mae_point(pred_pot,pointi)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),mae=float(mae))
                mae_all+=float(mae)
            elif train_mode=='frm':
                pred_frm = model(video)
                pred_frm = torch.tanh(pred_frm)
                loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                loss = loss_frm
                mae = mae_point(pred_frm,framei)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),mae=float(mae))
                mae_all+=float(mae)
            elif train_mode=='seg':
                pred_seg = model(video)
                dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv,pred_seg) # 计算指标
                dice = dice + float(dice_coeff)
                loss1 = seg_loss(pred_seg, labelv)
                loss2 = boundary_cos_loss(labelv,pred_seg)
                a = t/50
                loss = (1-a)*loss1+a*loss2
                hd,md = get_hausdorff(labelv,pred_seg)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),dice=float(dice_coeff))
            elif train_mode == 'mtl':
                pred_pot,pred_potmap,pred_frm,pred_seg = model(video)
                # loss_cls = criterion(pred_cls,classiv.long())
                loss_pot1 = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                loss_pot_map = kl_loss(torch.sigmoid(pred_potmap).log().to(torch.float32), pointmapi.to(torch.float32))
                #print(pred_potmap.shape, pointmapi.shape)
                #loss_pot_map = criterion(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                loss_pot = loss_pot1+loss_pot_map
                pred_frm = torch.tanh(pred_frm)
                loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                loss1 = seg_loss(pred_seg, labelv)
                loss2 = boundary_cos_loss(labelv,pred_seg)
                a = t/50
                loss_seg = (1-a)*loss1+a*loss2
                loss = loss_pot+loss_frm+loss_seg
                mae_pot = mae_point(pred_pot,pointi)
                mae_frm = mae_point(pred_frm,framei)
                dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv,pred_seg) # 计算指标
                mae_all_frm+=float(mae_frm)
                mae_all_pot+=float(mae_pot)
                dice = dice + float(dice_coeff)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss),mae_pot=float(mae_pot)
                              ,mae_frm=float(mae_frm),dice=float(dice_coeff),loss_seg=float(loss_seg),loss_frm=float(loss_frm),
                              loss_pot=float(loss_pot),loss_cls=float(loss_cls))
                # if float(classiv) == float(torch.max(pred_cls, 1)[1]):
                #     right = right+1
                # else:
                #     error = error+1
                lossseg.append(float(loss_seg))
                lossfrm.append(float(loss_frm))
                losspot.append(float(loss_pot))
                # losscls.append(float(loss_cls))
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad() # 梯度置零，因为反向传播过程中梯度会累加上一次循环的梯度
            loss.backward() # loss反向传播
            optimizer.step() # 反向传播后参数更新 
    #print('loss=', loss)
    if train_mode=='cls':
        print('epoch_train_acc=',right/(right+error))
    elif train_mode=='pot':
        print('epoch_train_acc=',mae_all/lens)
    elif train_mode=='frm':
        print('epoch_train_acc=',mae_all/lens)
    elif train_mode=='seg':
        print('epoch_train_acc=',dice/lens)
    elif train_mode=='mtl':
        print('epoch_train_frame=',mae_all_frm/lens,
              'epoch_train_point=',mae_all_pot/lens,'epoch_train_dice=',dice/lens)
    
    with torch.no_grad():
    #测试阶段
        #model.eval()
        print('evaluating')
        dice,jcd, fb = 0,0,0
        cor_lva ,cor_mvd,cor_lvd = 0,0,0
        mae_all = 0
        mae_all_pot = 0
        mae_all_frm = 0
        right,error = 0,0
        with trange(lens2) as tr:
            for i in tr:
                video_test = image_test[i:i+1,:,:,:,:]
                #print('batch=',i,' of ',lens)
                labelv_test = label_test[i:i+1,:,:,:,:]
                classiv_test = classi_test[i:i+1]
                pointi_test = point_test[i:i+1]
                framei_test = frame_test[i]
                if train_mode == 'cls':
                    pred_test_cls = model(video_test)
                    tr.set_postfix(loss=float(loss),class_ori=classiv,class_pred=float(torch.max(pred_cls, 1)[1]))
                    if float(classiv_test) == float(torch.max(pred_test_cls, 1)[1]):
                        right = right+1
                    else:
                        error = error+1
                elif train_mode == 'pot':
                    pred_pot_test,_ = model(video)
                    mae = mae_point(pred_pot_test,pointi_test)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),mae=float(mae))
                    mae_all+=float(mae) 
                elif train_mode == 'frm':
                    pred_frm_test = model(video)
                    mae = mae_point(pred_frm_test,framei_test)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),mae=float(mae))
                    mae_all+=float(mae) 
                elif train_mode=='seg':
                    pred_seg_test = model(video)
                    dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv_test,pred_seg_test) # 计算指标
                    dice = dice + float(dice_coeff)
                    hd,md = get_hausdorff(labelv,pred_seg_test)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),dice=float(dice_coeff))
                elif train_mode=='mtl':
                    pred_pot_test,_,pred_frm_test,pred_seg_test = model(video_test)
                    pred_frm_test = torch.tanh(pred_frm_test)
                    mae_pot_test = mae_point(pred_pot_test,pointi_test)
                    mae_frm_test = mae_point(pred_frm_test,framei_test)
                    dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv_test,pred_seg_test) # 计算指标
                    mae_all_frm+=float(mae_frm_test)
                    mae_all_pot+=float(mae_pot_test)
                    dice = dice + float(dice_coeff)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(loss=float(loss),
                                   mae_pot=float(mae_pot_test),mae_frm=float(mae_frm_test),dice=float(dice_coeff))
                    # if float(classiv_test) == float(torch.max(pred_cls_test, 1)[1]):
                    #     right = right+1
                    # else:
                    #     error = error+1
                    
            if train_mode=='cls':
                print('epoch_train_acc=',right/(right+error))
            elif train_mode=='pot':
                print('epoch_train_acc=',mae_all/lens2)
            elif train_mode=='frm':
                print('epoch_train_acc=',mae_all/lens2)
            elif train_mode=='seg':
                print('epoch_train_acc=',dice/lens2)
            elif train_mode=='mtl':
                print('epoch_train_point=',mae_all_frm/lens2,
                      'epoch_train_frame=',mae_all_pot/lens2,'epoch_train_dice=',dice/lens2)
                
    if train_mode == 'cls':
        cls_acc =   right/(right+error)  
        if cls_acc >= cls_less:
            print('save model')
            cls_less = cls_acc
            torch.save(model.state_dict(), './weight_cls/cls_weights3.pth')    
        else:
            print('not save,the best cor is:',cls_less)
    elif train_mode == 'pot':
        pot_acc =   mae_all/lens 
        if pot_acc <= pot_less:
            print('save model')
            pot_less = pot_acc
            torch.save(model.state_dict(), './weight_pot/pot_weights3.pth')    
        else:
            print('not save,the best cor is:',pot_less)
    elif train_mode == 'frm':
        frm_acc =   mae_all/lens2 
        if frm_acc <= frm_less:
            print('save model')
            frm_less = frm_acc
            torch.save(model.state_dict(), './weight_frm/frm_weights3.pth')    
        else:
            print('not save,the best cor is:',frm_less)
    elif train_mode == 'seg':
        seg_acc =   dice/lens2
        if seg_acc >= dice_less:
            print('save model')
            dice_less = seg_acc
            torch.save(model.state_dict(), './weight_seg/seg_weights3.pth')    
        else:
            print('not save,the best cor is:',dice_less)
    elif train_mode == 'mtl':
        seg_acc =   dice/lens2
        if seg_acc >= dice_less:
            print('save model')
            dice_less = seg_acc
            torch.save(model.state_dict(), './weight/mlt_weights_beta1.pth')    
        else:
            print('not save,the best cor is:',dice_less)
scheduler.step()

In [None]:
#训练分割分支可视化
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torchvision
import torch
from torchvision.transforms import transforms
from torch import nn, optim
import timeit
import os
from tqdm import tqdm
from tqdm import trange

img_size = 128
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.backends.cudnn.benchmark = True
l_r = 0.0002  #0.0002
device = torch.device("cuda")
train_mode = 'mtl'
model = Multivit_net(img_dim=128,in_channels=1,out_channels=16,head_num=4,mlp_dim=512,block_num=8,
                     patch_dim=16,class_num=1,drop_rate = 0.2,seq_frame=30,mode =train_mode,height=128,weight=128).to(device)
#model.load_state_dict(torch.load('./weight_frm/frm_weights.pth'),True)
#model.load_state_dict(torch.load('./weight_seg/seg_weights.pth'),True)
param_optim = []
layers = []
optimizer = torch.optim.Adam(model.parameters(), lr=l_r)
dice_less = 0.85
cor_less = 0.85
mae_less = 30
pot_less = 30
frm_less = 30
cls_less = 0.8
criterion = nn.CrossEntropyLoss()
crit_pot = torch.nn.SmoothL1Loss()
for t in range(50):
    # Forward pass: Compute predicted y by passing x to the model
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 30], 0.1)
    lens = image.shape[0]
    print('epoch=',t)
    dice,jcd, fb = 0,0,0
    cor_lva ,cor_mvd,cor_lvd = 0,0,0
    mae_all = 0
    mae_all_pot = 0
    mae_all_frm = 0
    right,error = 0,0
    image = image
    label = label
    classi = classi
    point = point
    pointmap = pointmap
    frame = frame
    lens2 = image_test.shape[0]
    image_test = image_test
    label_test = label_test
    classi_test = classi_test
    point_test = point_test
    frame_test = frame_test
    model.train()
    print('training')
    with trange(lens) as tr:
        for i in tr:
            video = image[i:i+1,:,:,:,:].to(device)
            #print('batch=',i,' of ',lens)
            labelv = label[i:i+1,:,:,:,:].to(device)
            classiv = classi[i:i+1].to(device)
            pointi = point[i].to(device)
            pointmapi = pointmap[i].to(device)
            framei = frame[i].to(device)
            if train_mode == 'mtl':
                grads = {}
                losses_vec = []
                n_tasks = 4
                # obtain and store the gradient value
                for i in range(n_tasks):
                    pred_cls,pred_pot,pred_potmap,pred_frm,pred_seg = model(video)
                    loss_cls = criterion(pred_cls,classiv.long())
                    loss_pot1 = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                    loss_pot_map = seg_loss(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                    loss_pot = loss_pot1+loss_pot_map
                    pred_frm = torch.tanh(pred_frm)
                    loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                    loss1 = seg_loss(pred_seg, labelv)
                    loss2 = boundary_cos_loss(labelv,pred_seg)
                    a = t/50
                    loss_seg = (1-a)*loss1+a*loss2
                    #loss = loss_cls+loss_pot+loss_frm+loss_seg
                    task_loss = [loss_cls,loss_pot,loss_frm,loss_seg]
                    optimizer.zero_grad()
                    losses_vec.append(task_loss[i].data)
                    task_loss[i].backward()
                    grads[i] = []
                    # can use scalable method proposed in the MOO-MTL paper for large scale problem
                    # but we keep use the gradient of all parameters in this experiment
                    for param in model.parameters():
                        if param.grad is not None:
                            grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))
                #print(len(grads[1]),len(grads[2]))
                grads_list = [torch.cat(grads[i]) for i in range(len(grads))]
                #grads = torch.stack(grads_list)

                # calculate the weights
                losses_vec = torch.stack(losses_vec)
                #print('losses_vec',losses_vec,'grad',grads.shape)
                npref = 5
                ref_vec = torch.tensor(circle_points([1], [npref])[0]).cuda().float()
                flag, weight_vec = get_d_paretomtl_init(grads_list,losses_vec,ref_vec,2)

                # early stop once a feasible solution is obtained
                if flag == True:
                    print("fealsible solution is obtained.")
                    break

                # optimization step
                optimizer.zero_grad()
                for i in range(len(task_loss)):
                    pred_cls,pred_pot,pred_potmap,pred_frm,pred_seg = model(video)
                    loss_cls = criterion(pred_cls,classiv.long())
                    loss_pot1 = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                    loss_pot_map = seg_loss(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                    loss_pot = loss_pot1+loss_pot_map
                    pred_frm = torch.tanh(pred_frm)
                    loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                    loss1 = seg_loss(pred_seg, labelv)
                    loss2 = boundary_cos_loss(labelv,pred_seg)
                    a = t/50
                    loss_seg = (1-a)*loss1+a*loss2
                    #loss = loss_cls+loss_pot+loss_frm+loss_seg
                    task_loss = [loss_cls,loss_pot,loss_frm,loss_seg]
                    if i == 0:
                        loss_total = weight_vec[i] * task_loss[i]
                    else:
                        loss_total = loss_total + weight_vec[i] * task_loss[i]

                loss_total.backward()
                optimizer.step()
                mae_pot = mae_point(pred_pot,pointi)
                mae_frm = mae_point(pred_frm,framei)
                dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv,pred_seg) # 计算指标
                mae_all_frm+=float(mae_frm)
                mae_all_pot+=float(mae_pot)
                dice = dice + float(dice_coeff)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss_total),class_ori=classiv,class_pred=float(torch.max(pred_cls, 1)[1]),mae_pot=float(mae_pot)
                              ,mae_frm=float(mae_frm),dice=float(dice_coeff),task_loss=task_loss,weight=weight_vec)
            else:
            # continue if no feasible solution is found
                continue
            # break the loop once a feasible solutions is found
            break
                
    
    with trange(lens) as tr:
        for i in tr:
            video = image[i:i+1,:,:,:,:].to(device)
            #print('batch=',i,' of ',lens)
            labelv = label[i:i+1,:,:,:,:].to(device)
            classiv = classi[i:i+1].to(device)
            pointi = point[i].to(device)
            pointmapi = pointmap[i].to(device)
            framei = frame[i].to(device)
            if train_mode == 'mtl':
                grads = {}
                losses_vec = []
                n_tasks = 4
                # obtain and store the gradient value
                for i in range(n_tasks):
                    pred_cls,pred_pot,pred_potmap,pred_frm,pred_seg = model(video)
                    loss_cls = criterion(pred_cls,classiv.long())
                    loss_pot1 = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                    loss_pot_map = seg_loss(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                    loss_pot = loss_pot1+loss_pot_map
                    pred_frm = torch.tanh(pred_frm)
                    loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                    loss1 = seg_loss(pred_seg, labelv)
                    loss2 = boundary_cos_loss(labelv,pred_seg)
                    a = t/50
                    loss_seg = (1-a)*loss1+a*loss2
                    #loss = loss_cls+loss_pot+loss_frm+loss_seg
                    task_loss = [loss_cls,loss_pot,loss_frm,loss_seg]
                    optimizer.zero_grad()
                    losses_vec.append(task_loss[i].data)
                    task_loss[i].backward()
                    grads[i] = []
                    # can use scalable method proposed in the MOO-MTL paper for large scale problem
                    # but we keep use the gradient of all parameters in this experiment
                    for param in model.parameters():
                        if param.grad is not None:
                            grads[i].append(Variable(param.grad.data.clone().flatten(), requires_grad=False))
                #print(len(grads[1]),len(grads[2]))
                grads_list = [torch.cat(grads[i]) for i in range(len(grads))]
                grads = torch.stack(grads_list)

                # calculate the weights
                losses_vec = torch.stack(losses_vec)
                #print('losses_vec',losses_vec,'grad',grads.shape)
                npref = 5
                ref_vec = torch.tensor(circle_points([1], [npref])[0]).cuda().float()
                weight_vec = get_d_paretomtl(grads,losses_vec,ref_vec,2)
                normalize_coeff = n_tasks / torch.sum(torch.abs(weight_vec))
                weight_vec = weight_vec * normalize_coeff

                # optimization step
                optimizer.zero_grad()
                for i in range(len(task_loss)):
                    pred_cls,pred_pot,pred_potmap,pred_frm,pred_seg = model(video)
                    loss_cls = criterion(pred_cls,classiv.long())
                    loss_pot1 = crit_pot(pred_pot[0].to(torch.float32),pointi.to(torch.float32))
                    loss_pot_map = seg_loss(pred_potmap.to(torch.float32), pointmapi.to(torch.float32))
                    loss_pot = loss_pot1+loss_pot_map
                    pred_frm = torch.tanh(pred_frm)
                    loss_frm = crit_pot(pred_frm.to(torch.float32),framei.to(torch.float32))
                    loss1 = seg_loss(pred_seg, labelv)
                    loss2 = boundary_cos_loss(labelv,pred_seg)
                    a = t/50
                    loss_seg = (1-a)*loss1+a*loss2
                    #loss = loss_cls+loss_pot+loss_frm+loss_seg
                    task_loss = [loss_cls,loss_pot,loss_frm,loss_seg]
                    if i == 0:
                        loss_total = weight_vec[i] * task_loss[i]
                    else:
                        loss_total = loss_total + weight_vec[i] * task_loss[i]

                loss_total.backward()
                optimizer.step()
                mae_pot = mae_point(pred_pot,pointi)
                mae_frm = mae_point(pred_frm,framei)
                dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv,pred_seg) # 计算指标
                mae_all_frm+=float(mae_frm)
                mae_all_pot+=float(mae_pot)
                dice = dice + float(dice_coeff)
                tr.set_description('batch= %i' % i)
                tr.set_postfix(loss=float(loss_total),class_ori=classiv,class_pred=float(torch.max(pred_cls, 1)[1]),mae_pot=float(mae_pot)
                              ,mae_frm=float(mae_frm),dice=float(dice_coeff),weight=weight_vec)
                if float(classiv) == float(torch.max(pred_cls, 1)[1]):
                    right = right+1
                else:
                    error = error+1
            # Zero gradients, perform a backward pass, and update the weights.
            # optimizer.zero_grad() # 梯度置零，因为反向传播过程中梯度会累加上一次循环的梯度
            # loss.backward() # loss反向传播
            # optimizer.step() # 反向传播后参数更新 
    #print('loss=', loss)
    if train_mode=='mtl':
        print('epoch_train_class=',right/(right+error),'epoch_train_point=',mae_all_frm/lens,
              'epoch_train_frame=',mae_all_pot/lens,'epoch_train_dice=',dice/lens)
    
    with torch.no_grad():
    #测试阶段
        #model.eval()
        print('evaluating')
        dice,jcd, fb = 0,0,0
        cor_lva ,cor_mvd,cor_lvd = 0,0,0
        mae_all = 0
        mae_all_pot = 0
        mae_all_frm = 0
        right,error = 0,0
        with trange(lens2) as tr:
            for i in tr:
                video_test = image_test[i:i+1,:,:,:,:].to(device)
                #print('batch=',i,' of ',lens)
                labelv_test = label_test[i:i+1,:,:,:,:].to(device)
                classiv_test = classi_test[i:i+1].to(device)
                pointi_test = point_test[i].to(device)
                framei_test = frame_test[i].to(device)
                if train_mode=='mtl':
                    pred_cls_test,pred_pot_test,_,pred_frm_test,pred_seg_test = model(video_test)
                    pred_frm_test = torch.tanh(pred_frm_test)
                    mae_pot_test = mae_point(pred_pot_test,pointi_test)
                    mae_frm_test = mae_point(pred_frm_test,framei_test)
                    dice_coeff, jaccard_coeff, f_beta_coeff = coefficients(labelv_test,pred_seg_test) # 计算指标
                    mae_all_frm+=float(mae_frm_test)
                    mae_all_pot+=float(mae_pot_test)
                    dice = dice + float(dice_coeff)
                    tr.set_description('batch= %i' % i)
                    tr.set_postfix(class_ori=classiv_test,class_pred=float(torch.max(pred_cls_test, 1)[1]),
                                   mae_pot=float(mae_pot_test),mae_frm=float(mae_frm_test),dice=float(dice_coeff))
                    if float(classiv_test) == float(torch.max(pred_cls_test, 1)[1]):
                        right = right+1
                    else:
                        error = error+1
                    
            if train_mode=='mtl':
                print('epoch_train_class=',right/(right+error),'epoch_train_point=',mae_all_frm/lens2,
                      'epoch_train_frame=',mae_all_pot/lens2,'epoch_train_dice=',dice/lens2)
                
    if train_mode == 'mtl':
        seg_acc =   dice/lens2
        if seg_acc >= dice_less:
            print('save model')
            dice_less = seg_acc
            torch.save(model.state_dict(), './weight/mlt_weights.pth')    
        else:
            print('not save,the best cor is:',dice_less)
scheduler.step()

In [None]:
#traning set
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import OrdinalEncoder
import numpy.matlib
from torchvision.transforms import transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import timeit
import os
from tqdm import tqdm
from tqdm import trange
import os
import sys
import random
import itertools
import colorsys

import numpy as np
from skimage.measure import find_contours
import matplotlib.pyplot as plt
from matplotlib import patches,  lines
from matplotlib.patches import Polygon
import IPython.display

from torchvision.transforms import transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
unloader = transforms.ToPILImage()

def tensor_to_img(tensor_pred , data_num, frame_num,title=None):
    #image = unloader(image)
    #tensor_pred = tensor_pred.type(torch.float32)
    #pred = tensor_pred.cpu().clone()  # we clone the tensor to not do changes on it
    pred = tensor_pred
    pred = unloader(pred)
    #fig = plt.figure()
    pred = np.array(pred)
    pred = cv2.resize(pred, (256,256) , interpolation=cv2.INTER_AREA) 
    #plt.imshow(pred, cmap = 'gray')
    return pred

def tensor_to_lb(tensor_pred , data_num, frame_num,title=None):
    #image = unloader(image)
    tensor_pred = torch.sigmoid(tensor_pred)
    tensor_pred = torch.gt(tensor_pred, 0.5)
    tensor_pred = tensor_pred.type(torch.float32)
    pred = tensor_pred.cpu().clone()  # we clone the tensor to not do changes on it
    pred = unloader(pred)
    #fig = plt.figure()
    pred = np.array(pred)
    pred = cv2.resize(pred, (256,256) , interpolation=cv2.INTER_AREA) 
    #plt.imshow(pred, cmap = 'gray')
    return pred


def random_colors(N, bright=True):
    """
    Generate random colors.
    To get visually distinct colors, generate them in HSV space then
    convert to RGB.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


def apply_mask(image, mask, color, alpha=0.4):
    """Apply the given mask to the image.
    """
    image = image.copy()
    mask_out = cv2.Canny(mask.astype(np.uint8),0,1)
    kernel = np.ones((2, 2), dtype=np.uint8)
    mask_out = cv2.dilate(mask_out, kernel, 1)
    for c in range(3):
        image[:, :, c] = np.where(mask >= 0.5,
                                  image[:, :, c] *
                                  (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    for c in range(3):
        image[:, :, c] = np.where(mask_out >= 0.5,
                                  color[c] * 255,
                                  image[:, :, c])
    return image

unloader = transforms.ToPILImage()
train_mode = 'mtl'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.backends.cudnn.benchmark = True
l_r = 0.0002  #0.0002
device = torch.device("cuda")
def fill_contour(img):
    contours, _ = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    n = len(contours)  # 轮廓的个数
    max_area = 0
    for i  in range(n):
        if cv2.contourArea(contours[i]) > max_area:
            max_area = cv2.contourArea(contours[i])
    cv_contours = []
    if n != 1:
        for contour in contours:
            area = cv2.contourArea(contour)
            if area < max_area:
                cv_contours.append(contour)
                x, y, w, h = cv2.boundingRect(contour)
                img[y:y + h, x:x + w] = 1
            else:
                continue
    else:
        pass
    
    return img

def tensor_im(tensor_pred , data_num, frame_num,title=None):
    #tensor_pred = torch.sigmoid(tensor_pred)
    #tensor_pred = torch.gt(tensor_pred, 0.5)
    tensor_pred = tensor_pred.type(torch.float32)
    pred = tensor_pred.cpu().clone()  # we clone the tensor to not do changes on it
    pred = unloader(pred)
    pred = np.array(pred)
    #plt.imshow(pred)
    return pred
    
def tensor_save(tensor_pred , data_num, frame_num,title=None):
    tensor_pred = torch.sigmoid(tensor_pred)
    tensor_pred = torch.gt(tensor_pred, 0.5)
    tensor_pred = tensor_pred.type(torch.float32)
    pred = tensor_pred.cpu().clone()  # we clone the tensor to not do changes on it
    pred = unloader(pred)
    kernel1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5, 5))
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(10, 10))
    pred = np.array(pred)
    #plt.imshow(pred)
    return pred

def landmark(center_x,center_y,IMAGE_HEIGHT, IMAGE_WIDTH):
    R = np.sqrt(1**1 + 1**1)
    Gauss_map = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH))
    # 直接利用矩阵运算实现
    mask_x = np.matlib.repmat(center_x, IMAGE_HEIGHT, IMAGE_WIDTH)
    mask_y = np.matlib.repmat(center_y, IMAGE_HEIGHT, IMAGE_WIDTH)
    x1 = np.arange(IMAGE_WIDTH)
    x_map = np.matlib.repmat(x1, IMAGE_HEIGHT, 1)
    y1 = np.arange(IMAGE_HEIGHT)
    y_map = np.matlib.repmat(y1, IMAGE_WIDTH, 1)
    y_map = np.transpose(y_map)
    Gauss_map = np.sqrt((x_map-mask_x)**2+(y_map-mask_y)**2)
    Gauss_map = np.exp(-0.5*Gauss_map/R)
    return Gauss_map

def norm(img):
    img = np.array(img, dtype=np.float32)
    img -= np.mean(img)
    img /= (np.std(img) + 1e-12)
    return img
    
def locmap(pot):
    gauss_batch = []
    for i in range(0,pot.shape[0]):
        gauss_tp = []
        for j in range(0,pot.shape[1]):
            g_map1 = landmark(pot[i, j, 0, 0],pot[i, j, 0, 1],128,128)
            g_map2 = landmark(pot[i, j, 1, 0],pot[i, j, 1, 1],128,128)
            g_map3 = landmark(pot[i, j, 2, 0],pot[i, j, 2, 1],128,128)
            g_map4 = landmark(pot[i, j, 3, 0],pot[i, j, 3, 1],128,128)
            Gauss_map = [g_map1,g_map2,g_map3,g_map4]
            Gauss_map = norm(Gauss_map)
            gauss_tp.append(Gauss_map)
        gauss_batch.append(gauss_tp)
    gauss_batch = np.array(gauss_batch)[:, :, :, :, :]
    return gauss_batch

def point_color(potmap):
    image = np.zeros((128,128,3),dtype=int)
    kernel = np.ones((2, 2), dtype=np.uint8)
    color1 = [1,1,0]
    color2 = [0,1,1]
    color3 = [1,0,1]
    color4 = [1,0,0]
    alpha = 1
    k = 10
    for c in range(3):
        image[:, :, c] = np.where(potmap[0, :, :] >= k,alpha * color1[c] * 255,
                                  image[:, :, c])
    for c in range(3):
        image[:, :, c] = np.where(potmap[1, :, :] >= k,alpha * color2[c] * 255,
                                  image[:, :, c])
    for c in range(3):
        image[:, :, c] = np.where(potmap[2, :, :] >= k,alpha * color3[c] * 255,
                                  image[:, :, c])
    for c in range(3):
        image[:, :, c] = np.where(potmap[3, :, :] >= k,alpha * color4[c] * 255,
                                  image[:, :, c])
    return image
    
model = Multivit_net(img_dim=128,in_channels=1,out_channels=128,head_num=4,mlp_dim=512,block_num=8,
                     patch_dim=16,class_num=1,drop_rate = 0.2,seq_frame=30,mode =train_mode,height=128,weight=128).to(device)
model.load_state_dict(torch.load('./weight/mlt_weights_beta0.5.pth'),True)
model.train()
size = 128
ki = image_test.to(device)
kg = label_test.to(device)
kp = point_test.to(device)
kf = frame_test.to(device)
for i in range(99,109):
    a = ki[i:i+1]
    d = kg[i:i+1]
    c = kp[i:i+1]
    e = kf[i:i+1]
    plt.subplots(figsize=(15,40))
    color_lvla = [(1,0,0),(0,1,0)]
    with torch.no_grad():
        b,b1,b2,b3,b4 = model(a)
        #torch.tanh(b)
        print(a.shape,c.shape,b.shape,b1.shape,b2.shape,b3.shape,b4.shape)
        print(b1[0,0,:,:],c[0,0,:,:])
        print(e,b3)
        for k in range(2):
            j = k
            plt.subplot(161)
            img = tensor_im(a[0,j,0,:,:].to(torch.uint8),0,1).astype(np.uint8)
            alp,bet = 53,11
            plt.imshow(img*alp+bet,cmap = 'gray')
            plt.axis('off')
            image_ori = np.zeros((128,128,3),dtype=int)
            image_ori[:,:,0] = (img*alp+bet)[:,:]
            image_ori[:,:,1] = (img*alp+bet)[:,:]
            image_ori[:,:,2] = (img*alp+bet)[:,:]
            plt.subplot(162)
            gt = tensor_save(d[0,j,0,:,:].to(torch.float32),0,1)
            gt = apply_mask(image_ori,gt,color_lvla[1])
            plt.imshow(gt)
            plt.axis('off')
            plt.subplot(163)
            pred = tensor_save(b4[0,j,0,:,:].to(torch.float32),0,1)
            pred = apply_mask(image_ori,pred,color_lvla[1])
            plt.imshow(pred)
            plt.axis('off')
            plt.subplot(164)
            x = locmap(np.array(c.cpu()))[0,j,:,:]
            max_index = np.unravel_index(np.argmax(x, axis=None), x.shape)
            max_value = x[max_index]
            #print(max_index,max_value)
            comap = point_color(locmap(np.array(c.cpu()))[0,j,:,:])
            plt.imshow(comap)
            plt.axis('off')
            plt.subplot(165)
            b1map = point_color(locmap(np.array(b1.cpu()))[0,j,:,:])
            plt.imshow(b1map)
            plt.axis('off')
            plt.subplot(166)
            b1map = point_color(locmap(np.array(b1.cpu()))[0,j,:,:])
            # plt.imshow(b1map)
            # plt.axis('off')
            peak_list = np.array(e.cpu()[0])
            # print(peak_list.shape)
            if peak_list[j] == -1:
                plt.imshow(img,cmap = 'PuRd_r')
            elif peak_list[j] == 0:
                plt.imshow(img)
            else:
                plt.imshow(img,cmap = 'BrBG_r')
            plt.axis('off')