In [None]:
import os
import nibabel as nib
import pandas as pd
from tqdm import tqdm
import logging
import imageio
import numpy as np
import torch
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
import os,sys,cv2,gc
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel
from glob import glob
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import train_test_split
import random
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from warmup_scheduler import GradualWarmupScheduler
from contextlib import contextmanager
import time
import math
from datetime import datetime
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler

In [None]:
class CFG:
    #Predata
    predata = False
    output_dir = 'EXP_MyDataSet'
    target_height = 512
    target_width = 512
    wandb = False
    seed = 42
    project = 'Spine'
    exp_name = 'exp01'
    n_fold = 5
    valid_fold = 4
    chopping_percentile = 1e-3
    in_chans = 1
    train_batch_size = 4
    valid_batch_size = 8
    
    train_aug_list = [
        A.Affine(scale={"x":(0.8, 1.2), "y":(0.8, 1.2)}, translate_percent={"x":(0, 0.1), "y":(0, 0.1)}, rotate=(-30, 30), shear=(-20, 20), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=0.5),
        A.ShiftScaleRotate(scale_limit=0.2),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.5),
        # A.RandomRotate90(p=0.5),
        # A.OneOf([
        #         A.RandomBrightnessContrast(), 
        #         A.RandomGamma(),
        #   ],p=0.5,),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)

    nprocs=1 
    fold_num=5 
    num_classes=1
    
    accum_iter=1 
    max_grad_norm=1e3
    print_freq=100 
    
    test_fold_list=[0] 
    valid_fold_list=[1]
    train_fold_list = [2,3,4]

    epochs=25
    
    model_arch="efficientnet-b1"
    
    optimizer="AdamW" 
    
    scheduler="CosineAnnealingLR"
    loss_fn="BCEWithLogitsLoss"
    scheduler_warmup= "GradualWarmupSchedulerV3"

    warmup_epo = 10
    warmup_factor = 10
    T_max= epochs-warmup_epo-2 if scheduler_warmup=="GradualWarmupSchedulerV2" else \
           epochs-warmup_epo-1 if scheduler_warmup=="GradualWarmupSchedulerV3" else epochs-1
    lr=1e-3 
    min_lr=1e-6 #
    weight_decay=1e-2
    n_early_stopping=20
    
    model_name = 'seresnext26d_32x4d.bt_in1k'
    model_path = 'Encoder_backbone/Encoder_backbone/Encoder/seresnext26d_32x4d_bt_in1k.bin' #0.8827
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 

seed_everything(CFG.seed)

@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')

# 日志记录函数
def init_logger(log_file):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger(CFG.output_dir+f'/train_{CFG.exp_name}.log')
loginfo = LOGGER.info
cusprint = print

def get_timediff(time1,time2):
    minute_,second_ = divmod(time2-time1,60)
    return f"{int(minute_):02d}:{int(second_):02d}"  

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
if CFG.predata:
    os.makedirs(CFG.output_dir, exist_ok=True)
    def normalize_upsample(image,is_label):
        # 缩放图像数据到0-255范围内以适应PNG格式
        if not is_label:
            image = image.astype(np.float32)
            min_val = np.min(image)
            max_val = np.max(image)
            image = (image - min_val) / (max_val - min_val + 1e-9) * 255
        image = cv2.resize(image, (CFG.target_width, CFG.target_height), interpolation=cv2.INTER_CUBIC)
        return image.astype(np.uint8)
        
    # 用于存储所有信息的DataFrame
    all_image_info = []
    # 遍历定义的stage
    for stage in CFG.stages:
        #使用logger输出当前stage
        logger.info(f"当前阶段：{stage}")
        img_dir = os.path.join(CFG.top_level_dir, stage, 'image')
        groundtruth_dir = os.path.join(CFG.top_level_dir, stage, 'groundtruth')
        img_output_dir = os.path.join(CFG.output_dir, stage,'image')
        groundtruth_output_dir = os.path.join(CFG.output_dir, stage,'groundtruth')
        os.makedirs(img_output_dir, exist_ok=True)
        os.makedirs(groundtruth_output_dir, exist_ok=True)
        
        # 遍历文件夹内的文件
        for file_name in tqdm(os.listdir(img_dir)):
            if file_name.endswith('.nii.gz'):
                # 读取image和groundtruth文件
                img_path = os.path.join(img_dir, file_name)
                groundtruth_path = os.path.join(groundtruth_dir, f'mask_{file_name.lower()}')
                try:
                    img = nib.load(img_path).get_fdata()
                    label = nib.load(groundtruth_path).get_fdata()
                    height, width, depth = img.shape
                    
                    # 检查shape是否相同
                    if img.shape != label.shape:
                        raise ValueError("Image and label shape do not match.")
                except FileNotFoundError as e:
                    logger.warning(f"文件未找到: {e}")
                    continue
                except ValueError as e:
                    logger.warning(f"Shape不匹配警告: {e} - 跳过文件 {file_name}")
                    continue
                except Exception as e:
                    logger.warning(f"处理文件 {file_name} 时发生未知错误: {e}")
                    continue
                base_name = file_name.split('.')[0]
                # 遍历所有切片并保存为PNG文件
                for i in range(img.shape[2]):
                    image_2d = img[:, :, i]
                    label_2d = label[:, :, i]
                    img_filename = f'{base_name}_{i}.png'
                    gt_filename = f'{base_name}_{i}_label.png'
                    img_png_path = os.path.join(img_output_dir, img_filename)
                    gt_png_path = os.path.join(groundtruth_output_dir, gt_filename)
                    
                    # 保存图片和标签的PNG文件
                    imageio.imwrite(img_png_path,normalize_upsample(image_2d,is_label=False))
                    imageio.imwrite(gt_png_path,normalize_upsample(label_2d,is_label=True))
                    # 将信息添加到列表中
                    all_image_info.append({
                        'Image': img_filename,
                        'Case':base_name,
                        'ImagePath': img_png_path,
                        'GroundTruthPath': gt_png_path,
                        'height':height,
                        'width':width,
                        'Stage': stage
                    })
    
    # 使用列表创建DataFrame
    df = pd.DataFrame(all_image_info)
    # 保存DataFrame为CSV文件
    df.to_csv('image_groundtruth_data.csv', index=False)
    logger.info(f"处理完成，DataFrame已保存为 image_groundtruth_data.csv")
