# Import Lib

In [None]:
import pandas as pd
import torch, io
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torcheval.metrics import AUC
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os, cv2
from PIL import Image
from tempfile import TemporaryDirectory
from torch.utils.tensorboard import SummaryWriter
from metrics_compute_plot import PR_plot_CV, binary_auc_plot_cv
import torch.nn.functional as F
from functools import partial
import SimpleITK as sitk
from sklearn.metrics import roc_auc_score
from skimage.transform import resize
from tqdm import tqdm
from prefetch_generator import BackgroundGenerator
from livelossplot import PlotLosses
from sklearn.model_selection import StratifiedShuffleSplit,ShuffleSplit
from ctviewer import CTViewer


# Check the device

In [None]:
# 检查是否有可用的 GPU 设备
if torch.cuda.is_available():
    # 设置 PyTorch 使用的 GPU 设备为卡1
    torch.cuda.set_device(0)  # 使用 GPU 卡1
    device = "cuda"
else:
    # 如果没有可用的 GPU，则使用 CPU
    device = "cpu"

print("Using device:", device)

# General Methods
## Data Loader with batchsize in memory

In [None]:
cudnn.benchmark = True
plt.ion()  # interact

class PrefetchDataLoader(torch.utils.data.DataLoader):
    '''
        replace DataLoader with PrefetchDataLoader
    '''
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())
def label_smooth(label, C, epsilon, is_onehot=True):
    """
    Smooths the labels, commonly used in machine learning and deep learning to address label noise issues.
    Args:
        label: Labels, of type np.ndarray, with shape (batch_size, C)
        C: Number of classes
        epsilon: Smoothing factor, should be greater than or equal to 0 and less than or equal to 1
        is_onehot: Whether the labels are in one-hot encoding, default is True
    Returns:
        smooth_labels: Smoothed labels, of type np.ndarray, with the same shape as label
    """
    assert epsilon >= 0.0 and epsilon <= 1.0, "epsilon should be in [0.0, 1.0]"
    confidence = 1.0 - epsilon
    #print(label.shape)
    # one-hot label
    if label.shape[-1] == C and is_onehot is True:
        smooth_labels = label * confidence + epsilon / C

    # index label
    else:
        # Convert index labels to one-hot labels
        eye_matrix = np.eye(C)
        onehot_labels = eye_matrix[label]
        smooth_labels = onehot_labels * confidence + epsilon / C
    return smooth_labels#smooth_labels.astype(np.float64)