else:
    df = pd.read_csv('image_groundtruth_data_my.csv')

In [None]:
df

In [None]:
if CFG.wandb:
    try:
        import wandb
        wandb.login()
        run = wandb.init(project=CFG.project, 
                 name=CFG.exp_name,
                ) 
    except:
        logger.info(f"Check your WANDB account")

In [None]:
gkf = GroupKFold(n_splits=CFG.n_fold)
df["fold"] = -1
for fold_id, (_, val_idx) in enumerate(
    gkf.split(df, y=df["GroundTruthPath"], groups=df["Case"])
):
    df.loc[val_idx, "fold"] = fold_id
df.fold.value_counts()

In [None]:
class Data_loader2D(Dataset):
    def __init__(self,df):
        self.df = df.reset_index()
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self,index):
        row = self.df.iloc[index]
        img_path = row.ImagePath
        label_path = row.GroundTruthPath
        img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
        label = cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)
        
        img=torch.from_numpy(img)
        label=torch.from_numpy(label)
        bone_list=[]
        for i in row.Gt_all_path[1:-1].split(', '):
            i=i[1:-1]
            bone = cv2.imread(i,cv2.IMREAD_GRAYSCALE)
            bone = torch.from_numpy(bone)
            bone_list.append(bone)
        bone_list = torch.stack(bone_list)
        return img,label,bone_list
    
def load_data(df):
    data_loader=Data_loader2D(df)
    data_loader=DataLoader(data_loader, batch_size=16, num_workers=0)
    img=[]
    label=[]
    bone=[]
    for x,y,z in tqdm(data_loader):
        img.append(x)
        label.append(y)
        bone.append(z)
    img_c=torch.cat(img,dim=0)
    label_c=torch.cat(label,dim=0)
    bone_c=torch.cat(bone,dim=0)
    del img,label,bone
    return img_c,label_c,bone_c

In [None]:
train_img,train_label,train_bone=load_data(df)
print(train_img.shape)
print(train_label.shape)
print(train_bone.shape)

In [None]:
class Spine_Dataset(Dataset):
    def __init__(self,x:list,y:list,z:list,arg=False):
        super(Dataset,self).__init__()
        self.x = x
        self.y = y #list[(C,H,W),...]
        self.z = z
        self.image_size=CFG.target_height
        self.in_chans=CFG.in_chans
        self.arg=arg
        if arg:
            self.transform=CFG.train_aug
        else: 
            self.transform=CFG.valid_aug
            
    def __len__(self) -> int:
        return sum([y.shape[0]-self.in_chans for y in self.y])
        
    def __getitem__(self, index):
        i=0
        for x in self.x:
            if index>x.shape[0]-self.in_chans:
                index-=x.shape[0]-self.in_chans
                i+=1
            else:
                break
        x=self.x[i]
        y=self.y[i]
        z=self.z[i]
        
        x=x[index:index+self.in_chans,:,:]
        y=y[index+self.in_chans//2,:,:]
        z=z[index+self.in_chans//2,:,:,:]
        y = y.unsqueeze(0) 
        y = torch.cat((y, z), dim=0) # 使用concat进行拼接操作

        if self.in_chans == 1:
             x = x.repeat(3, 1, 1)
        data = self.transform(image=x.numpy().transpose(1, 2, 0), mask=y.numpy().transpose(1, 2, 0))
        x = data['image']
        y = data['mask']
        # 创建二值标签mask：所有不等于0的标签变为1
        mask = torch.where(y != 0, torch.tensor(1.0, dtype=torch.float32), torch.tensor(0.0, dtype=torch.float32))
        binary_mask = mask[0:1]
        multilabel_mask = mask[1:]
        # 返回最终的输入数据x，二值标签binary_mask，以及多标签multilabel_mask
        return x, binary_mask, multilabel_mask

In [None]:
train_dataset = Spine_Dataset([train_img],[train_label],[train_bone],arg=True)
train_dataset = DataLoader(train_dataset, batch_size=CFG.train_batch_size ,num_workers=0, shuffle=True, pin_memory=True)

In [None]:
# 获取DataLoader的迭代器
data_iterator = iter(train_dataset)
# 从迭代器中获取第一个批次
first_batch = next(data_iterator)
print('img shape:',first_batch[0].shape,'binary_label_shape:',first_batch[1].shape,'multi_label_shape:',first_batch[2].shape)

In [None]:
import segmentation_models_pytorch as smp

def build_model():
    model = smp.DeepLabV3Plus(
        encoder_name=CFG.model_arch,    # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=CFG.num_classes,        # model output channels (number of classes in your dataset)
        )
    model.to(device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
import torch
import torch.nn as nn
class SELayer(nn.Module): 
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel//reduction,bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel//reduction,channel, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b,c,h,w = x.size()
        y = self.avgpool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y.expand_as(x)

if __name__ == "__main__":
    bs, c, h, w = 10, 16, 64, 64
    in_tensor = torch.ones(bs, c, h, w)

    cs_se = SELayer(c)
    print("in shape:",in_tensor.shape)
    out_tensor = cs_se(in_tensor)
    print("out shape:", out_tensor.shape)

In [None]:
import torch
import torch.nn as nn

class sSE(nn.Module): 
    def __init__(self, in_channels):
        super().__init__()
        self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
        self.norm = nn.Sigmoid()

    def forward(self, U):
        q = self.Conv1x1(U) 
        q = self.norm(q)
        return U * q  

class cSE(nn.Module): 
    def __init__(self, in_channels):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.Conv_Squeeze = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False)
        self.Conv_Excitation = nn.Conv2d(in_channels//2, in_channels, kernel_size=1, bias=False)
        self.norm = nn.Sigmoid()

    def forward(self, U):
        z = self.avgpool(U)# shape: [bs, c, h, w] to [bs, c, 1, 1]
        z = self.Conv_Squeeze(z) # shape: [bs, c/2]
        z = self.Conv_Excitation(z) # shape: [bs, c]
        z = self.norm(z)
        return U * z.expand_as(U)

class csSE(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.cSE = cSE(in_channels)
        self.sSE = sSE(in_channels)

    def forward(self, U):
        U_sse = self.sSE(U)
        U_cse = self.cSE(U)
        return U_cse+U_sse

if __name__ == "__main__":
    bs, c, h, w = 10, 3, 64, 64
    in_tensor = torch.ones(bs, c, h, w)
    cs_se = csSE(c)
    print("in shape:",in_tensor.shape)
    out_tensor = cs_se(in_tensor)
    print("out shape:", out_tensor.shape)

    s_se = sSE(c)
    print("in shape:",in_tensor.shape)
    out_tensor = s_se(in_tensor)
    print("out shape:", out_tensor.shape)

In [None]:
import torch
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=4):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        print('X shape：',x.shape)
        mid_col_idx_start = x.shape[3] // 4 
        mid_col_idx_end = x.shape[3] // 4 *3
        
        # 仅抽取中间列，中间列形状为 [batch_size, channels, height, 1]
        mid_col = x[:, :, :, mid_col_idx_start:mid_col_idx_end]
        print('Mid col shape:', mid_col.shape)
        
        print('Avg shape：',self.avg_pool(mid_col).shape)
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.ca(out) * out  # 广播机制
        out = self.sa(out) * out  # 广播机制
        out += residual
        out = self.relu(out)

        return out


if __name__ == "__main__":

    x = torch.ones(3, 16, 32, 32)

    model = BasicBlock(16, 16, stride=1)

    print(model(x).shape)

In [None]:
import torch

def pool_variance(Z_prev, f_h=2, f_w=2, padding=0, stride_h=2, stride_w=3):
    (b, n_C_prev, n_H_prev, n_W_prev) = Z_prev.shape

    n_H = 1 + int((n_H_prev + 2 * padding - f_h) / stride_h)
    n_W = 1 + int((n_W_prev + 2 * padding - f_w) / stride_w)

    Z_prev_unfold = torch.nn.functional.unfold(Z_prev, (f_h, f_w), stride=(stride_h, stride_w)) # (b, n_C_prev * f_h * f_w, L)
    Z_prev_unfold = Z_prev_unfold.transpose(1, 2) # (b, L, n_C_prev * f_h * f_w)
    Z_prev_unfold = Z_prev_unfold.view(b, n_H * n_W, n_C_prev, -1) # (b, L, n_C_prev, f_h * f_w)
    
    mean_squared = torch.mean(Z_prev_unfold ** 2, dim=3, keepdim=False)
    
    mean = torch.mean(Z_prev_unfold, dim=3, keepdim=False)
    
    variance = mean_squared - mean ** 2
    
    variance = variance.transpose(1, 2)
    
    Z_var = torch.nn.functional.fold(variance, (n_H, n_W), (1, 1))
    
    assert(Z_var.size() == (b, n_C_prev, n_H, n_W))
    return Z_var

if __name__ == "__main__":
    sample_input = torch.randn(1, 64, 32, 16)  
    result = pool_variance(sample_input, f_h=8, f_w=16, padding=0, stride_h=8, stride_w=16)
    print(result.shape) 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=4):
        super(ChannelAttention, self).__init__()
        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.BatchNorm2d(in_planes// ratio),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False),
            nn.BatchNorm2d(in_planes),
            nn.ReLU(),
        )
        self.sigmoid = nn.Sigmoid()
        self.height_reduce = nn.Sequential(
            nn.Conv2d(in_planes, in_planes, (8, 2), bias=False),
        )

    def forward(self, x):
        mid_col_idx_start = x.shape[3]//4
        mid_col_idx_end = x.shape[3]//4 *3

        mid_col = x[:, :, :, mid_col_idx_start:mid_col_idx_end]

        pool_height = mid_col.shape[2] // 8
        pool_width =  mid_col_idx_end - mid_col_idx_start
        stride = (pool_height, pool_width)  # stride matches the pool size for non-overlapping pooling

        avg_pool = nn.AvgPool2d((pool_height, pool_width), stride=stride)
        var_pooled = pool_variance(mid_col,f_h=pool_height, f_w=pool_width, stride_h=pool_height, stride_w=pool_width)

        avg_pooled = avg_pool(mid_col)
   
        avgout = self.sharedMLP(avg_pooled)
        varout = self.sharedMLP(var_pooled)
        concat_pooled = torch.cat((avgout,varout), dim=3)

        mlp_output = self.height_reduce(concat_pooled)

        return self.sigmoid(mlp_output)

# 测试模块
if __name__ == "__main__":
    sample_input = torch.randn(4, 64, 32, 32) 
    channel_attention = ChannelAttention(in_planes=64)
    output = channel_attention(sample_input)
    print(output.shape)  

In [None]:
import torch
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)

class SpatialAttention(nn.Module):
    def __init__(self, planes, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        self.kernel_size = kernel_size
        padding = 3 if kernel_size == 7 else 1
        self.padding = padding
        self.conv = nn.Conv2d(3, 1, kernel_size, padding=padding, bias=False)
        self.p2pconv = nn.Conv2d(planes,1,kernel_size,padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        convout = self.p2pconv(x)
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        varout = torch.var(x, dim=1, keepdim=True)
        x = torch.cat([avgout, varout,convout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention(planes)

    def forward(self, x):
        out = self.ca(x) * x  # 广播机制
        out = self.sa(out) * out  # 广播机制
        out = self.relu(out)
        return out

if __name__ == "__main__":

    x = torch.ones(3, 16, 32, 32)

    model = BasicBlock(16, 16, stride=1)

    print(model(x).shape)

In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):  
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out
if __name__ == "__main__":
    bs, c, h, w = 10, 16, 64, 64
    in_tensor = torch.ones(bs, c, h, w)

    cs_se = CoordAtt(c,c)
    print("in shape:",in_tensor.shape)
    out_tensor = cs_se(in_tensor)
    print("out shape:", out_tensor.shape)

In [None]:
class MS_CAM(nn.Module): 
    def __init__(self, channels=64, r=8):
        super(MS_CAM, self).__init__()
        inter_channels = int(channels // r)

        self.local_att = nn.Sequential(
            nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels),
        )

        self.global_att = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(channels),
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        xl = self.local_att(x)
        xg = self.global_att(x)
        xlg = xl + xg
        wei = self.sigmoid(xlg)
        return x * wei

if __name__ == "__main__":
    bs, c, h, w = 10, 16, 64, 64
    in_tensor = torch.ones(bs, c, h, w)

    cs_se = MS_CAM(c)
    print("in shape:",in_tensor.shape)
    out_tensor = cs_se(in_tensor)
    print("out shape:", out_tensor.shape)

In [None]:
from torch import Tensor
class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i+1),mode='bilinear') 
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class UnetBlock(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        self.bn2 = nn.BatchNorm2d(nf)             
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.gelu = nn.GELU()
        self.relu = nn.ReLU(inplace=True)
        self.nf = nf
    def forward(self, up_in, left_in):
        s = left_in
        #c_se = csSE(s.shape[1]).cuda()  
        #c_se = SELayer(s.shape[1]).cuda() 
        #c_se = BasicBlock(s.shape[1], s.shape[1], stride=1).cuda() 
        #c_se = CoordAtt(s.shape[1],s.shape[1]).cuda() 
        #s = c_se(s)
        up_out = self.shuf(up_in)
        cat_x = self.gelu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.gelu(self.bn2(self.conv1(cat_x))))
        
class UnetBlockWithAtt(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        self.bn2 = nn.BatchNorm2d(nf)             
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.gelu = nn.GELU()
        self.relu = nn.ReLU(inplace=True)
        self.nf = nf
    def forward(self, up_in, left_in):
        s = left_in
        #c_se = csSE(s.shape[1]).cuda()  #0.8832
        #c_se = SELayer(s.shape[1]).cuda() #0.8789
        c_se = BasicBlock(s.shape[1], s.shape[1], stride=1).cuda() #0.8833 CBAM
        #c_se = CoordAtt(s.shape[1],s.shape[1]).cuda()  #0.8845
        s = c_se(s)
        up_out = self.shuf(up_in)
        cat_x = self.gelu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.gelu(self.bn2(self.conv1(cat_x))))
        
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)
        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU(inplace=True))
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [None]:
class WIDRMS(nn.Module):
    def __init__(self, stride=1, **kwargs):
        super().__init__()
        #encoder
        m = timm.create_model(CFG.model_name, pretrained=False)
        #weights = torch.load(CFG.model_path, map_location=torch.device('cpu'))
        # 应用这些权重到模型上
        #m.load_state_dict(weights)
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = nn.Dropout2d(0.0)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlockWithAtt(128,256,64)
        self.dec1 = UnetBlockWithAtt(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.0)
        self.final_conv = ConvLayer(32+16*4, CFG.num_classes, ks=1, norm_type=None, act_cls=None)
        #self.final_conv = ConvLayer(32, CFG.num_classes, ks=1, norm_type=None, act_cls=None)
        self.num_classes = CFG.num_classes
        
    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        #dec3 = self.dec4(enc4,enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        dec0 = F.interpolate(dec0,scale_factor=2,mode='bilinear')
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        #x = self.final_conv(self.drop(dec0))
        return x