class CELoss_mixup(nn.Module):
    ''' Cross Entropy Loss with label smoothing and mix up'''
    def __init__(self, label_smooth=None, class_num=2, mix_up = None):
        super().__init__()
        self.label_smooth = label_smooth
        self.class_num = class_num
        self.mix_up = mix_up

    def forward(self, pred, target):
        ''' 
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12
        
        if self.label_smooth is not None:
            # cross entropy loss with label smoothing
            logprobs = F.log_softmax(pred, dim=1)	# softmax + log
            target = F.one_hot(target, self.class_num)	# 转换成one-hot
            
            # label smoothing
            # 实现 1
            # target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num 	
            # 实现 2
            # implement 2
            target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
            loss = -1*torch.sum(target*logprobs, 1)
        elif self.mix_up is not None:
            # cross entropy loss with mix up
            logprobs = F.log_softmax(pred, dim=1)	# softmax + log
            
            # label smoothing
            loss = -1*torch.sum(target*logprobs, 1)
            
        else:
            # standard cross entropy loss
            loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))

        return loss.mean()


class Dataset_loader(torch.utils.data.Dataset):
    """dataset."""

    def __init__(self, df_file, transform=None):
        """
        Args:
            csv_file (string): 带有标注信息的 csv 文件路径
            transform (callable, optional): 可选的用于预处理图片的方法
        """
        self.dataframe = df_file#pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # 读取图片
        depth = 9
        image1 = sitk.GetArrayFromImage(sitk.ReadImage('/home/u20210110/jupyterlab/HAIC_TACE/Processed_data_cut/'+str(self.dataframe.iloc[idx, 3])+'/'+os.path.basename(self.dataframe.iloc[idx, 2])+ '/AP.nii.gz'))#修改文件路径
        #image2 = sitk.GetArrayFromImage(sitk.ReadImage('/home/u20210110/jupyterlab/HAIC_TACE/Processed_data_cut/'+str(self.dataframe.iloc[idx, 0])+ '/VP.nii.gz'))#修改文件路径
        steps = len(image1)//depth
        #print(str(self.dataframe.iloc[idx, 3])+'/'+str(self.dataframe.iloc[idx, 1])+':'+str(steps))
        '''
        if steps>0:
            if np.mod(len(image1),depth)==0:
                image1 = image1[::steps,:,:]
            else:
                image1 = image1[steps::steps,:,:]
        image1 = resize(image1,(depth,224,224))
        '''
        if steps>0:
            idx_slice = len(image1) // 2
            image1 = image1[idx_slice-depth//2:idx_slice+depth//2,:,:]
        image1 = resize(image1,(depth,224,224))    
        image1 = image1.reshape(1,depth*224,224)
        #steps = len(image2)//depth
        #if np.mod(len(image2),depth)==0:
            #image2 = image2[::steps,:,:]
        #else:
            #image2 = image2[steps::steps,:,:]
        #image2 = resize(image2,(depth,224,224))
        #image = np.concatenate((image1,image2),axis=-3)
        image = ((image1-0.)/255)
        
        # 读取标签
        label = self.dataframe.iloc[idx, -8]
        #label_time = self.dataframe.iloc[idx, -3]
        #label_state = self.dataframe.iloc[idx, -4]
        label = np.array([label])
        #label_time = np.array([label_time])
        #label_state = np.array([label_state])

        if self.transform:
            image = self.transform(image)
        # sample = {'image': image, 'label': label}
        return torch.Tensor(image), label

In [None]:
import time
import math
from functools import partial
from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
    pass

# an alternative for mamba_ssm (in which causal_conv1d is needed)
try:
    from selective_scan import selective_scan_fn as selective_scan_fn_v1
    from selective_scan import selective_scan_ref as selective_scan_ref_v1
except:
    pass

DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"




class PatchEmbed2D(nn.Module):
    r""" Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, patch_size1=4,patch_size2=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size1, int) and isinstance(patch_size2, int):
            patch_size = (patch_size1, patch_size2)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = self.proj(x).permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x


class PatchMerging2D(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        SHAPE_FIX = [-1, -1]
        if (W % 2 != 0) or (H % 2 != 0):
            print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
            SHAPE_FIX[0] = H // 2
            SHAPE_FIX[1] = W // 2

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        if SHAPE_FIX[0] > 0:
            x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
        
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, H//2, W//2, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x
    



class SS2D(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        # d_state="auto", # 20240109
        d_conv=3,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
        device=None,
        dtype=None,
        **kwargs,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2,
            **factory_kwargs,
        )
        self.act = nn.SiLU()

        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
        )
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
        del self.x_proj

        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
        )
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
        del self.dt_projs
        
        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)

        # self.selective_scan = selective_scan_fn
        self.forward_core = self.forward_corev0

        self.out_norm = nn.LayerNorm(self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True
        
        return dt_proj

    @staticmethod
    def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D

    def forward_corev0(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn
        
        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
        # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = self.selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds, z=None,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    # an alternative to forward_corev1
    def forward_corev1(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn_v1

        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
        # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = self.selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    def forward(self, x: torch.Tensor, **kwargs):
        B, H, W, C = x.shape

        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1) # (b, h, w, d)

        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.act(self.conv2d(x)) # (b, d, h, w)
        y1, y2, y3, y4 = self.forward_core(x)
        assert y1.dtype == torch.float32
        y = y1 + y2 + y3 + y4
        y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
        y = self.out_norm(y)
        y = y * F.silu(z)
        out = self.out_proj(y)
        if self.dropout is not None:
            out = self.dropout(out)
        return out


class ConvSSM(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 0,
        drop_path: float = 0,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        attn_drop_rate: float = 0,
        d_state: int = 16,
        **kwargs,
    ):
        super().__init__()
        self.ln_1 = norm_layer(hidden_dim//2)
        self.self_attention = SS2D(d_model=hidden_dim//2, dropout=attn_drop_rate, d_state=d_state, **kwargs)
        self.drop_path = DropPath(drop_path)

        self.conv33conv33conv11 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_dim//2,out_channels=hidden_dim//2,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(hidden_dim//2),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(hidden_dim // 2),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1)
        )
        self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
    def forward(self, input: torch.Tensor):
        input_left, input_right = input.chunk(2,dim=-1)
        x = input_right + self.drop_path(self.self_attention(self.ln_1(input_right)))
        input_left = input_left.permute(0,3,1,2).contiguous()
        input_left = self.conv33conv33conv11(input_left)
        x = x.permute(0,3,1,2).contiguous()
        output = torch.cat((input_left,x),dim=1)
        output = self.finalconv11(output).permute(0,2,3,1).contiguous()
        return output+input


class VSSLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
        self, 
        dim, 
        depth, 
        attn_drop=0.,
        drop_path=0., 
        norm_layer=nn.LayerNorm, 
        downsample=None, 
        use_checkpoint=False, 
        d_state=16,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            ConvSSM(
                hidden_dim=dim,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                attn_drop_rate=attn_drop,
                d_state=d_state,
            )
            for i in range(depth)])
        
        if True: # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_() # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))
            self.apply(_init_weights)

        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None


    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x,use_reentrant=False)
            else:
                x = blk(x)
        
        if self.downsample is not None:
            x = self.downsample(x)

        return x

    


class VSSM(nn.Module):
    def __init__(self, patch_size1=4,patch_size2=4, in_chans=3, num_classes=1000, depths=[2, 2, 2, 2], depths_decoder=[2, 9, 2, 2],
                 dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.prototypes = nn.Parameter(torch.randn(num_classes,sum(dims[1:])+dims[-1]))
        #print(self.prototypes.shape)
        self.num_layers = len(depths)
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
        self.embed_dim = dims[0]
        self.num_features = sum(dims[1:])+dims[-1]
        self.dims = dims

        self.patch_embed = PatchEmbed2D(patch_size1=patch_size1,patch_size2=patch_size2, in_chans=in_chans, embed_dim=self.embed_dim,
            norm_layer=norm_layer if patch_norm else None)

        # WASTED absolute position embedding ======================
        self.ape = False
        # self.ape = False
        # drop_rate = 0.0
        if self.ape:
            self.patches_resolution = self.patch_embed.patches_resolution
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1]

        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = VSSLayer(
                dim=dims[i_layer],
                depth=depths[i_layer],
                d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
                drop=drop_rate, 
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
            )
            self.layers.append(layer)


        # self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    def _init_weights(self, m: nn.Module):
        """
        out_proj.weight which is previously initilized in ConvSSM, would be cleared in nn.Linear
        no fc.weight found in the any of the model parameters
        no nn.Embedding found in the any of the model parameters
        so the thing is, ConvSSM initialization is useless
        
        Conv2D is not intialized !!!
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def compute_distances(self, x):
        distances = torch.norm(x.unsqueeze(1) - self.prototypes, dim=2)
        return distances#torch.argmin(distances, dim=1)
    
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_backbone(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        
        mutiple_ouput = []
        for layer in self.layers:
            x = layer(x)
            mutiple_ouput.append(x)
        return mutiple_ouput

    def forward(self, x):
        x = self.forward_backbone(x)
        mutiple = x[0].permute(0,3,1,2)
        mutiple = self.avgpool(mutiple)
        #print(mutiple.shape)
        for single in x[1:]:
            single = single.permute(0,3,1,2)
            single = self.avgpool(single)
            #print(single.shape)
            mutiple = torch.cat((mutiple,single),dim=1)
        x = torch.flatten(mutiple,start_dim=1)
        x = self.pos_drop(x)
        #print(x.shape)
        distances = self.compute_distances(x)
        x = self.head(x)
        return distances,x


# model = VSSM(num_classes=6).to("cuda")
#
# data = torch.randn(1,3,224,224).to("cuda")
#
# print(model(data).shape)


# Accelearte with data load
## Train model function with dataset all in memory for AP

### Data analysis

In [None]:
df = pd.read_excel(f'/home/u20210110/jupyterlab/HAIC_TACE/Final_V7-Selected.xlsx')
idx = (df['Phase_A']==1)&(df['Phase_V']==1)&(pd.isna(df['proliferative']))
df = df[idx]

In [None]:
for i in list(np.unique(df['Center'])):
    print(f"{i} has data {np.sum(df['Center']==i)}: HAIC {np.sum((df['Center']==i)&(df['HAIC/TACE']==1))}, TACE {np.sum((df['Center']==i)&(df['HAIC/TACE']==0))}")
print(f"ALL data {len(df['Center'])}: HAIC {np.sum(df['HAIC/TACE']==1)}, TACE {np.sum(df['HAIC/TACE']==0)}")
print(f'Train HAIC {396+129+442+21}, TACE {473+138}')
print(f'test HAIC {156+50+3+38}, TACE {28+33+12+132+10}')

In [None]:
df.loc[df['最大肿瘤直径']=='不易测量','最大肿瘤直径'] = 0
df = df[(df['HAIC/TACE']==0)&(df['最大肿瘤直径']>=5)]
#OR
df.loc[(df['疗效评估 (PD or PR or CR)']=='CR')|(df['疗效评估 (PD or PR or CR)']=='CR ')|(df['疗效评估 (PD or PR or CR)']=='PR')|(df['疗效评估 (PD or PR or CR)']=='pR')|(df['疗效评估 (PD or PR or CR)']=='PR '),'疗效评估 (PD or PR or CR)']=1
df.loc[(df['疗效评估 (PD or PR or CR)']=='PD')|(df['疗效评估 (PD or PR or CR)']=='pD')|(df['疗效评估 (PD or PR or CR)']=='SD')|(df['疗效评估 (PD or PR or CR)']=='Sd')|(df['疗效评估 (PD or PR or CR)']=='SD '),'疗效评估 (PD or PR or CR)']=0
#DCR
#df.loc[(df['疗效评估 (PD or PR or CR)']=='CR')|(df['疗效评估 (PD or PR or CR)']=='CR ')|(df['疗效评估 (PD or PR or CR)']=='PR')|(df['疗效评估 (PD or PR or CR)']=='PR ')|(df['疗效评估 (PD or PR or CR)']=='SD')|(df['疗效评估 (PD or PR or CR)']=='SD '),'疗效评估 (PD or PR or CR)']=1
#df.loc[(df['疗效评估 (PD or PR or CR)']=='PD')|(df['疗效评估 (PD or PR or CR)']=='PD '),'疗效评估 (PD or PR or CR)']=0

In [None]:
for i in list(np.unique(df['Center'])):
    print(f"{i} has data {np.sum(df['Center']==i)}: HAIC {np.sum((df['Center']==i)&(df['HAIC/TACE']==1))}:{np.sum(df[(df['Center']==i)&(df['HAIC/TACE']==1)]['疗效评估 (PD or PR or CR)'])}, TACE {np.sum((df['Center']==i)&(df['HAIC/TACE']==0))}:{np.sum(df[(df['Center']==i)&(df['HAIC/TACE']==0)]['疗效评估 (PD or PR or CR)'])}")
print(f"ALL data {len(df['Center'])}: HAIC {np.sum(df['HAIC/TACE']==1)}:{np.sum(df[df['HAIC/TACE']==1]['疗效评估 (PD or PR or CR)'])}, TACE {np.sum(df['HAIC/TACE']==0)}:{np.sum(df[df['HAIC/TACE']==0]['疗效评估 (PD or PR or CR)'])}")


In [None]:
del_list = os.listdir('/home/u20210110/jupyterlab/HAIC_TACE/Problem')
index_selected = []
for idx in range(len(df)):
    vp_name = str(df.iloc[idx, 3])+'_'+os.path.basename(df.iloc[idx, 2])+'_0000.nii.gz_VP'
    ap_name = str(df.iloc[idx, 3])+'_'+os.path.basename(df.iloc[idx, 2])+'_0000.nii.gz_AP'
    if ((vp_name) in del_list)|((ap_name) in del_list):
        index_selected.append(False)
    else:
        index_selected.append(True)

df = df[index_selected]
index_selected = []
for idx in range(len(df)):
    file_name = '/home/u20210110/jupyterlab/HAIC_TACE/Processed_data_cut/'+str(df.iloc[idx, 3])+'/'+os.path.basename(df.iloc[idx, 2])+ '/AP.nii.gz'
    if os.path.exists(file_name):
        index_selected.append(True)
    else:
        index_selected.append(False)
df = df[index_selected]


In [None]:
for i in list(np.unique(df['Center'])):
    print(f"{i} has data {np.sum(df['Center']==i)}: HAIC {np.sum((df['Center']==i)&(df['HAIC/TACE']==1))}:{np.sum(df[(df['Center']==i)&(df['HAIC/TACE']==1)]['疗效评估 (PD or PR or CR)'])}, TACE {np.sum((df['Center']==i)&(df['HAIC/TACE']==0))}:{np.sum(df[(df['Center']==i)&(df['HAIC/TACE']==0)]['疗效评估 (PD or PR or CR)'])}")
print(f"ALL data {len(df['Center'])}: HAIC {np.sum(df['HAIC/TACE']==1)}:{np.sum(df[df['HAIC/TACE']==1]['疗效评估 (PD or PR or CR)'])}, TACE {np.sum(df['HAIC/TACE']==0)}:{np.sum(df[df['HAIC/TACE']==0]['疗效评估 (PD or PR or CR)'])}")


In [None]:
df_train = df[(df['Center']=='SYSUCC')|(df['Center']=='SYSU_first')|(df['Center']=='JNU')]
df_val = df[(df['Center']=='CHCAMS')|(df['Center']=='SYSU_third')|(df['Center']=='Gaofei')|(df['Center']=='LUHE')]
print(f"HIAC Train ORR:{df_train['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_train)-df_train['疗效评估 (PD or PR or CR)'].values.sum()}")
print(f"HIAC Val ORR:{df_val['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_val)-df_val['疗效评估 (PD or PR or CR)'].values.sum()}")
df_train = pd.concat([df_train,df_train[df_train['疗效评估 (PD or PR or CR)']==1],df_train[df_train['疗效评估 (PD or PR or CR)']==1]])
print(f"Ague HIAC Train ORR:{df_train['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_train)-df_train['疗效评估 (PD or PR or CR)'].values.sum()}")
df_val = pd.concat([df_val,df_val[df_val['疗效评估 (PD or PR or CR)']==1],df_val[df_val['疗效评估 (PD or PR or CR)']==1],df_val[df_val['疗效评估 (PD or PR or CR)']==1]])
print(f"Ague HIAC Train ORR:{df_val['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_val)-df_val['疗效评估 (PD or PR or CR)'].values.sum()}")

In [None]:
df_train.to_csv('TACE_train_20240523.csv',index=False,encoding='gbk')

In [None]:
dataset_sizes = {'train': len(df_train), 'val': len(df_val)}
print(f'train: {len(df_train)}, val: {len(df_val)}')

In [None]:
train_data = Dataset_loader(df_file=df_train,)
val_data = Dataset_loader(df_file=df_val,)
'''pre_data_loader = {
        'train': PrefetchDataLoader(train_data, batch_size=len(df_train), shuffle=False, num_workers=4),
        'val': PrefetchDataLoader(val_data, batch_size= len(df_val), shuffle=False, num_workers=4)
        }