In [None]:
import timm
from fastai.vision.all import PixelShuffle_ICNR
from fastai.vision.all import ConvLayer,SelfAttention
from fastai.vision.all import AdaptiveConcatPool2d, Flatten

import torch
model = WIDRMS().cuda()
# state_dict = torch.load('Pretrained_Weight/seresnext26d_32x4d.bt_in1k_exp01_fold0_epoch24.pth')
# if 'model' in state_dict.keys():
#     state_dict = state_dict['model']
# model.load_state_dict(state_dict)
input_tensor = torch.randn(2, 3, 512, 512).cuda()
output = model(input_tensor)
print('Input shape:', input_tensor.shape)
print('Output shape:', output.shape)

In [None]:
class MedNeXtBlock(nn.Module):
    def __init__(self, 
                in_channels:int, 
                out_channels:int, 
                exp_r:int=4, 
                kernel_size:int=7, 
                do_res:int=True,
                norm_type:str = 'group',
                n_groups:int or None = None,
                dim = '2d',
                grn = False
                ):
        super().__init__()
        self.do_res = do_res
        assert dim in ['2d', '3d']
        self.dim = dim
        if self.dim == '2d':
            conv = nn.Conv2d
        elif self.dim == '3d':
            conv = nn.Conv3d
        # First convolution layer with DepthWise Convolutions
        self.conv1 = conv(
            in_channels = in_channels,
            out_channels = in_channels,
            kernel_size = kernel_size,
            stride = 1,
            padding = kernel_size//2,
            groups = in_channels if n_groups is None else n_groups,
        )
        print()
        # Normalization Layer. GroupNorm is used by default.
        if norm_type=='group':
            self.norm = nn.GroupNorm(
                num_groups=in_channels, 
                num_channels=in_channels
                )
        elif norm_type=='layer':
            self.norm = LayerNorm(
                normalized_shape=in_channels, 
                data_format='channels_first'
                )

        # Second convolution (Expansion) layer with Conv3D 1x1x1
        self.conv2 = conv(
            in_channels = in_channels,
            out_channels = exp_r*in_channels,
            kernel_size = 1,
            stride = 1,
            padding = 0
        )
        
        # GeLU activations
        self.act = nn.GELU()
        
        # Third convolution (Compression) layer with Conv3D 1x1x1
        self.conv3 = conv(
            in_channels = exp_r*in_channels,
            out_channels = out_channels,
            kernel_size = 1,
            stride = 1,
            padding = 0
        )

        self.grn = grn
        if grn:
            if dim == '3d':
                self.grn_beta = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1,1), requires_grad=True)
                self.grn_gamma = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1,1), requires_grad=True)
            elif dim == '2d':
                self.grn_beta = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1), requires_grad=True)
                self.grn_gamma = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1), requires_grad=True)

 
    def forward(self, x, dummy_tensor=None):
        x1 = x
        x1 = self.conv1(x1)
        x1 = self.act(self.conv2(self.norm(x1)))
        if self.grn:
            # gamma, beta: learnable affine transform parameters
            # X: input of shape (N,C,H,W,D)
            if self.dim == '3d':
                gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
            elif self.dim == '2d':
                gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
            nx = gx / (gx.mean(dim=1, keepdim=True)+1e-6)
            x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1
        x1 = self.conv3(x1)
        if self.do_res:
            x1 = x + x1  
        return x1
        
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class MedNeXtDownBlock(MedNeXtBlock):

    def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, 
                do_res=False, norm_type = 'group', dim='2d', grn=False):

        super().__init__(in_channels, out_channels, exp_r, kernel_size, 
                        do_res = False, norm_type = norm_type, dim=dim,
                        grn=grn)

        if dim == '2d':
            conv = nn.Conv2d
        elif dim == '3d':
            conv = nn.Conv3d
        self.resample_do_res = do_res
        if do_res:
            self.res_conv = conv(
                in_channels = in_channels,
                out_channels = out_channels,
                kernel_size = 1,
                stride = 2
            )

        self.conv1 = conv(
            in_channels = in_channels,
            out_channels = in_channels,
            kernel_size = kernel_size,
            stride = 2,
            padding = kernel_size//2,
            groups = in_channels,
        )

    def forward(self, x, dummy_tensor=None):
        
        x1 = super().forward(x)
        
        if self.resample_do_res:
            res = self.res_conv(x)
            x1 = x1 + res

        return x1


class MedNeXtUpBlock(MedNeXtBlock):

    def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, 
                do_res=False, norm_type = 'group', dim='2d', grn = False):
        super().__init__(in_channels, out_channels, exp_r, kernel_size,
                         do_res=False, norm_type = norm_type, dim=dim,
                         grn=grn)

        self.resample_do_res = do_res
        
        self.dim = dim
        if dim == '2d':
            conv = nn.ConvTranspose2d
        elif dim == '3d':
            conv = nn.ConvTranspose3d
        if do_res:            
            self.res_conv = conv(
                in_channels = in_channels,
                out_channels = out_channels,
                kernel_size = 1,
                stride = 2
                )

        self.conv1 = conv(
            in_channels = in_channels,
            out_channels = in_channels,
            kernel_size = kernel_size,
            stride = 2,
            padding = kernel_size//2,
            groups = in_channels,
        )


    def forward(self, x, dummy_tensor=None):
        
        x1 = super().forward(x)
        # Asymmetry but necessary to match shape
        
        if self.dim == '2d':
            x1 = torch.nn.functional.pad(x1, (1,0,1,0))
        elif self.dim == '3d':
            x1 = torch.nn.functional.pad(x1, (1,0,1,0,1,0))
        
        if self.resample_do_res:
            res = self.res_conv(x)
            if self.dim == '2d':
                res = torch.nn.functional.pad(res, (1,0,1,0))
            elif self.dim == '3d':
                res = torch.nn.functional.pad(res, (1,0,1,0,1,0))
            x1 = x1 + res

        return x1


class OutBlock(nn.Module):

    def __init__(self, in_channels, n_classes, dim):
        super().__init__()
        
        if dim == '2d':
            conv = nn.ConvTranspose2d
        elif dim == '3d':
            conv = nn.ConvTranspose3d
        self.conv_out = conv(in_channels, n_classes, kernel_size=1)
    
    def forward(self, x, dummy_tensor=None): 
        return self.conv_out(x)


class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))        # beta
        self.bias = nn.Parameter(torch.zeros(normalized_shape))         # gamma
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x, dummy_tensor=False):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