for inputs, labels in pre_data_loader['train']:
    train_dataset =  torch.utils.data.TensorDataset(inputs, labels )
for inputs, labels  in pre_data_loader['val']:
    val_dataset =  torch.utils.data.TensorDataset(inputs, labels )    '''
data_set = {'train':train_data,'val':val_data}

# Ray finetune

In [None]:
from ray import train, tune
from ray.train import RunConfig
from ray.tune import Callback
from ray.train import Checkpoint
import tempfile 

from ray.tune.schedulers import ASHAScheduler,AsyncHyperBandScheduler
class MyCallback(Callback):
    def on_trial_result(self, iteration, trials, trial, result, **info):
        print(f"Got {trial} result of {result['training_iteration']}: {result['auc']}")
def stop_fn(trial_id: str, result: dict) -> bool:
    return result["acc"] < 0.6 and result["training_iteration"] >= 10

In [None]:
def train_model(config,data_set):#config 是一个字典，包含了训练模型所需的各种配置参数，比如模型类型、学习率等。fold 是一个表示交叉验证中的第几折的参数，用于指定当前训练的是哪个数据分组
    fold_dir = '/home/u20210110/jupyterlab/HAIC_TACE/2_20240430_finetune_TACE_AP'
    
    model = VSSM(patch_size1=9*config["patch_size"],patch_size2=config["patch_size"], in_chans=1, num_classes=2, depths=config["depths"], depths_decoder=[2, 9, 2, 2],
                 dims=config["dims"], dims_decoder=[768, 384, 192, 96], d_state=config["d_state"], drop_rate=config["drop_rate"], attn_drop_rate=config["attn_drop_rate"], drop_path_rate=config["drop_path_rate"],
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=True)
    model = model.to(device)
    criterion_train = CELoss_mixup(label_smooth=config['smooth_label'], class_num=2,mix_up=config['mix_up'])#nn.CrossEntropyLoss()#nn.BCEWithLogitsLoss()
    criterion_test = CELoss_mixup(label_smooth=config['smooth_label'],class_num=2,mix_up=config['mix_up'])#nn.CrossEntropyLoss()#nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config["lr"],betas=(config["beta"], 0.999), eps=config["eps"])
    #criterion = nn.BCEWithLogitsLoss()#nn.CrossEntropyLoss()
    # Observe that all parameters are being optimized

    since = time.time()
    num_epochs = config["ep"]
    '''
    data_loader = {
        'train': torch.utils.data.DataLoader(data_set['train'], batch_size=config["bs"], shuffle=True, num_workers=4,pin_memory=True),
        'val': torch.utils.data.DataLoader(data_set['val'], batch_size=config["bs"], shuffle=False, num_workers=4,pin_memory=True)}
    '''
    data_loader = {
        'train': PrefetchDataLoader(data_set['train'], batch_size=config["bs"], shuffle=True, num_workers=4),
        'val': PrefetchDataLoader(data_set['val'], batch_size= config["bs"], shuffle=False, num_workers=4)
        }
    

    # Create a temporary directory to save training checkpoints
    checkpoint = train.get_checkpoint()
    if checkpoint:
        checkpoint_state = checkpoint.to_dict()
        start_epoch_train = checkpoint_state["epoch"]
        model.load_state_dict(checkpoint_state["net_state_dict"])
        optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
    else:
        start_epoch_train = 0
    best_auc = 0.0
    best_acc = 0.0
    start_epoch_train = 0
    n_iter_test = 0
    for epoch in range(num_epochs):
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0
                train_pred = []
                test_pred = []
                train_label = []
                test_label = []
                # Iterate over data.
                for inputs, labels in data_loader[phase]:
                    #sm_labels = label_smooth(labels, 1, config['smooth_label'], is_onehot=True)
                    inputs = inputs.to(device)
                    sm_labels = labels.to(device)
                    labels = labels.to(device)
                    if phase == 'train' and config['alpha']>0:
                        ## mix up
                        target = F.one_hot(labels, 2)	# 转换成one-hot
                        alpha = config['alpha']
                        lam = np.random.beta(alpha,alpha)
                        index = torch.randperm(inputs.size(0)).cuda()
                        inputs = lam*inputs+(1-lam)*inputs[index,:]
                        sm_labels = lam*sm_labels+(1-lam)*sm_labels[index]

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):#完成了模型在训练和验证阶段的前向传播、损失计算、反向传播和参数更新等步骤。
                        outputs,x = model(inputs)
                        #outputs = outputs.squeeze()
                        _, preds = torch.max(outputs, 1)
                        if phase == 'train' and config['alpha']>0:
                            loss = criterion_train(outputs, sm_labels)
                        else:
                            loss1 = criterion_test(outputs, labels.squeeze())
                            loss2 = criterion_test(x, labels.squeeze())
                            loss = loss1+loss2

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                            start_epoch_train+=1
                        else:
                            n_iter_test+=1
                    # statistics#可以实现对模型在训练和验证阶段的损失值和评估指标的实时监控和记录，方便后续分析和可视化。
                    if phase == 'train':
                        train_pred.extend(nn.Sigmoid()(outputs).cpu().detach().numpy()[:,1])
                        train_label.extend(labels.cpu().numpy())
                    else:
                        test_pred.extend(nn.Sigmoid()(outputs).cpu().detach().numpy()[:,1])
                        test_label.extend(labels.cpu().numpy())
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.squeeze().data).cpu().numpy()
                    
                    
                '''   
                if phase == 'train':
                    scheduler.step()
                '''

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects / dataset_sizes[phase]
                
                if phase == 'val':
                    prefix = 'val'
                    #print(test_label)
                    #print(test_pred)
                    epoch_auc = roc_auc_score(test_label,test_pred)
                    if epoch_auc > best_auc:
                        best_auc = epoch_auc
                        torch.save({"epoch": epoch, "model_state": model.state_dict()}, os.path.join(fold_dir, "checkpoint.pt"),)
                        train.report(
                                    {"auc": epoch_auc,"loss":epoch_loss},
                                    checkpoint=Checkpoint.from_directory(fold_dir)
                                    )
                        #torch.save(model.state_dict(),f'{fold_dir}/best_model_params.pt')
                    else:
                        train.report({"auc": epoch_auc,"loss":epoch_loss})
                else:
                    prefix = ''

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')#训练时间
    print(f'Best val AUC: {best_auc:4f}')#最好的模型


In [None]:
import ray
if __name__ == "__main__":
    fold_dir = '/home/u20210110/jupyterlab/HAIC_TACE/2_20240430_finetune_TACE_AP'
    os.makedirs(f'{fold_dir}', exist_ok=True)
    train_pred_lists = []
    train_label_lists = []
    test_pred_lists = []
    test_label_lists = []
    i = 5
    config = {
        "lr": tune.choice([1e-06]),#tune.loguniform(1e-6, 1e-3),
        "ep": tune.choice([50]),
        "bs": 32,
        "beta": tune.choice([0.9]),#tune.choice([0.1,0.5,0.9]),
        "eps": tune.choice([1e-08]),#tune.choice([1e-08,1e-06,1e-04,1e-02]),
        "smooth_label": None,#tune.choice([0,0.1,0.3]),
        "alpha":tune.choice([0]),#tune.choice([0,0.1,0.5,1,10]),
        "patch_size": tune.choice([4,9]), 
        "depths":tune.choice([[2, 2, 2, 2],[1,1,1,1],[2,4,4,2]]), 
        "dims": tune.choice([[96, 192, 384, 768],[48,96,192,384],[128,256,512,1024]]), 
        "d_state":tune.choice([16,81]),  
        "drop_rate":tune.choice([0,0.1,0.3]), 
        "attn_drop_rate":tune.choice([0,0.1,0.3]), 
        "drop_path_rate":tune.choice([0,0.1,0.3]),
        "mix_up":None,
        
        }

    # Decay LR by a factor of 0.1 every 7 epochs
    # exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    #scheduler = ASHAScheduler()
    
    scheduler = AsyncHyperBandScheduler(grace_period=5)
    '''
    reporter = tune.JupyterNotebookReporter(
        metric_columns=["loss", "acc", "training_iteration"])
    result = tune.run(
        partial(train_model, dataset=data_set),
        resources_per_trial={"cpu": 4, "gpu": 1},
        config=config,
        num_samples=30,
        scheduler=scheduler,
    )
    '''
    #ray.init(ignore_reinit_error=True,num_cpus=4, num_gpus=1, resources={'Custom': 4})
    tuner = tune.Tuner(
        tune.with_resources(tune.with_parameters(train_model, data_set=data_set),
            resources={"cpu": 4, "gpu": 1}
        ),
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            metric="auc",
            mode="max",
            num_samples=30,
        ),
        param_space=config,
        run_config=train.RunConfig(stop={"time_total_s": 10*60},callbacks=[MyCallback()])
    )
    result = tuner.fit()
    best_trial = result.get_best_result("auc", "max", "last")
    print(f"Best trial config: {best_trial.config}")
    print(f"Best trial path: {best_trial.path}")
    print(f"Best trial final validation auc: {best_trial.metrics['auc']}")
    model = VSSM(patch_size1=9*best_trial.config["patch_size"],patch_size2=best_trial.config["patch_size"], in_chans=1, num_classes=2, 
                 depths=best_trial.config["depths"], depths_decoder=[2, 9, 2, 2],
                 dims=best_trial.config["dims"], dims_decoder=[768, 384, 192, 96], 
                 d_state=best_trial.config["d_state"], 
                 drop_rate=best_trial.config["drop_rate"], attn_drop_rate=best_trial.config["attn_drop_rate"], drop_path_rate=best_trial.config["drop_path_rate"],
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=True)
    model.load_state_dict(torch.load(os.path.join(best_trial.checkpoint.path,'checkpoint.pt'))['model_state'])
    model = model.to(device)
    
    #开始训练
    data_loader = {
        'train': torch.utils.data.DataLoader(data_set['train'], batch_size=32, shuffle=False, num_workers=0),
        'val': torch.utils.data.DataLoader(data_set['val'], batch_size=32, shuffle=False, num_workers=0)}#创建了两个数据加载器，用于加载训练集和验证集的数据
    train_pred_list = []
    train_label_list = []                              
    for inputs, labels in data_loader['train']:
        inputs,_ = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        #outputs = outputs.squeeze()
        outputs = nn.Sigmoid()(outputs)
        train_pred_list.extend(outputs.cpu().detach().numpy()[:,1])
        train_label_list.extend(labels.cpu().numpy())                                                       
    df_train['pred1'] = train_pred_list
    df_train.to_csv(f'{fold_dir}/train_{i}.csv')
    train_pred_lists.append(train_pred_list)
    train_label_lists.append(train_label_list)                            
    test_pred_list = []
    test_label_list = []
    for inputs, labels in data_loader['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs,_ = model(inputs)
       # outputs = outputs.squeeze()
        outputs = nn.Sigmoid()(outputs)
        test_pred_list.extend(outputs.cpu().detach().numpy()[:,1])
        test_label_list.extend(labels.cpu().numpy())                   
    df_val['pred1'] = test_pred_list
    df_val.to_csv(f'{fold_dir}/val_{i}.csv')
    test_pred_lists.append(test_pred_list)
    test_label_lists.append(test_label_list)   
    # PR_plot_CV(np.concatenate([train_pred_list,1-np.array(train_pred_list)]),train_label_list,'train',fold_dir)
    # PR_plot_CV(np.concatenate([test_pred_list,1-test_pred_list]),test_label_list,'test',fold_dir)

    #设置绘图
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95,
                    hspace=0.2, wspace=0.2)
    parameters = {'lw': 2, 'alpha_fold': 0.3, 'alpha_ave': 0.8, 'fontsize': 20}

    title = f'Train Cohort'
    ax[0] = binary_auc_plot_cv(ax[0], train_pred_lists, train_label_lists, parameters, title)
    title = f'Test Cohort '
    ax[1] = binary_auc_plot_cv(ax[1], test_pred_lists, test_label_lists, parameters, title)
    fig.savefig(f'{fold_dir}/ROC_CV.png', dpi=200, bbox_inches='tight')
    #print(f"Model structure: {model}\n\n")

In [None]:
best_trial.config

In [None]:
if __name__ == "__main__":
    fold_dir = '/home/u20210110/jupyterlab/HAIC_TACE/20240430_finetune_AP'
    os.makedirs(f'{fold_dir}', exist_ok=True)
    train_pred_lists = []
    train_label_lists = []
    test_pred_lists = []
    test_label_lists = []
    i = 5
    config = {
        "lr": tune.loguniform(1e-7, 1e-4),
        "ep": tune.choice([50,100,200]),
        "bs": tune.choice([32]),
        "beta": tune.choice([0.5,0.1,0.9]),
        "eps": tune.choice([1e-08,1e-06,1e-04,1e-02]),
        "smooth_label": tune.choice([0.,0.1,0.3]),
        "alpha":tune.choice([0,0.1,0.5,1,5,10]),
        "patch_size": 9, 
        "depths":[2, 4, 4, 2], 
        "dims":[128, 256, 512, 1024],
        "d_state":16,  
        "drop_rate":0.1, 
        "attn_drop_rate":0.1, 
        "drop_path_rate":0.3,
        }

    # Decay LR by a factor of 0.1 every 7 epochs
    # exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    #scheduler = ASHAScheduler()
    scheduler = AsyncHyperBandScheduler(grace_period=5)
    
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_model,data_set = data_set),
            resources={"cpu": 2, "gpu": 1}
        ),
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            metric="auc",
            mode="max",
            num_samples=100,
        ),
        param_space=config,
        run_config=train.RunConfig(stop={"time_total_s": 20*60},callbacks=[MyCallback()])
    )
    result = tuner.fit()
    best_trial = result.get_best_result("auc", "max", "last")
    print(f"Best trial config: {best_trial.config}")
    print(f"Best trial path: {best_trial.path}")
    print(f"Best trial final validation auc: {best_trial.metrics['auc']}")
    
    model = medmamba(patch_size1=9*best_trial.config["patch_size"],patch_size2=best_trial.config["patch_size"], in_chans=1, num_classes=1, 
                 depths=best_trial.config["depths"], depths_decoder=[2, 9, 2, 2],
                 dims=best_trial.config["dims"], dims_decoder=[768, 384, 192, 96], 
                 d_state=best_trial.config["d_state"], 
                 drop_rate=best_trial.config["drop_rate"], attn_drop_rate=best_trial.config["attn_drop_rate"], drop_path_rate=best_trial.config["drop_path_rate"],
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=True)
    model.load_state_dict(torch.load(os.path.join(best_trial.checkpoint.path,'checkpoint.pt'))['model_state'])
    model = model.to(device)
    
    #开始训练
    data_loader = {
        'train': torch.utils.data.DataLoader(data_set['train'], batch_size=32, shuffle=False, num_workers=0),
        'val': torch.utils.data.DataLoader(data_set['val'], batch_size=32, shuffle=False, num_workers=0)}#创建了两个数据加载器，用于加载训练集和验证集的数据
    train_pred_list = []
    train_label_list = []                              
    for inputs, labels in data_loader['train']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        #outputs = outputs.squeeze()
        outputs = nn.Sigmoid()(outputs)
        train_pred_list.extend(outputs.cpu().detach().numpy())
        train_label_list.extend(labels.cpu().numpy())                                                       
    df_train['pred1'] = train_pred_list
    df_train.to_csv(f'{fold_dir}/train_{i}.csv')
    train_pred_lists.append(train_pred_list)
    train_label_lists.append(train_label_list)                            
    test_pred_list = []
    test_label_list = []
    for inputs, labels in data_loader['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        #outputs = outputs.squeeze()
        outputs = nn.Sigmoid()(outputs)
        test_pred_list.extend(outputs.cpu().detach().numpy())
        test_label_list.extend(labels.cpu().numpy())                   
    df_val['pred1'] = test_pred_list
    df_val.to_csv(f'{fold_dir}/val_{i}.csv')
    test_pred_lists.append(test_pred_list)
    test_label_lists.append(test_label_list)   
    # PR_plot_CV(np.concatenate([train_pred_list,1-np.array(train_pred_list)]),train_label_list,'train',fold_dir)
    # PR_plot_CV(np.concatenate([test_pred_list,1-test_pred_list]),test_label_list,'test',fold_dir)

    #设置绘图
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95,
                    hspace=0.2, wspace=0.2)
    parameters = {'lw': 2, 'alpha_fold': 0.3, 'alpha_ave': 0.8, 'fontsize': 20}

    title = f'Train Cohort'
    ax[0] = binary_auc_plot_cv(ax[0], train_pred_lists, train_label_lists, parameters, title)
    title = f'Test Cohort '
    ax[1] = binary_auc_plot_cv(ax[1], test_pred_lists, test_label_lists, parameters, title)
    fig.savefig(f'{fold_dir}/ROC_CV.png', dpi=200, bbox_inches='tight')
    #print(f"Model structure: {model}\n\n")

In [None]:
ax = None
for result_single in result:
    label = f"lr={result_single.config['lr']}"#:.3f
    if ax is None:
        ax = result_single.metrics_dataframe.plot("training_iteration", "auc", label=label)
    else:
        result_single.metrics_dataframe.plot("training_iteration", "auc", ax=ax, label=label)
ax.set_title("AUC vs. Training Iteration for All Trials")
ax.set_ylabel("AUC")
ax.legend(loc="lower right", bbox_to_anchor=(1.05,1.1),borderaxespad = 0.)

# tune again

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma = 2, alpha = 1, size_average = True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average
        self.elipson = 0.000001
    
    def forward(self, logits, labels):
        """
        cal culates loss
        logits: batch_size * labels_length * seq_length
        labels: batch_size * seq_length
        """
        if labels.dim() > 2:
            labels = labels.contiguous().view(labels.size(0),  -1)
            #labels = labels.transpose(1, 2)
            labels = labels.contiguous().view(-1, labels.size(2)).squeeze()
        if logits.dim() > 3:
            logits = logits.contiguous().view(logits.size(0), logits.size(1),  -1)
            #logits = logits.transpose(2, 3)
            logits = logits.contiguous().view(-1, logits.size(1), logits.size(2)).squeeze()
        assert(logits.size(0) == labels.size(0))
        #assert(logits.size(2) == labels.size(1))
        batch_size = logits.size(0)
        labels_length = logits.size(1)
        #seq_length = logits.size(2)

        # transpose labels into labels onehot
        new_label = labels.unsqueeze(1)
        #print(torch.zeros([batch_size, labels_length]))
        label_onehot = torch.zeros([batch_size, labels_length]).to(device).scatter_(1, new_label,1)

        # calculate log
        log_p = F.log_softmax(logits,dim=1)
        pt = label_onehot * log_p
        sub_pt = 1 - pt
        fl = -self.alpha * (sub_pt)**self.gamma * log_p
        if self.size_average:
            return fl.mean()
        else:
            return fl.sum()

In [None]:

def train_again_model(config, dataset):#config 是一个字典，包含了训练模型所需的各种配置参数，比如模型类型、学习率等。fold 是一个表示交叉验证中的第几折的参数，用于指定当前训练的是哪个数据分组
    model = VSSM(patch_size1=9*config["patch_size"],patch_size2=config["patch_size"], in_chans=1, num_classes=2, depths=config["depths"], depths_decoder=[2, 9, 2, 2],
                 dims=config["dims"], dims_decoder=[768, 384, 192, 96], d_state=config["d_state"], drop_rate=config["drop_rate"], attn_drop_rate=config["attn_drop_rate"], drop_path_rate=config["drop_path_rate"],
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=True)

    optimizer = optim.Adam(model.parameters(), lr=config["lr"],betas=(config["beta"], 0.999), eps=config["eps"])
    criterion_train = CELoss_mixup(label_smooth=config['smooth_label'],class_num=2,mix_up=None)#nn.CrossEntropyLoss()#nn.BCEWithLogitsLoss()
    criterion_test = CELoss_mixup(label_smooth=config['smooth_label'],class_num=2,mix_up=None)#nn.CrossEntropyLoss()#nn.BCEWithLogitsLoss()
    
    model = model.to(device)
    num_epochs = config["ep"]

    since = time.time()
    
    writer = SummaryWriter(comment=f'LR_{config["lr"]}_BS_{config["bs"]}')
                         
    data_loader = {
        'train': torch.utils.data.DataLoader(dataset['train'], batch_size=config["bs"], shuffle=True, num_workers=4,pin_memory=True),
        'val': torch.utils.data.DataLoader(dataset['val'], batch_size=config["bs"], shuffle=False, num_workers=4,pin_memory=True)}


    # Create a temporary directory to save training checkpoints
    
    #best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
    
    if config["train"]==True:
        model.load_state_dict(torch.load(f'{fold_dir}/best_model_params.pt'))
    ep_loss = PlotLosses()
    #torch.save(model.state_dict(), best_model_params_path)
    best_acc = 0.0
    best_auc = 0.0
    start_epoch_train = 0
    n_iter_test = 0
    for epoch in range(num_epochs):
        logs = {}                    
        
        # Each epoch has a training and validation phase
        with tqdm(total=len(data_loader['train'])+1) as pbar:
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0
                epoch_auc = 0
                train_pred = []
                test_pred = []
                train_label = []
                test_label = []        
                # Iterate over data.
                for inputs, labels in data_loader[phase]:
                    #sm_labels = label_smooth(labels, 1, config['smooth_label'], is_onehot=True)
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    if phase == 'train' and config['alpha']>0:
                        ## mix up
                        target = F.one_hot(labels, 2)	# 转换成one-hot
                        alpha = config['alpha']
                        lam = np.random.beta(alpha,alpha)
                        index = torch.randperm(inputs.size(0)).cuda()
                        inputs = lam*inputs+(1-lam)*inputs[index,:]
                        sm_labels = lam*target+(1-lam)*target[index]
                    
                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):#完成了模型在训练和验证阶段的前向传播、损失计算、反向传播和参数更新等步骤。
                        outputs,x = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        if phase == 'train' and config['alpha']>0:
                            loss = criterion_train(outputs, sm_labels)
                        else:
                           loss1 = criterion_test(outputs, labels.squeeze())
                           loss2 = criterion_test(x, labels.squeeze())
                           loss = loss1+loss2
                        #loss2 = criterion(preds.squeeze().float(), sm_labels.squeeze().float())
                        #loss = 0.1*loss1+loss2
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    
                    # statistics#可以实现对模型在训练和验证阶段的损失值和评估指标的实时监控和记录，方便后续分析和可视化。
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.squeeze().data).cpu().numpy()  #torch.sum((nn.Sigmoid()(outputs)>0.5) == (labels.data>0.5)).item()
                    
                    if phase == 'train':
                        train_pred.extend(nn.Softmax(dim=1)(outputs)[:,1].cpu().detach().numpy())
                        train_label.extend(labels.squeeze().cpu().numpy())
                    else:
                        test_pred.extend(nn.Softmax(dim=1)(outputs)[:,1].cpu().detach().numpy())
                        test_label.extend(labels.squeeze().cpu().numpy())
                    if phase == 'train':
                        prefix = ''
                        auc_score = roc_auc_score(train_label,train_pred)
                        writer.add_scalar('Loss/train', loss.item() * inputs.size(0) / config["bs"], start_epoch_train)
                        writer.add_scalar('ACC/train', running_corrects/ (config["bs"]*(start_epoch_train+1)), start_epoch_train)
                        writer.add_scalar('AUC/test', auc_score, start_epoch_train)
                        start_epoch_train += 1
                    elif phase == 'val':
                        prefix = 'val_'
                        auc_score = roc_auc_score(test_label,test_pred)
                        writer.add_scalar('Loss/test', loss.item() * inputs.size(0) / config["bs"], n_iter_test)
                        writer.add_scalar('ACC/test', running_corrects/ (config["bs"]*(n_iter_test+1)), n_iter_test)
                        writer.add_scalar('AUC/test', auc_score, n_iter_test)
                        n_iter_test += 1

                    pbar.set_description(f'Epoch {epoch}/{num_epochs - 1}')
                    pbar.update(1)
                    pbar.set_postfix(loss = loss.item(),acc =running_corrects/(config["bs"]*(start_epoch_train+1)),auc=auc_score)
                #if phase == 'train':
                #    scheduler.step()
                    
                
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects / dataset_sizes[phase]
                if phase == 'val':
                    prefix = 'val_'
                    epoch_auc = roc_auc_score(test_label,test_pred)
                    pbar.set_postfix(val_loss = epoch_loss,val_acc=epoch_acc,val_auc=epoch_auc)
                    if epoch_auc > best_auc:
                        best_auc = epoch_auc
                        #torch.save(model.state_dict(),f'{fold_dir}/best_model_params_{epoch:04d}_{best_acc:.4f}.pt')
                        torch.save(model.state_dict(),f'{fold_dir}/best_model_params.pt')
                else:
                    epoch_auc = roc_auc_score(train_label,train_pred)
                    prefix = ''
                    pbar.set_postfix(loss = epoch_loss,acc=epoch_acc,auc=epoch_auc)
                logs[prefix + 'loss'] = epoch_loss  
                logs[prefix + 'ACC'] = epoch_acc
                logs[prefix + 'AUC'] = epoch_auc
                
        ep_loss.update(logs)
        ep_loss.send()
    writer.flush()
    writer.close()                                      
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')#训练时间
    print(f'Best val AUC: {best_auc:4f}')#最好的模型

    # load best model weights
    model.load_state_dict(torch.load(f'{fold_dir}/best_model_params.pt'))
    #torch.save(model.state_dict(), best_model_params_path)
    return model

In [None]:
df_train = df[(df['Center']=='SYSUCC')|(df['Center']=='SYSU_first')|(df['Center']=='JNU')]
df_val = df[(df['Center']=='CHCAMS')|(df['Center']=='SYSU_third')|(df['Center']=='Gaofei')|(df['Center']=='LUHE')]
print(f"HIAC Train ORR:{df_train['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_train)-df_train['疗效评估 (PD or PR or CR)'].values.sum()}")
print(f"HIAC Val ORR:{df_val['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_val)-df_val['疗效评估 (PD or PR or CR)'].values.sum()}")
df_train = pd.concat([df_train,df_train[df_train['疗效评估 (PD or PR or CR)']==1],df_train[df_train['疗效评估 (PD or PR or CR)']==1]])
print(f"Ague HIAC Train ORR:{df_train['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_train)-df_train['疗效评估 (PD or PR or CR)'].values.sum()}")
df_val = pd.concat([df_val,df_val[df_val['疗效评估 (PD or PR or CR)']==1],df_val[df_val['疗效评估 (PD or PR or CR)']==1],df_val[df_val['疗效评估 (PD or PR or CR)']==1],df_val[df_val['疗效评估 (PD or PR or CR)']==1][:13]])
print(f"Ague HIAC Train ORR:{df_val['疗效评估 (PD or PR or CR)'].values.sum()}, Non-ORR:{len(df_val)-df_val['疗效评估 (PD or PR or CR)'].values.sum()}")

In [None]:
train_data = Dataset_loader(df_file=df_train,)
val_data = Dataset_loader(df_file=df_val,)
pre_data_loader = {
        'train': PrefetchDataLoader(train_data, batch_size=len(df_train), shuffle=False, num_workers=4),
        'val': PrefetchDataLoader(val_data, batch_size= len(df_val), shuffle=False, num_workers=4)
        }
for inputs, labels in pre_data_loader['train']:
    train_dataset =  torch.utils.data.TensorDataset(inputs, labels )
'''with tqdm(total=len(inputs)) as pbar:
    single = 0
    for single_image,single_label in zip(inputs,labels):
        single_image = Image.fromarray((single_image.cpu().detach().numpy()[0,:,:]*255).astype('uint8'))
        os.makedirs(f'/home/u20210110/jupyterlab/HAIC_TACE/Check_train_{single_label.cpu().numpy()[0]}',exist_ok =True)
        single_image.save(f'/home/u20210110/jupyterlab/HAIC_TACE/Check_train_{single_label.cpu().numpy()[0]}/'+str(single)+'.jpg')
        single += 1
        pbar.update(1)'''
for inputs, labels  in pre_data_loader['val']:
    val_dataset =  torch.utils.data.TensorDataset(inputs, labels )
'''with tqdm(total=len(inputs)) as pbar:
    single = 0
    for single_image,single_label in zip(inputs,labels):
        single_image = Image.fromarray((single_image.cpu().detach().numpy()[0,:,:]*255).astype('uint8'))
        os.makedirs(f'/home/u20210110/jupyterlab/HAIC_TACE/Check_val_{single_label.cpu().numpy()[0]}',exist_ok =True)
        single_image.save(f'/home/u20210110/jupyterlab/HAIC_TACE/Check_val_{single_label.cpu().numpy()[0]}/'+str(single)+'.jpg')
        single+=1
        pbar.update(1)'''
data_set = {'train':train_dataset,'val':val_dataset}

In [None]:
if __name__ == "__main__":
    fold_dir = '/home/u20210110/jupyterlab/HAIC_TACE/2_final_TACE_AP'
    os.makedirs(f'{fold_dir}', exist_ok=True)
    train_pred_lists = []
    train_label_lists = []
    test_pred_lists = []
    test_label_lists = []
    config = {
        'lr': 1e-06,
        'ep': 200,
        'bs': 32,
        'beta': 0.5,
        'eps': 0.01,
        'smooth_label':None,
        'alpha': 0,#10,
        'patch_size': 9,
        'depths': [2, 4, 4, 2],
        'dims': [128, 256, 512, 1024],
        'd_state': 16,
        'drop_rate': 0.1,
        'attn_drop_rate': 0.1,
        'drop_path_rate': 0.3,
        'train':False,
        }

    # Decay LR by a factor of 0.1 every 7 epochs
    model = train_again_model(config, data_set)
    
    #开始训练
    data_loader = {
        'train': torch.utils.data.DataLoader(data_set['train'], batch_size=32, shuffle=False, num_workers=0),
        'val': torch.utils.data.DataLoader(data_set['val'], batch_size=32, shuffle=False, num_workers=0)}#创建了两个数据加载器，用于加载训练集和验证集的数据
    train_pred_list = []
    train_label_list = []                              
    for inputs, labels in data_loader['train']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs,_ = model(inputs)
        #outputs = outputs.squeeze()
        outputs = nn.Softmax(dim=1)(outputs)
        train_pred_list.extend(outputs.cpu().detach().numpy()[:,1])
        train_label_list.extend(labels.cpu().numpy())                                                       
    df_train['pred1'] = train_pred_list
    df_train.to_csv(f'{fold_dir}/train_{i}.csv')
    train_pred_lists.append(train_pred_list)
    train_label_lists.append(train_label_list)                            
    test_pred_list = []
    test_label_list = []
    for inputs, labels in data_loader['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs,_ = model(inputs)
        #outputs = outputs.squeeze()
        outputs = nn.Softmax(dim=1)(outputs)
        test_pred_list.extend(outputs.cpu().detach().numpy()[:,1])
        test_label_list.extend(labels.cpu().numpy())                   
    df_val['pred1'] = test_pred_list
    df_val.to_csv(f'{fold_dir}/val_{i}.csv')
    test_pred_lists.append(test_pred_list)
    test_label_lists.append(test_label_list)   
    # PR_plot_CV(np.concatenate([train_pred_list,1-np.array(train_pred_list)]),train_label_list,'train',fold_dir)
    # PR_plot_CV(np.concatenate([test_pred_list,1-test_pred_list]),test_label_list,'test',fold_dir)

    #设置绘图
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95,
                    hspace=0.2, wspace=0.2)
    parameters = {'lw': 2, 'alpha_fold': 0.3, 'alpha_ave': 0.8, 'fontsize': 20}

    title = f'Train Cohort'
    ax[0] = binary_auc_plot_cv(ax[0], train_pred_lists, train_label_lists, parameters, title)
    title = f'Test Cohort '
    ax[1] = binary_auc_plot_cv(ax[1], test_pred_lists, test_label_lists, parameters, title)
    fig.savefig(f'{fold_dir}/ROC_CV.png', dpi=200, bbox_inches='tight')
    #print(f"Model structure: {model}\n\n")

In [None]:
import pandas as pd
import numpy as np
import os 
import SimpleITK as sitk
from PIL import Image

def cut_roi(data):
    
    silce_list = []
    for silce in range(len(data)):
        if np.sum(data[silce])>0:
            silce_list.append(silce)
    y_list = []
    for y in range((data).shape[1]):
        if np.sum(data[:,y,:])>0:
            y_list.append(y)
    x_list = []
    for x in range((data).shape[1]):
        if np.sum(data[:,:,x])>0:
            x_list.append(x)
    if (len(silce_list)>0)&(len(y_list)>0)&(len(x_list)>0):
        return silce_list[0],silce_list[-1],y_list[0],y_list[-1],x_list[0],x_list[-1]
    else:
        return False,False,False,False,False,False
phase_str = 'AP'
files_dir = '/home/u20210110/jupyterlab/HAIC_TACE/All_data/'+phase_str
center = 40
width = 350
depth = 9
patients_id = os.listdir(files_dir)#VP
wrong_list = []
df_val = df[(df['Center']=='CHCAMS')|(df['Center']=='SYSU_third')|(df['Center']=='Gaofei')|(df['Center']=='LUHE')]
with tqdm(total=len(patients_id)) as pbar:
    for single in patients_id:
        if  ('CHCAMS' in single) or ('SYSU_third' in single) or ('Gaofei' in single) or ('SYSU_third' in single) or ('LUHE' in single):
            ids = single.split('_')[-2]
            center_dir = single.replace('_'+ids+'_0000.nii.gz','')
            full_save_dir = os.path.join(files_dir.replace('All_data/'+phase_str,'Processed_data_cut_nnUnet'),center_dir,ids)
            os.makedirs(full_save_dir,exist_ok=True)
            phase_data = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(files_dir,single)))
            phase_mask = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(files_dir.replace('All_data/'+phase_str,'nnUNet/output_all_'+phase_str),single.replace('_0000',''))))
            phase_mask[phase_mask==1] = 0
            z_min,z_max,_,_,_,_ = cut_roi(phase_mask)
            phase_mask = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(files_dir.replace('All_data/'+phase_str,'nnUNet/output_all_'+phase_str),single.replace('_0000',''))))
            phase_mask[phase_mask>=1] = 255
            _,_,y_min,y_max,x_min,x_max = cut_roi(phase_mask)
            #import pdb;pdb.set_trace()
            if z_min&z_max&y_min&y_max&x_min&x_max:
                # 转换成窗宽窗位
                min_ = (2 * center - width) / 2.0 + 0.5
                max_ = (2 * center + width) / 2.0 + 0.5
                dFactor = 255.0 / (max_ - min_)
                phase_data = phase_data - min_
                phase_data = np.trunc(phase_data * dFactor)
                phase_data[phase_data < 0.0] = 0
                phase_data[phase_data > 255.0] = 255  # 转换为窗位窗位之后的数据
                roi_data = (phase_data*(phase_mask/255).astype('uint8'))[z_min:z_max+1,:,:]#y_min:y_max,x_min:x_max]
                #roi_data = (phase_data.astype('uint8'))[z_min:z_max,y_min:y_max,x_min:x_max]
                roi = sitk.GetImageFromArray(roi_data)
                sitk.WriteImage(roi,os.path.join(full_save_dir,phase_str+'.nii.gz'))
                #idx_slice = len(roi_data) // 2
                #image = roi_data[idx_slice] 
                steps = len(roi_data)//depth
                if steps>0:
                    if np.mod(len(roi_data),depth)==0:
                        roi_data = roi_data[::steps,:,:]
                    else:
                        roi_data = roi_data[steps::steps,:,:]
                image = resize(roi_data,(depth,224,224))
                image = image.reshape(depth*224,224)
                image = Image.fromarray(image.astype('uint8'))
                os.makedirs('/home/u20210110/jupyterlab/HAIC_TACE/Check',exist_ok =True)
                image.save('/home/u20210110/jupyterlab/HAIC_TACE/Check/'+str(single)+'_'+phase_str+'.jpg')
            else:
                wrong_list.append(single)
        pbar.update(1)
pd.DataFrame(wrong_list).to_csv('/home/u20210110/jupyterlab/HAIC_TACE/All_data/'+phase_str+'_wronglist.csv',index=False)