In [None]:
class MedNeXt(nn.Module):

    def __init__(self, 
        in_channels: int, 
        n_channels: int,
        n_classes: int, 
        exp_r: int = 4,                            # Expansion ratio as in Swin Transformers
        kernel_size: int = 7,                      # Ofcourse can test kernel_size
        enc_kernel_size: int = None,
        dec_kernel_size: int = None,
        deep_supervision: bool = False,             # Can be used to test deep supervision
        do_res: bool = False,                       # Can be used to individually test residual connection
        do_res_up_down: bool = False,             # Additional 'res' connection on up and down convs
        checkpoint_style: bool = None,            # Either inside block or outside block
        block_counts: list = [2,2,2,2,2,2,2,2,2], # Can be used to test staging ratio: 
                                            # [3,3,9,3] in Swin as opposed to [2,2,2,2,2] in nnUNet
        norm_type = 'group',
        dim = '2d',                                # 2d or 3d
        grn = False
    ):

        super().__init__()

        self.do_ds = deep_supervision
        assert checkpoint_style in [None, 'outside_block']
        self.inside_block_checkpointing = False
        self.outside_block_checkpointing = False
        assert dim in ['2d', '3d']
        if kernel_size is not None:
            enc_kernel_size = kernel_size
            dec_kernel_size = kernel_size

        if dim == '2d':
            conv = nn.Conv2d
        elif dim == '3d':
            conv = nn.Conv3d
            
        self.stem = conv(in_channels, n_channels, kernel_size=1)
        if type(exp_r) == int:
            exp_r = [exp_r for i in range(len(block_counts))]
        
        self.enc_block_0 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels,
                out_channels=n_channels,
                exp_r=exp_r[0],
                kernel_size=enc_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                ) 
            for i in range(block_counts[0])]
        ) 

        self.down_0 = MedNeXtDownBlock(
            in_channels=n_channels,
            out_channels=2*n_channels,
            exp_r=exp_r[1],
            kernel_size=enc_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim
        )
    
        self.enc_block_1 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*2,
                out_channels=n_channels*2,
                exp_r=exp_r[1],
                kernel_size=enc_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[1])]
        )

        self.down_1 = MedNeXtDownBlock(
            in_channels=2*n_channels,
            out_channels=4*n_channels,
            exp_r=exp_r[2],
            kernel_size=enc_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )

        self.enc_block_2 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*4,
                out_channels=n_channels*4,
                exp_r=exp_r[2],
                kernel_size=enc_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[2])]
        )

        self.down_2 = MedNeXtDownBlock(
            in_channels=4*n_channels,
            out_channels=8*n_channels,
            exp_r=exp_r[3],
            kernel_size=enc_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )
        
        self.enc_block_3 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*8,
                out_channels=n_channels*8,
                exp_r=exp_r[3],
                kernel_size=enc_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )            
            for i in range(block_counts[3])]
        )
        
        self.down_3 = MedNeXtDownBlock(
            in_channels=8*n_channels,
            out_channels=16*n_channels,
            exp_r=exp_r[4],
            kernel_size=enc_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )

        self.bottleneck = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*16,
                out_channels=n_channels*16,
                exp_r=exp_r[4],
                kernel_size=dec_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[4])]
        )

        self.up_3 = MedNeXtUpBlock(
            in_channels=16*n_channels,
            out_channels=8*n_channels,
            exp_r=exp_r[5],
            kernel_size=dec_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )

        self.dec_block_3 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*8,
                out_channels=n_channels*8,
                exp_r=exp_r[5],
                kernel_size=dec_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[5])]
        )

        self.up_2 = MedNeXtUpBlock(
            in_channels=8*n_channels,
            out_channels=4*n_channels,
            exp_r=exp_r[6],
            kernel_size=dec_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )

        self.dec_block_2 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*4,
                out_channels=n_channels*4,
                exp_r=exp_r[6],
                kernel_size=dec_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[6])]
        )

        self.up_1 = MedNeXtUpBlock(
            in_channels=4*n_channels,
            out_channels=2*n_channels,
            exp_r=exp_r[7],
            kernel_size=dec_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )

        self.dec_block_1 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels*2,
                out_channels=n_channels*2,
                exp_r=exp_r[7],
                kernel_size=dec_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[7])]
        )

        self.up_0 = MedNeXtUpBlock(
            in_channels=2*n_channels,
            out_channels=n_channels,
            exp_r=exp_r[8],
            kernel_size=dec_kernel_size,
            do_res=do_res_up_down,
            norm_type=norm_type,
            dim=dim,
            grn=grn
        )

        self.dec_block_0 = nn.Sequential(*[
            MedNeXtBlock(
                in_channels=n_channels,
                out_channels=n_channels,
                exp_r=exp_r[8],
                kernel_size=dec_kernel_size,
                do_res=do_res,
                norm_type=norm_type,
                dim=dim,
                grn=grn
                )
            for i in range(block_counts[8])]
        )

        self.out_0 = OutBlock(in_channels=n_channels, n_classes=n_classes, dim=dim)

        # Used to fix PyTorch checkpointing bug
        self.dummy_tensor = nn.Parameter(torch.tensor([1.]), requires_grad=True)  

        if deep_supervision:
            self.out_1 = OutBlock(in_channels=n_channels*2, n_classes=n_classes, dim=dim)
            self.out_2 = OutBlock(in_channels=n_channels*4, n_classes=n_classes, dim=dim)
            self.out_3 = OutBlock(in_channels=n_channels*8, n_classes=n_classes, dim=dim)
            self.out_4 = OutBlock(in_channels=n_channels*16, n_classes=n_classes, dim=dim)

        self.block_counts = block_counts


    def iterative_checkpoint(self, sequential_block, x):
        """
        This simply forwards x through each block of the sequential_block while
        using gradient_checkpointing. This implementation is designed to bypass
        the following issue in PyTorch's gradient checkpointing:
        https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/9
        """
        for l in sequential_block:
            x = checkpoint.checkpoint(l, x, self.dummy_tensor)
        return x


    def forward(self, x):
        
        x = self.stem(x)
        
        x_res_0 = self.enc_block_0(x)
        x = self.down_0(x_res_0)
        x_res_1 = self.enc_block_1(x)
        x = self.down_1(x_res_1)
        x_res_2 = self.enc_block_2(x)
        x = self.down_2(x_res_2)
        x_res_3 = self.enc_block_3(x)
        x = self.down_3(x_res_3)

        x = self.bottleneck(x)
        if self.do_ds:
            x_ds_4 = self.out_4(x)

        x_up_3 = self.up_3(x)
        dec_x = x_res_3 + x_up_3 
        x = self.dec_block_3(dec_x)

        if self.do_ds:
            x_ds_3 = self.out_3(x)
        del x_res_3, x_up_3

        x_up_2 = self.up_2(x)
        dec_x = x_res_2 + x_up_2 
        x = self.dec_block_2(dec_x)
        if self.do_ds:
            x_ds_2 = self.out_2(x)
        del x_res_2, x_up_2

        x_up_1 = self.up_1(x)
        dec_x = x_res_1 + x_up_1 
        x = self.dec_block_1(dec_x)
        if self.do_ds:
            x_ds_1 = self.out_1(x)
        del x_res_1, x_up_1

        x_up_0 = self.up_0(x)
        dec_x = x_res_0 + x_up_0 
        x = self.dec_block_0(dec_x)
        del x_res_0, x_up_0, dec_x

        x = self.out_0(x)

        if self.do_ds:
            return [x, x_ds_1, x_ds_2, x_ds_3, x_ds_4]
        else: 
            return x

In [None]:
def create_mednextv1_small(num_input_channels, num_classes, kernel_size=3, ds=False):

    return MedNeXt(
        in_channels = num_input_channels, 
        n_channels = 32,
        n_classes = num_classes, 
        exp_r=2,                         
        kernel_size=kernel_size,         
        deep_supervision=ds,             
        do_res=True,                     
        do_res_up_down = True,
        block_counts = [2,2,2,2,2,2,2,2,2]
    )


def create_mednextv1_base(num_input_channels, num_classes, kernel_size=3, ds=False):

    return MedNeXt(
        in_channels = num_input_channels, 
        n_channels = 32,
        n_classes = num_classes, 
        exp_r=[2,3,4,4,4,4,4,3,2],       
        kernel_size=kernel_size,         
        deep_supervision=ds,             
        do_res=True,                     
        do_res_up_down = True,
        block_counts = [2,2,2,2,2,2,2,2,2]
    )


def create_mednextv1_medium(num_input_channels, num_classes, kernel_size=3, ds=False):

    return MedNeXt(
        in_channels = num_input_channels, 
        n_channels = 32,
        n_classes = num_classes, 
        exp_r=[2,3,4,4,4,4,4,3,2],       
        kernel_size=kernel_size,         
        deep_supervision=ds,             
        do_res=True,                     
        do_res_up_down = True,
        block_counts = [3,4,4,4,4,4,4,4,3],
        checkpoint_style = 'outside_block'
    )

def create_mednextv1_large(num_input_channels, num_classes, kernel_size=3, ds=False):

    return MedNeXt(
        in_channels = num_input_channels, 
        n_channels = 32,
        n_classes = num_classes, 
        exp_r=[3,4,8,8,8,8,8,4,3],                          
        kernel_size=kernel_size,                     
        deep_supervision=ds,             
        do_res=True,                     
        do_res_up_down = True,
        block_counts = [3,4,8,8,8,8,8,4,3],
        checkpoint_style = 'outside_block'
    )

def create_mednext_v1(num_input_channels, num_classes, model_id, kernel_size=3,
                      deep_supervision=False):

    model_dict = {
        'S': create_mednextv1_small,
        'B': create_mednextv1_base,
        'M': create_mednextv1_medium,
        'L': create_mednextv1_large,
        }
    
    return model_dict[model_id](
        num_input_channels, num_classes, kernel_size, deep_supervision
        )

In [None]:
def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice

In [None]:
from medpy import metric
def calculate_metric_percase(gt, pred, thr=0.5):
    gt = gt.cpu().detach().numpy()
    pred = (pred>thr).cpu().detach().numpy()
    dice = metric.binary.dc(pred, gt)
    jc = metric.binary.jc(pred, gt)
    hd = metric.binary.hd95(pred, gt)
    asd = metric.binary.asd(pred, gt)
    return dice, jc, hd, asd

In [None]:
def plot_examples(images, masks, preds, epoch, step):
    fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(9, 9))
    for row in range(3):
        ax[row, 0].imshow(images[row][1].cpu().squeeze(), cmap='gray')
        ax[row, 0].set_title(f"Epoch {epoch} Batch {step} Sample {row} - Image")
        ax[row, 1].imshow(masks[row].cpu().squeeze(), cmap='gray')
        ax[row, 1].set_title(f"Epoch {epoch} Batch {step} Sample {row} - Ground Truth")
        ax[row, 2].imshow(preds[row].cpu().squeeze(), cmap='gray')
        ax[row, 2].set_title(f"Epoch {epoch} Batch {step} Sample {row} - Prediction")
        for col in range(3):
            ax[row, col].axis("off")
    plt.show()

In [None]:
from SegLossFunc import SegLoss
BoundaryDoULoss = SegLoss.BoundaryDoULoss()

SoftDiceCLDiceBoundaryDoULoss = SegLoss.SoftDiceCLDiceBoundaryDoULoss()

StructureLoss = SegLoss.StructureLoss()
StructureLossBoundaryDOU = SegLoss.StructureLossBoundaryDOU()
StructureLossBoundaryDOUV2 = SegLoss.StructureLossBoundaryDOUV2()

JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
FocalLoss = smp.losses.FocalLoss(mode="multilabel")

def criterion(y_pred, y_true):
    return BoundaryDoULoss(y_pred, y_true) + 0.5*BCELoss(y_pred, y_true) + 0.5*DiceLoss(y_pred, y_true)

In [None]:
# 自定义逐渐升温调度器
class GradualWarmupSchedulerV3(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV3, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch >= self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

In [None]:
def train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    model.train()
    start = end = time.time()
    for step, (images, bi_masks, multi_mask) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        if CFG.num_classes == 1:
            masks = bi_masks.to(device, dtype=torch.float)
        else:
            masks = multi_mask.to(device, dtype=torch.float)
        batch_size = images.size(0)
        with autocast(enabled=True):
            y_preds = model(images)
            loss = criterion(y_preds, masks)
        # record loss
        losses.update(loss.item(), batch_size)
        if CFG.accum_iter > 1:
            loss = loss / CFG.accum_iter
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.accum_iter == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            cusprint('Epoch: [{0}][{1}/{2}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                'Grad: {grad_norm:.4f}  '
                'LR: {lr:.7f}  '
                .format(
                epoch, step, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(train_loader)),
                grad_norm=grad_norm,
                lr=optimizer.param_groups[0]["lr"],
                ))
    return losses.avg, optimizer.param_groups[0]["lr"]

In [None]:
def valid_one_epoch(valid_loader, model, criterion, device,epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    start = end = time.time()
    val_scores = []
    for step, (images, bi_masks, multi_mask) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        if CFG.num_classes == 1:
            masks = bi_masks.to(device, dtype=torch.float)
        else:
            masks = multi_mask.to(device, dtype=torch.float)
        batch_size = images.size(0)

        # compute loss
        with torch.no_grad():
            y_pred = model(images)

        loss = criterion(y_pred, masks)
        losses.update(loss.item(), batch_size)
        # record accuracy
        y_pred = y_pred.sigmoid() ####
        # y_pred = y_pred.sigmoid().to('cpu').numpy()
        
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        #val_scores.append([val_dice])
        dice, jc, hd, asd = calculate_metric_percase(masks, y_pred)

        val_scores.append([val_dice, dice, jc, hd, asd])
        
        if CFG.accum_iter > 1:
            loss = loss / CFG.accum_iter
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            cusprint('EVAL: [{0}/{1}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                .format(
                step, len(valid_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(valid_loader)),
                ))
    val_scores = np.mean(val_scores, axis=0)
    
    if val_scores[0]>0.95:
        selected_indices = np.random.choice(images.shape[0], 3, replace=False)
        images = images[selected_indices]
        masks = masks[selected_indices]
        preds = y_pred[selected_indices]
        plot_examples(images, masks, preds,epoch,step)
    torch.cuda.empty_cache()
    gc.collect()
    return losses.avg, val_scores

In [None]:
# 训练函数
def train_loop(train_df, fold_1, fold_2, criterion):
    loginfo(f"========== training ==========")
    # ====================================================
    # loader 
    # ====================================================
    train_folds = train_df[train_df["fold"].isin(fold_1)].reset_index(drop=True)
    valid_folds = train_df[train_df["fold"].isin(fold_2)].reset_index(drop=True)

    train_img,train_label=load_data(train_folds)
    valid_img,valid_label=load_data(valid_folds)
    
    train_dataset = Spine_Dataset([train_img],[train_label],arg=True)
    train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size ,num_workers=0, shuffle=True, pin_memory=True)
    valid_dataset = Spine_Dataset([valid_img],[valid_label],arg=False)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_batch_size ,num_workers=0, shuffle=False, pin_memory=True)
    # ====================================================
    # model & optimizer & scheduler & loss 
    # ====================================================
    model = WIDRMS().cuda()
    #model = torch.nn.DataParallel(model, device_ids=[0, 1])
    # optimizer
    if CFG.optimizer == "AdamW":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)  
    # scheduler
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)

    if CFG.scheduler_warmup=="GradualWarmupSchedulerV3":
        scheduler_warmup = GradualWarmupSchedulerV3(optimizer, multiplier=10, total_epoch=CFG.warmup_epo, after_scheduler=scheduler)

    # ====================================================
    # loop 
    # ====================================================

    valid_acc_max=0
    valid_acc_max_cnt=0
    for epoch in range(CFG.epochs):
        loginfo(f"***** Epoch {epoch} *****")
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            pass
            # loginfo(f"schwarmup_last_epoch:{scheduler_warmup.last_epoch}, schwarmup_lr:{scheduler_warmup.get_last_lr()[0]}")
        if CFG.scheduler=='CosineAnnealingLR':
            loginfo(f"scheduler_last_epoch:{scheduler.last_epoch}, scheduler_lr:{scheduler.get_last_lr()[0]}")
        loginfo(f"optimizer_lr:{optimizer.param_groups[0]['lr']}")
                
        start_time = time.time() # 记录当前时间
        avg_loss, cur_lr = train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        avg_val_loss, valid_scores = valid_one_epoch(valid_loader, model, criterion, device,epoch)
        # scoring
        elapsed = time.time() - start_time
        # print("valid_scores:", valid_scores, type(valid_scores))
        val_dice, dice, jc, hd, asd = valid_scores
        
        loginfo(f'Epoch {epoch} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        loginfo(f'Epoch {epoch} -val_dice: {val_dice:.4f} - Dice Score: {dice:.4f}- Jaccard Score: {jc:.4f}- HD95 Score: {hd:.4f}- ASSD Score: {asd:.4f}')
    
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            scheduler_warmup.step()
        elif CFG.scheduler == "ReduceLROnPlateau":
            scheduler.step(avg_val_loss)
        elif CFG.scheduler in ["CosineAnnealingLR", "CosineAnnealingWarmRestarts"]:
            scheduler.step()

        torch.save({'model': model.state_dict()}, CFG.output_dir+f'/{CFG.model_name}_{CFG.exp_name}_fold{fold}_epoch{epoch}.pth')
        
        # early stopping 
        if val_dice > valid_acc_max:
            valid_acc_max = dice
            valid_acc_max_cnt=0
            best_acc_epoch = epoch
        else:
            valid_acc_max_cnt+=1
        
        if valid_acc_max_cnt >= CFG.n_early_stopping:
            torch.save({'model': model.state_dict()}, CFG.output_dir+f'/{CFG.model_name}_{CFG.exp_name}_fold{fold}_epoch{epoch}.pth')
        
            print("early_stopping")
            break
        
        torch.save({'model': model.state_dict()}, CFG.output_dir+f'/{CFG.model_name}_{CFG.exp_name}_fold{fold}_epoch{epoch}.pth')

In [None]:
def valid():
    fold_1 = CFG.train_fold_list
    fold_2 = CFG.valid_fold_list
    train_loop(df, fold_1, fold_2, criterion)
def test():
    fold_1 = CFG.train_fold_list
    fold_2 = CFG.test_fold_list
    train_loop(df, fold_1, fold_2, criterion)

In [None]:
test()