# Prepare for HSID-2D

## Pip

In [1]:
#pip install opencv-python
#pip install albumentations
#pip install segmentation-models-pytorch
#pip install timm
#pip install wandb

import os
import nibabel as nib
import pandas as pd
from tqdm import tqdm
import logging
import imageio
import numpy as np
import torch
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
import os,sys,cv2
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel
from glob import glob
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import train_test_split
import random

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

  from .autonotebook import tqdm as notebook_tqdm


## Config

In [2]:
class CFG:
    #Predata
    predata = True
    top_level_dir = ['ExternalDataset/ExternalDataset/SpineSagT2Wdataset','ExternalDataset/ExternalDataset/dataset15','ExternalDataset/ExternalDataset/VerSe2019']
    output_dir = 'Pretrain_data'
    target_height = 512
    target_width = 512

    wandb = False
    seed = 42
    project = 'Spine'
    exp_name = 'exp01'
    n_fold = 5
    valid_fold = 4
    chopping_percentile = 1e-3
    in_chans = 1
    train_batch_size = 16
    valid_batch_size = 32
    
    train_aug_list = [
        A.Rotate(limit=270, p= 0.2),
        A.GaussianBlur(p=0.1),
        A.MotionBlur(p=0.1),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.1),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)

In [3]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 
    torch.backends.cudnn.enabled = True
    
seed_everything(CFG.seed)

In [4]:
if CFG.predata:
    os.makedirs(CFG.output_dir, exist_ok=True)
    # def normalize_upsample(image,is_label):
    #     # 缩放图像数据到0-255范围内以适应PNG格式
    #     if not is_label:
    #         image = image.astype(np.float32)
    #         min_val = np.min(image)
    #         max_val = np.max(image)
    #         image = (image - min_val) / (max_val - min_val + 1e-9) * 255
    #     image = cv2.resize(image, (CFG.target_width, CFG.target_height), interpolation=cv2.INTER_CUBIC)
    #     return image.astype(np.uint8)
    
    def normalize_and_pad(image, is_label):
        # 缩放图像数据到0-255范围内以适应PNG格式
        if not is_label:
            image = image.astype(np.float32)
            min_val = np.min(image)
            max_val = np.max(image)
            image = (image - min_val) / (max_val - min_val + 1e-9) * 255
            
        image_height, image_width = image.shape[:2]
        standard_l = max(image_height, image_width)
        # 计算两边的填充大小
        delta_w = standard_l - image_width
        delta_h = standard_l - image_height
        top, bottom = delta_h // 2, delta_h-(delta_h // 2)
        left, right = delta_w // 2, delta_w-(delta_w // 2)
        color = [0, 0, 0]
        image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
        image = cv2.resize(image, (CFG.target_width, CFG.target_height), interpolation=cv2.INTER_CUBIC)
        return image.astype(np.uint8)
        
    # 用于存储所有信息的DataFrame
    all_image_info = []
    # 遍历定义的stage
    for dir_index in range(len(CFG.top_level_dir)):
        if dir_index == 0:
            for stage in ['train']:
                #使用logger输出当前stage
                logger.info(f"当前阶段：{stage}")
                img_dir = os.path.join(CFG.top_level_dir[dir_index], stage, 'image')
                groundtruth_dir = os.path.join(CFG.top_level_dir[dir_index], stage, 'groundtruth')
                img_output_dir = os.path.join(CFG.output_dir, 'image')
                groundtruth_output_dir = os.path.join(CFG.output_dir, 'groundtruth')
                os.makedirs(img_output_dir, exist_ok=True)
                os.makedirs(groundtruth_output_dir, exist_ok=True)
                # 遍历文件夹内的文件
                for file_name in tqdm(os.listdir(img_dir)):
                    if file_name.endswith('.nii.gz'):
                        # 读取image和groundtruth文件
                        img_path = os.path.join(img_dir, file_name)
                        groundtruth_path = os.path.join(groundtruth_dir, f'mask_{file_name.lower()}')
                        try:
                            img = nib.load(img_path).get_fdata()
                            label = nib.load(groundtruth_path).get_fdata()
                            height, width, depth = img.shape
                            print(img.shape)
                            # 检查shape是否相同
                            if img.shape != label.shape:
                                raise ValueError("Image and label shape do not match.")
                        except FileNotFoundError as e:
                            logger.warning(f"文件未找到: {e}")
                            continue
                        except ValueError as e:
                            logger.warning(f"Shape不匹配警告: {e} - 跳过文件 {file_name}")
                            continue
                        except Exception as e:
                            logger.warning(f"处理文件 {file_name} 时发生未知错误: {e}")
                            continue
                        base_name = file_name.split('.')[0]
                        # 遍历所有切片并保存为PNG文件
                        for i in range(img.shape[2]):
                            image_2d = img[:, :, i]
                            label_2d = label[:, :, i]
                            img_filename = f'{dir_index}_{base_name}_{i}.png'
                            gt_filename = f'{dir_index}_{base_name}_{i}_label.png'
                            img_png_path = os.path.join(img_output_dir, img_filename)
                            gt_png_path = os.path.join(groundtruth_output_dir, gt_filename)
                            # 保存图片和标签的PNG文件
                            imageio.imwrite(img_png_path,normalize_and_pad(image_2d,is_label=False))
                            imageio.imwrite(gt_png_path,normalize_and_pad(label_2d,is_label=True))
                            # 将信息添加到列表中
                            all_image_info.append({
                                'Image': img_filename,
                                'Dataset_index':dir_index,
                                'Case':base_name,
                                'ImagePath': img_png_path,
                                'GroundTruthPath': gt_png_path,
                                'height':height,
                                'width':width,
                            })
        elif dir_index == 1:
            total_dir = os.path.join(CFG.top_level_dir[dir_index])
            for file_name in tqdm(os.listdir(total_dir)):
                if file_name.endswith('.nii.gz') and 'label' not in file_name:
                    base_name = file_name.split('.')[0]
                    img_path = os.path.join(total_dir, f'{file_name.lower()}')
                    groundtruth_path = os.path.join(total_dir, f'{base_name}_label.nii.gz')
                    try:
                        img = nib.load(img_path).get_fdata()
                        label = nib.load(groundtruth_path).get_fdata()
                        height, width, depth = img.shape
                        print('here',img.shape)
                        # 检查shape是否相同
                        if img.shape != label.shape:
                            raise ValueError("Image and label shape do not match.")
                    except FileNotFoundError as e:
                        logger.warning(f"文件未找到: {e}")
                        continue
                    except ValueError as e:
                        logger.warning(f"Shape不匹配警告: {e} - 跳过文件 {file_name}")
                        continue
                    except Exception as e:
                        logger.warning(f"处理文件 {file_name} 时发生未知错误: {e}")
                        continue
                    # 遍历所有切片并保存为PNG文件
                    start_index = img.shape[0] // 2 - 5
                    end_index = start_index + 15
                    for i in range(start_index, end_index):
                        image_2d = img[i, :, :]
                        label_2d = label[i, :, :]
                        img_filename = f'{dir_index}_{base_name}_{i}.png'
                        gt_filename = f'{dir_index}_{base_name}_{i}_label.png'
                        img_png_path = os.path.join(img_output_dir, img_filename)
                        gt_png_path = os.path.join(groundtruth_output_dir, gt_filename)
                        # 保存图片和标签的PNG文件
                        imageio.imwrite(img_png_path,normalize_and_pad(image_2d,is_label=False))
                        imageio.imwrite(gt_png_path,normalize_and_pad(label_2d,is_label=True))
                        # 将信息添加到列表中
                        all_image_info.append({
                            'Image': img_filename,
                            'Dataset_index':dir_index,
                            'Case':base_name,
                            'ImagePath': img_png_path,
                            'GroundTruthPath': gt_png_path,
                            'height':height,
                            'width':width,
                            })
        else:
            for stage in ['training_phase_1_release','training_phase_2_release','training_phase_3_release']: 
                #使用logger输出当前stage
                logger.info(f"当前阶段：{stage}")
                total_dir = os.path.join(CFG.top_level_dir[dir_index], stage)
                # 遍历文件夹内的文件
                for file_name in tqdm(os.listdir(total_dir)):
                    if file_name.endswith('.nii.gz') and 'seg' not in file_name:
                        base_name = file_name.split('.')[0]
                        # 读取image和groundtruth文件
                        img_path = os.path.join(total_dir, file_name.lower())
                        groundtruth_path = os.path.join(total_dir, f'{base_name}_seg.nii.gz')
                        try:
                            img = nib.load(img_path).get_fdata()
                            label = nib.load(groundtruth_path).get_fdata()
                            height, width, depth = img.shape
                            print(img.shape)
                            # 检查shape是否相同
                            if img.shape != label.shape:
                                raise ValueError("Image and label shape do not match.")
                        except FileNotFoundError as e:
                            logger.warning(f"文件未找到: {e}")
                            continue
                        except ValueError as e:
                            logger.warning(f"Shape不匹配警告: {e} - 跳过文件 {file_name}")
                            continue
                        except Exception as e:
                            logger.warning(f"处理文件 {file_name} 时发生未知错误: {e}")
                            continue
                        
                        # 遍历所有切片并保存为PNG文件
                        start_index = img.shape[0] // 2 - 5
                        end_index = start_index + 15
                        # 迭代中间20张图像
                        for i in range(start_index, end_index):
                            image_2d = img[i, :, :]
                            label_2d = label[i, :, :]
                            img_filename = f'{dir_index}_{base_name}_{i}.png'
                            gt_filename = f'{dir_index}_{base_name}_{i}_label.png'
                            img_png_path = os.path.join(img_output_dir, img_filename)
                            gt_png_path = os.path.join(groundtruth_output_dir, gt_filename)
                            # 保存图片和标签的PNG文件
                            imageio.imwrite(img_png_path,normalize_and_pad(image_2d,is_label=False))
                            imageio.imwrite(gt_png_path,normalize_and_pad(label_2d,is_label=True))
                            # 将信息添加到列表中
                            all_image_info.append({
                            'Image': img_filename,
                            'Dataset_index':dir_index,
                            'Case':base_name,
                            'ImagePath': img_png_path,
                            'GroundTruthPath': gt_png_path,
                            'height':height,
                            'width':width,
                            })
    # 使用列表创建DataFrame
    df = pd.DataFrame(all_image_info)
    # 保存DataFrame为CSV文件
    df.to_csv('image_groundtruth_data.csv', index=False)
    logger.info(f"处理完成，DataFrame已保存为 image_groundtruth_data.csv")
else:
    df = pd.read_csv('image_groundtruth_data.csv')

INFO:root:当前阶段：train
  0%|          | 0/200 [00:00<?, ?it/s]

(880, 880, 12)


  0%|          | 1/200 [00:00<01:59,  1.67it/s]

(960, 960, 15)


  1%|          | 2/200 [00:01<02:21,  1.40it/s]

(880, 880, 12)


  2%|▏         | 3/200 [00:01<02:04,  1.59it/s]

(1008, 1008, 12)


  2%|▏         | 4/200 [00:02<02:00,  1.62it/s]

(880, 880, 12)


  2%|▎         | 5/200 [00:03<01:51,  1.74it/s]

(880, 880, 12)


  3%|▎         | 6/200 [00:03<01:49,  1.78it/s]

(880, 880, 12)


  4%|▎         | 7/200 [00:04<01:48,  1.77it/s]

(880, 880, 15)


  4%|▍         | 8/200 [00:04<01:52,  1.71it/s]

(880, 880, 12)


  4%|▍         | 9/200 [00:05<01:47,  1.77it/s]

(896, 896, 15)


  5%|▌         | 10/200 [00:05<01:55,  1.65it/s]

(880, 880, 15)


  6%|▌         | 11/200 [00:06<01:58,  1.60it/s]

(880, 880, 12)


  6%|▌         | 12/200 [00:07<01:52,  1.68it/s]

(880, 880, 12)


  6%|▋         | 13/200 [00:07<01:47,  1.73it/s]

(880, 880, 12)


  7%|▋         | 14/200 [00:08<01:45,  1.76it/s]

(880, 880, 12)


  8%|▊         | 15/200 [00:08<01:43,  1.79it/s]

(880, 880, 15)


  8%|▊         | 16/200 [00:09<01:46,  1.72it/s]

(880, 880, 12)


  8%|▊         | 17/200 [00:09<01:42,  1.79it/s]

(880, 880, 12)


  9%|▉         | 18/200 [00:10<01:40,  1.81it/s]

(880, 880, 15)


 10%|▉         | 19/200 [00:11<01:44,  1.72it/s]

(880, 880, 12)


 10%|█         | 20/200 [00:11<01:42,  1.76it/s]

(880, 880, 18)


 10%|█         | 21/200 [00:12<01:51,  1.61it/s]

(880, 880, 12)


 11%|█         | 22/200 [00:12<01:47,  1.66it/s]

(880, 880, 12)


 12%|█▏        | 23/200 [00:13<01:44,  1.70it/s]

(880, 880, 12)


 12%|█▏        | 24/200 [00:14<01:39,  1.77it/s]

(880, 880, 15)


 12%|█▎        | 25/200 [00:14<01:42,  1.71it/s]

(880, 880, 12)


 13%|█▎        | 26/200 [00:15<01:37,  1.78it/s]

(880, 880, 12)


 14%|█▎        | 27/200 [00:15<01:35,  1.81it/s]

(880, 880, 12)


 14%|█▍        | 28/200 [00:16<01:33,  1.84it/s]

(880, 880, 15)


 14%|█▍        | 29/200 [00:16<01:37,  1.76it/s]

(880, 880, 12)


 15%|█▌        | 30/200 [00:17<01:33,  1.81it/s]

(880, 880, 12)


 16%|█▌        | 31/200 [00:17<01:32,  1.83it/s]

(880, 880, 12)


 16%|█▌        | 32/200 [00:18<01:30,  1.86it/s]

(1008, 1008, 12)


 16%|█▋        | 33/200 [00:18<01:32,  1.81it/s]

(880, 880, 12)


 17%|█▋        | 34/200 [00:19<01:30,  1.84it/s]

(880, 880, 12)


 18%|█▊        | 35/200 [00:20<01:28,  1.86it/s]

(880, 880, 12)


 18%|█▊        | 36/200 [00:20<01:29,  1.84it/s]

(880, 880, 12)


 18%|█▊        | 37/200 [00:21<01:27,  1.86it/s]

(880, 880, 12)


 19%|█▉        | 38/200 [00:21<01:26,  1.87it/s]

(880, 880, 12)


 20%|█▉        | 39/200 [00:22<01:26,  1.85it/s]

(880, 880, 12)


 20%|██        | 40/200 [00:22<01:26,  1.85it/s]

(880, 880, 12)


 20%|██        | 41/200 [00:23<01:24,  1.88it/s]

(960, 960, 15)


 21%|██        | 42/200 [00:23<01:33,  1.69it/s]

(880, 880, 12)


 22%|██▏       | 43/200 [00:24<01:29,  1.76it/s]

(880, 880, 12)


 22%|██▏       | 44/200 [00:25<01:26,  1.79it/s]

(880, 880, 12)


 22%|██▎       | 45/200 [00:25<01:26,  1.78it/s]

(880, 880, 12)


 23%|██▎       | 46/200 [00:26<01:25,  1.81it/s]

(880, 880, 12)


 24%|██▎       | 47/200 [00:26<01:24,  1.82it/s]

(880, 880, 12)


 24%|██▍       | 48/200 [00:27<01:24,  1.80it/s]

(880, 880, 12)


 24%|██▍       | 49/200 [00:27<01:21,  1.85it/s]

(960, 960, 12)


 25%|██▌       | 50/200 [00:28<01:26,  1.74it/s]

(880, 880, 12)


 26%|██▌       | 51/200 [00:28<01:24,  1.77it/s]

(880, 880, 15)


 26%|██▌       | 52/200 [00:29<01:27,  1.70it/s]

(880, 880, 12)


 26%|██▋       | 53/200 [00:30<01:26,  1.71it/s]

(880, 880, 15)


 27%|██▋       | 54/200 [00:30<01:29,  1.63it/s]

(880, 880, 15)


 28%|██▊       | 55/200 [00:31<01:31,  1.58it/s]

(880, 880, 12)


 28%|██▊       | 56/200 [00:32<01:26,  1.66it/s]

(880, 880, 15)


 28%|██▊       | 57/200 [00:32<01:28,  1.61it/s]

(880, 880, 12)


 29%|██▉       | 58/200 [00:33<01:25,  1.65it/s]

(880, 880, 12)


 30%|██▉       | 59/200 [00:33<01:22,  1.71it/s]

(880, 880, 12)


 30%|███       | 60/200 [00:34<01:21,  1.73it/s]

(880, 880, 12)


 30%|███       | 61/200 [00:34<01:17,  1.78it/s]

(880, 880, 12)


 31%|███       | 62/200 [00:35<01:15,  1.82it/s]

(880, 880, 12)


 32%|███▏      | 63/200 [00:35<01:13,  1.86it/s]

(880, 880, 12)


 32%|███▏      | 64/200 [00:36<01:12,  1.87it/s]

(880, 880, 12)


 32%|███▎      | 65/200 [00:36<01:11,  1.89it/s]

(880, 880, 15)


 33%|███▎      | 66/200 [00:37<01:15,  1.77it/s]

(880, 880, 12)


 34%|███▎      | 67/200 [00:38<01:13,  1.80it/s]

(880, 880, 12)


 34%|███▍      | 68/200 [00:38<01:13,  1.80it/s]

(880, 880, 12)


 34%|███▍      | 69/200 [00:39<01:11,  1.83it/s]

(880, 880, 12)


 35%|███▌      | 70/200 [00:39<01:10,  1.84it/s]

(880, 880, 12)


 36%|███▌      | 71/200 [00:40<01:10,  1.83it/s]

(880, 880, 12)


 36%|███▌      | 72/200 [00:40<01:09,  1.83it/s]

(880, 880, 15)


 36%|███▋      | 73/200 [00:41<01:13,  1.72it/s]

(880, 880, 12)


 37%|███▋      | 74/200 [00:42<01:13,  1.73it/s]

(880, 880, 12)


 38%|███▊      | 75/200 [00:42<01:11,  1.75it/s]

(880, 880, 12)


 38%|███▊      | 76/200 [00:43<01:10,  1.76it/s]

(880, 880, 15)


 38%|███▊      | 77/200 [00:43<01:13,  1.68it/s]

(880, 880, 12)


 39%|███▉      | 78/200 [00:44<01:10,  1.73it/s]

(880, 880, 12)


 40%|███▉      | 79/200 [00:44<01:07,  1.78it/s]

(880, 880, 12)


 40%|████      | 80/200 [00:45<01:06,  1.82it/s]

(880, 880, 12)


 40%|████      | 81/200 [00:46<01:05,  1.81it/s]

(880, 880, 15)


 41%|████      | 82/200 [00:46<01:08,  1.73it/s]

(880, 880, 12)


 42%|████▏     | 83/200 [00:47<01:05,  1.79it/s]

(880, 880, 12)


 42%|████▏     | 84/200 [00:47<01:03,  1.81it/s]

(880, 880, 12)


 42%|████▎     | 85/200 [00:48<01:02,  1.84it/s]

(880, 880, 15)


 43%|████▎     | 86/200 [00:48<01:05,  1.75it/s]

(880, 880, 12)


 44%|████▎     | 87/200 [00:49<01:02,  1.81it/s]

(880, 880, 12)


 44%|████▍     | 88/200 [00:49<01:01,  1.83it/s]

(880, 880, 12)


 44%|████▍     | 89/200 [00:50<00:59,  1.86it/s]

(880, 880, 12)


 45%|████▌     | 90/200 [00:50<00:58,  1.88it/s]

(880, 880, 12)


 46%|████▌     | 91/200 [00:51<00:57,  1.90it/s]

(880, 880, 12)


 46%|████▌     | 92/200 [00:51<00:56,  1.91it/s]

(880, 880, 12)


 46%|████▋     | 93/200 [00:52<00:57,  1.87it/s]

(880, 880, 12)


 47%|████▋     | 94/200 [00:53<00:56,  1.89it/s]

(864, 864, 12)


 48%|████▊     | 95/200 [00:53<00:56,  1.85it/s]

(880, 880, 12)


 48%|████▊     | 96/200 [00:54<00:55,  1.88it/s]

(880, 880, 12)


 48%|████▊     | 97/200 [00:54<00:54,  1.90it/s]

(880, 880, 15)


 49%|████▉     | 98/200 [00:55<00:57,  1.78it/s]

(880, 880, 15)


 50%|████▉     | 99/200 [00:55<00:59,  1.69it/s]

(880, 880, 12)


 50%|█████     | 100/200 [00:56<00:56,  1.78it/s]

(880, 880, 15)


 50%|█████     | 101/200 [00:57<00:58,  1.69it/s]

(880, 880, 15)


 51%|█████     | 102/200 [00:57<00:59,  1.64it/s]

(880, 880, 12)


 52%|█████▏    | 103/200 [00:58<00:56,  1.72it/s]

(880, 880, 15)


 52%|█████▏    | 104/200 [00:58<00:57,  1.66it/s]

(880, 880, 12)


 52%|█████▎    | 105/200 [00:59<00:54,  1.75it/s]

(880, 880, 12)


 53%|█████▎    | 106/200 [00:59<00:52,  1.79it/s]

(880, 880, 12)


 54%|█████▎    | 107/200 [01:00<00:51,  1.82it/s]

(880, 880, 15)


 54%|█████▍    | 108/200 [01:01<00:52,  1.74it/s]

(880, 880, 12)


 55%|█████▍    | 109/200 [01:01<00:51,  1.78it/s]

(880, 880, 12)


 55%|█████▌    | 110/200 [01:02<00:48,  1.86it/s]

(880, 880, 12)


 56%|█████▌    | 111/200 [01:02<00:47,  1.87it/s]

(880, 880, 15)


 56%|█████▌    | 112/200 [01:03<00:49,  1.77it/s]

(880, 880, 15)


 56%|█████▋    | 113/200 [01:03<00:50,  1.71it/s]

(880, 880, 12)


 57%|█████▋    | 114/200 [01:04<00:49,  1.75it/s]

(880, 880, 12)


 57%|█████▊    | 115/200 [01:05<00:48,  1.76it/s]

(880, 880, 12)


 58%|█████▊    | 116/200 [01:05<00:46,  1.80it/s]

(880, 880, 12)


 58%|█████▊    | 117/200 [01:06<00:45,  1.81it/s]

(880, 880, 12)


 59%|█████▉    | 118/200 [01:06<00:44,  1.86it/s]

(880, 880, 12)


 60%|█████▉    | 119/200 [01:07<00:43,  1.88it/s]

(880, 880, 12)


 60%|██████    | 120/200 [01:07<00:42,  1.87it/s]

(880, 880, 12)


 60%|██████    | 121/200 [01:08<00:42,  1.88it/s]

(880, 880, 12)


 61%|██████    | 122/200 [01:08<00:41,  1.90it/s]

(880, 880, 12)


 62%|██████▏   | 123/200 [01:09<00:40,  1.89it/s]

(880, 880, 12)


 62%|██████▏   | 124/200 [01:09<00:39,  1.90it/s]

(880, 880, 12)


 62%|██████▎   | 125/200 [01:10<00:39,  1.89it/s]

(880, 880, 12)


 63%|██████▎   | 126/200 [01:10<00:39,  1.89it/s]

(880, 880, 12)


 64%|██████▎   | 127/200 [01:11<00:39,  1.86it/s]

(512, 512, 12)


 64%|██████▍   | 128/200 [01:11<00:36,  1.99it/s]

(880, 880, 12)


 64%|██████▍   | 129/200 [01:12<00:36,  1.94it/s]

(880, 880, 12)


 65%|██████▌   | 130/200 [01:12<00:35,  1.95it/s]

(880, 880, 12)


 66%|██████▌   | 131/200 [01:13<00:36,  1.91it/s]

(1008, 1008, 12)


 66%|██████▌   | 132/200 [01:14<00:37,  1.82it/s]

(880, 880, 12)


 66%|██████▋   | 133/200 [01:14<00:36,  1.82it/s]

(880, 880, 12)


 67%|██████▋   | 134/200 [01:15<00:35,  1.85it/s]

(512, 512, 12)


 68%|██████▊   | 135/200 [01:15<00:32,  2.01it/s]

(880, 880, 12)


 68%|██████▊   | 136/200 [01:16<00:33,  1.93it/s]

(880, 880, 12)


 68%|██████▊   | 137/200 [01:16<00:32,  1.91it/s]

(880, 880, 12)


 69%|██████▉   | 138/200 [01:17<00:32,  1.93it/s]

(880, 880, 12)


 70%|██████▉   | 139/200 [01:17<00:31,  1.91it/s]

(1008, 1008, 12)


 70%|███████   | 140/200 [01:18<00:32,  1.84it/s]

(512, 512, 12)


 70%|███████   | 141/200 [01:18<00:29,  1.98it/s]

(880, 880, 15)


 71%|███████   | 142/200 [01:19<00:32,  1.81it/s]

(880, 880, 15)


 72%|███████▏  | 143/200 [01:19<00:32,  1.73it/s]

(880, 880, 12)


 72%|███████▏  | 144/200 [01:20<00:31,  1.79it/s]

(880, 880, 12)


 72%|███████▎  | 145/200 [01:20<00:30,  1.82it/s]

(880, 880, 12)


 73%|███████▎  | 146/200 [01:21<00:29,  1.83it/s]

(880, 880, 12)


 74%|███████▎  | 147/200 [01:22<00:28,  1.85it/s]

(880, 880, 12)


 74%|███████▍  | 148/200 [01:22<00:27,  1.87it/s]

(880, 880, 12)


 74%|███████▍  | 149/200 [01:23<00:27,  1.86it/s]

(880, 880, 15)


 75%|███████▌  | 150/200 [01:23<00:28,  1.76it/s]

(880, 880, 12)


 76%|███████▌  | 151/200 [01:24<00:27,  1.80it/s]

(880, 880, 15)


 76%|███████▌  | 152/200 [01:24<00:27,  1.72it/s]

(880, 880, 12)


 76%|███████▋  | 153/200 [01:25<00:26,  1.76it/s]

(880, 880, 12)


 77%|███████▋  | 154/200 [01:26<00:25,  1.78it/s]

(880, 880, 12)


 78%|███████▊  | 155/200 [01:26<00:25,  1.79it/s]

(880, 880, 12)


 78%|███████▊  | 156/200 [01:27<00:24,  1.76it/s]

(880, 880, 12)


 78%|███████▊  | 157/200 [01:27<00:24,  1.77it/s]

(880, 880, 12)


 79%|███████▉  | 158/200 [01:28<00:24,  1.74it/s]

(880, 880, 15)


 80%|███████▉  | 159/200 [01:28<00:24,  1.68it/s]

(880, 880, 12)


 80%|████████  | 160/200 [01:29<00:23,  1.73it/s]

(880, 880, 12)


 80%|████████  | 161/200 [01:30<00:21,  1.78it/s]

(880, 880, 12)


 81%|████████  | 162/200 [01:30<00:21,  1.78it/s]

(880, 880, 12)


 82%|████████▏ | 163/200 [01:31<00:20,  1.77it/s]

(880, 880, 12)


 82%|████████▏ | 164/200 [01:31<00:20,  1.79it/s]

(880, 880, 12)


 82%|████████▎ | 165/200 [01:32<00:19,  1.81it/s]

(880, 880, 12)


 83%|████████▎ | 166/200 [01:32<00:18,  1.81it/s]

(880, 880, 15)


 84%|████████▎ | 167/200 [01:33<00:19,  1.70it/s]

(1008, 1008, 12)


 84%|████████▍ | 168/200 [01:34<00:19,  1.66it/s]

(880, 880, 15)


 84%|████████▍ | 169/200 [01:34<00:19,  1.61it/s]

(880, 880, 15)


 85%|████████▌ | 170/200 [01:35<00:19,  1.57it/s]

(880, 880, 12)


 86%|████████▌ | 171/200 [01:35<00:17,  1.66it/s]

(880, 880, 12)


 86%|████████▌ | 172/200 [01:36<00:16,  1.72it/s]

(880, 880, 12)


 86%|████████▋ | 173/200 [01:36<00:15,  1.79it/s]

(880, 880, 12)


 87%|████████▋ | 174/200 [01:37<00:14,  1.80it/s]

(880, 880, 12)


 88%|████████▊ | 175/200 [01:38<00:13,  1.81it/s]

(880, 880, 12)


 88%|████████▊ | 176/200 [01:38<00:13,  1.81it/s]

(880, 880, 12)


 88%|████████▊ | 177/200 [01:39<00:12,  1.83it/s]

(880, 880, 15)


 89%|████████▉ | 178/200 [01:39<00:12,  1.73it/s]

(880, 880, 12)


 90%|████████▉ | 179/200 [01:40<00:11,  1.81it/s]

(880, 880, 12)


 90%|█████████ | 180/200 [01:40<00:10,  1.83it/s]

(880, 880, 12)


 90%|█████████ | 181/200 [01:41<00:10,  1.84it/s]

(880, 880, 15)


 91%|█████████ | 182/200 [01:41<00:10,  1.76it/s]

(880, 880, 12)


 92%|█████████▏| 183/200 [01:42<00:09,  1.79it/s]

(880, 880, 12)


 92%|█████████▏| 184/200 [01:43<00:08,  1.82it/s]

(880, 880, 12)


 92%|█████████▎| 185/200 [01:43<00:08,  1.84it/s]

(880, 880, 12)


 93%|█████████▎| 186/200 [01:44<00:07,  1.82it/s]

(880, 880, 15)


 94%|█████████▎| 187/200 [01:44<00:07,  1.74it/s]

(880, 880, 12)


 94%|█████████▍| 188/200 [01:45<00:06,  1.76it/s]

(880, 880, 12)


 94%|█████████▍| 189/200 [01:45<00:06,  1.80it/s]

(880, 880, 15)


 95%|█████████▌| 190/200 [01:46<00:05,  1.72it/s]

(880, 880, 12)


 96%|█████████▌| 191/200 [01:47<00:05,  1.78it/s]

(1024, 1024, 12)


 96%|█████████▌| 192/200 [01:47<00:04,  1.66it/s]

(880, 880, 12)


 96%|█████████▋| 193/200 [01:48<00:04,  1.73it/s]

(880, 880, 12)


 97%|█████████▋| 194/200 [01:48<00:03,  1.72it/s]

(880, 880, 12)


 98%|█████████▊| 195/200 [01:49<00:02,  1.78it/s]

(880, 880, 12)


 98%|█████████▊| 196/200 [01:49<00:02,  1.78it/s]

(880, 880, 12)


 98%|█████████▊| 197/200 [01:50<00:01,  1.82it/s]

(880, 880, 12)


 99%|█████████▉| 198/200 [01:50<00:01,  1.85it/s]

(880, 880, 12)


100%|█████████▉| 199/200 [01:51<00:00,  1.86it/s]

(880, 880, 15)


100%|██████████| 200/200 [01:52<00:00,  1.78it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

here (128, 128, 368)


  5%|▌         | 1/20 [00:00<00:09,  1.92it/s]

here (128, 128, 365)


 10%|█         | 2/20 [00:01<00:09,  1.88it/s]

here (128, 128, 350)


 25%|██▌       | 5/20 [00:01<00:04,  3.58it/s]

here (128, 128, 387)


 35%|███▌      | 7/20 [00:02<00:03,  3.66it/s]

here (145, 145, 422)


 45%|████▌     | 9/20 [00:02<00:03,  3.51it/s]

here (128, 128, 384)


 55%|█████▌    | 11/20 [00:03<00:02,  3.62it/s]

here (141, 141, 404)


 65%|██████▌   | 13/20 [00:03<00:01,  3.58it/s]

here (128, 128, 348)


 75%|███████▌  | 15/20 [00:04<00:01,  3.61it/s]

here (128, 128, 342)


 85%|████████▌ | 17/20 [00:04<00:00,  3.66it/s]

here (148, 148, 389)


100%|██████████| 20/20 [00:05<00:00,  3.63it/s]
INFO:root:当前阶段：training_phase_1_release
  0%|          | 0/121 [00:00<?, ?it/s]

(154, 295, 154)


  1%|          | 1/121 [00:00<01:00,  1.98it/s]

(459, 459, 66)


  4%|▍         | 5/121 [00:01<00:24,  4.67it/s]

(275, 605, 81)


  7%|▋         | 9/121 [00:01<00:22,  4.99it/s]

(231, 462, 97)


 11%|█         | 13/121 [00:02<00:20,  5.20it/s]

(234, 234, 88)


 14%|█▍        | 17/121 [00:03<00:17,  5.89it/s]

(407, 407, 75)


 17%|█▋        | 21/121 [00:03<00:16,  5.96it/s]

(198, 198, 50)


 21%|██        | 25/121 [00:04<00:13,  7.10it/s]

(392, 392, 453)


 24%|██▍       | 29/121 [00:06<00:29,  3.15it/s]

(196, 196, 80)


 27%|██▋       | 33/121 [00:07<00:22,  3.91it/s]

(445, 579, 132)


 31%|███       | 37/121 [00:09<00:25,  3.33it/s]

(162, 162, 253)


 34%|███▍      | 41/121 [00:09<00:20,  3.91it/s]

(512, 512, 688)


 37%|███▋      | 45/121 [00:17<00:58,  1.29it/s]

(381, 381, 252)


 40%|████      | 49/121 [00:19<00:49,  1.44it/s]

(444, 444, 709)


 44%|████▍     | 53/121 [00:26<01:06,  1.02it/s]

(250, 250, 372)


 47%|████▋     | 57/121 [00:27<00:51,  1.23it/s]

(365, 365, 619)


 50%|█████     | 61/121 [00:32<00:53,  1.11it/s]

(512, 512, 315)


 54%|█████▎    | 65/121 [00:35<00:51,  1.10it/s]

(312, 312, 497)


 57%|█████▋    | 69/121 [00:38<00:43,  1.20it/s]

(182, 182, 219)


 60%|██████    | 73/121 [00:39<00:31,  1.54it/s]

(187, 187, 269)


 64%|██████▎   | 77/121 [00:40<00:22,  1.96it/s]

(350, 350, 440)


 67%|██████▋   | 81/121 [00:43<00:23,  1.69it/s]

(183, 183, 211)


 70%|███████   | 85/121 [00:44<00:17,  2.11it/s]

(168, 168, 509)


 74%|███████▎  | 89/121 [00:44<00:12,  2.53it/s]

(350, 350, 637)


 77%|███████▋  | 93/121 [00:48<00:16,  1.72it/s]

(427, 427, 472)


 80%|████████  | 97/121 [00:53<00:17,  1.35it/s]

(380, 380, 640)


 83%|████████▎ | 101/121 [00:57<00:17,  1.17it/s]

(153, 300, 61)


 87%|████████▋ | 105/121 [00:58<00:09,  1.61it/s]

(191, 489, 67)


 90%|█████████ | 109/121 [00:58<00:05,  2.12it/s]

(174, 174, 247)


100%|██████████| 121/121 [00:59<00:00,  2.03it/s]
INFO:root:当前阶段：training_phase_2_release


(152, 179, 152)


  0%|          | 0/120 [00:00<?, ?it/s]

(161, 338, 61)


  1%|          | 1/120 [00:00<00:41,  2.89it/s]

(915, 1189, 121)


  8%|▊         | 9/120 [00:05<00:54,  2.04it/s]

(250, 325, 38)
(269, 269, 73)


 11%|█         | 13/120 [00:05<00:34,  3.10it/s]

(186, 186, 40)


 14%|█▍        | 17/120 [00:05<00:23,  4.44it/s]

(291, 291, 58)


 18%|█▊        | 21/120 [00:06<00:17,  5.67it/s]

(216, 216, 74)


 21%|██        | 25/120 [00:06<00:14,  6.70it/s]

(175, 175, 57)


 24%|██▍       | 29/120 [00:06<00:11,  7.82it/s]

(183, 363, 115)


 28%|██▊       | 33/120 [00:07<00:11,  7.69it/s]

(204, 272, 62)


 34%|███▍      | 41/120 [00:09<00:15,  5.05it/s]

(492, 640, 110)
(256, 256, 71)


 38%|███▊      | 45/120 [00:09<00:12,  6.01it/s]

(369, 369, 669)


 41%|████      | 49/120 [00:14<00:33,  2.15it/s]

(382, 382, 541)


 44%|████▍     | 53/120 [00:18<00:41,  1.63it/s]

(300, 300, 346)


 48%|████▊     | 57/120 [00:20<00:36,  1.75it/s]

(357, 357, 619)


 51%|█████     | 61/120 [00:23<00:41,  1.43it/s]

(391, 391, 670)


 54%|█████▍    | 65/120 [00:28<00:47,  1.16it/s]

(157, 266, 40)


 57%|█████▊    | 69/120 [00:29<00:31,  1.60it/s]

(246, 279, 82)


 61%|██████    | 73/120 [00:29<00:22,  2.10it/s]

(283, 239, 75)


 64%|██████▍   | 77/120 [00:30<00:15,  2.69it/s]

(114, 198, 61)


 68%|██████▊   | 81/120 [00:30<00:11,  3.47it/s]

(182, 182, 258)


 71%|███████   | 85/120 [00:31<00:09,  3.66it/s]

(160, 213, 61)


 74%|███████▍  | 89/120 [00:31<00:06,  4.56it/s]

(331, 696, 100)


 78%|███████▊  | 93/120 [00:33<00:06,  4.01it/s]

(194, 636, 61)


 81%|████████  | 97/120 [00:33<00:04,  4.65it/s]

(143, 379, 61)


 84%|████████▍ | 101/120 [00:34<00:03,  5.63it/s]

(228, 376, 135)


 88%|████████▊ | 105/120 [00:34<00:02,  5.37it/s]

(195, 629, 61)


 91%|█████████ | 109/120 [00:35<00:01,  5.83it/s]

(209, 209, 603)


 94%|█████████▍| 113/120 [00:37<00:01,  4.16it/s]

(179, 692, 88)


100%|██████████| 120/120 [00:37<00:00,  3.17it/s]
INFO:root:当前阶段：training_phase_3_release
  0%|          | 0/80 [00:00<?, ?it/s]

(177, 177, 300)


  6%|▋         | 5/80 [00:01<00:15,  4.82it/s]

(512, 512, 38)
(210, 210, 292)


 11%|█▏        | 9/80 [00:02<00:15,  4.65it/s]

(152, 152, 195)


 16%|█▋        | 13/80 [00:02<00:13,  5.05it/s]

(199, 199, 227)


 21%|██▏       | 17/80 [00:03<00:14,  4.42it/s]

(171, 171, 149)


 26%|██▋       | 21/80 [00:04<00:13,  4.52it/s]

(158, 158, 68)


 31%|███▏      | 25/80 [00:05<00:10,  5.44it/s]

(221, 243, 51)


 36%|███▋      | 29/80 [00:05<00:07,  6.61it/s]

(210, 210, 183)


 41%|████▏     | 33/80 [00:06<00:08,  5.65it/s]

(287, 287, 270)


 46%|████▋     | 37/80 [00:07<00:10,  4.23it/s]

(257, 257, 214)


 51%|█████▏    | 41/80 [00:08<00:09,  4.04it/s]

(173, 173, 184)


 56%|█████▋    | 45/80 [00:09<00:08,  4.29it/s]

(371, 371, 594)


 61%|██████▏   | 49/80 [00:14<00:15,  2.06it/s]

(346, 705, 174)


 66%|██████▋   | 53/80 [00:16<00:13,  2.02it/s]

(191, 345, 52)


 71%|███████▏  | 57/80 [00:16<00:08,  2.69it/s]

(216, 568, 68)


 76%|███████▋  | 61/80 [00:17<00:05,  3.31it/s]

(250, 250, 571)


 81%|████████▏ | 65/80 [00:18<00:05,  2.80it/s]

(222, 222, 589)


 86%|████████▋ | 69/80 [00:20<00:03,  2.76it/s]

(208, 208, 543)


 91%|█████████▏| 73/80 [00:21<00:02,  2.79it/s]

(317, 317, 559)


100%|██████████| 80/80 [00:24<00:00,  3.24it/s]
INFO:root:处理完成，DataFrame已保存为 image_groundtruth_data.csv


In [5]:
df

Unnamed: 0,Image,Dataset_index,Case,ImagePath,GroundTruthPath,height,width
0,0_Case1_0.png,0,Case1,Pretrain_data/image/0_Case1_0.png,Pretrain_data/groundtruth/0_Case1_0_label.png,880,880
1,0_Case1_1.png,0,Case1,Pretrain_data/image/0_Case1_1.png,Pretrain_data/groundtruth/0_Case1_1_label.png,880,880
2,0_Case1_2.png,0,Case1,Pretrain_data/image/0_Case1_2.png,Pretrain_data/groundtruth/0_Case1_2_label.png,880,880
3,0_Case1_3.png,0,Case1,Pretrain_data/image/0_Case1_3.png,Pretrain_data/groundtruth/0_Case1_3_label.png,880,880
4,0_Case1_4.png,0,Case1,Pretrain_data/image/0_Case1_4.png,Pretrain_data/groundtruth/0_Case1_4_label.png,880,880
...,...,...,...,...,...,...,...
3853,2_verse270_163.png,2,verse270,Pretrain_data/image/2_verse270_163.png,Pretrain_data/groundtruth/2_verse270_163_label...,317,317
3854,2_verse270_164.png,2,verse270,Pretrain_data/image/2_verse270_164.png,Pretrain_data/groundtruth/2_verse270_164_label...,317,317
3855,2_verse270_165.png,2,verse270,Pretrain_data/image/2_verse270_165.png,Pretrain_data/groundtruth/2_verse270_165_label...,317,317
3856,2_verse270_166.png,2,verse270,Pretrain_data/image/2_verse270_166.png,Pretrain_data/groundtruth/2_verse270_166_label...,317,317


In [6]:
if CFG.wandb:
    try:
        import wandb
        wandb.login()
        run = wandb.init(project=CFG.project, 
                 name=CFG.exp_name,
                ) 
    except:
        logger.info(f"Check your WANDB account")

In [7]:
gkf = GroupKFold(n_splits=CFG.n_fold)
df["fold"] = -1
for fold_id, (_, val_idx) in enumerate(
    gkf.split(df, y=df["GroundTruthPath"], groups=df["Case"])
):
    df.loc[val_idx, "fold"] = fold_id
df.fold.value_counts()

fold
3    774
1    774
0    774
2    774
4    762
Name: count, dtype: int64

# Train&Valid&Test

## pip

In [8]:
import os
import nibabel as nib
import pandas as pd
from tqdm import tqdm
import logging
import imageio
import numpy as np
import torch
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
import os,sys,cv2,gc
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel
from glob import glob
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import train_test_split
import random
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from warmup_scheduler import GradualWarmupScheduler
from contextlib import contextmanager
import time
import math
from datetime import datetime
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler

## Config

In [9]:
class CFG:
    #Predata
    predata = False
    output_dir = 'Exp001'
    target_height = 512
    target_width = 512

    wandb = False
    seed = 42
    project = 'Spine'
    exp_name = 'exp01'
    n_fold = 5
    valid_fold = 4
    chopping_percentile = 1e-3
    in_chans = 1
    train_batch_size = 16
    valid_batch_size = 32
    
    train_aug_list = [
        A.Rotate(limit=270, p= 0.3),
        A.GaussianBlur(p=0.1),
        A.MotionBlur(p=0.1),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.1),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)

    nprocs=1 
    fold_num=5 
    num_classes=1
    
    accum_iter=1 
    max_grad_norm=1000 
    print_freq=100 
    
    test_fold_list=[0] 
    valid_fold_list=[1]
    train_fold_list = [2,3,4]
    
    epochs=25

    optimizer="AdamW" 
    scheduler="CosineAnnealingLR"
    loss_fn="BCEWithLogitsLoss"
    scheduler_warmup= "GradualWarmupSchedulerV3"
    
    model_name = 'seresnext26d_32x4d.bt_in1k'
    model_path = 'Encoder_backbone/Encoder_backbone/Encoder/seresnext26d_32x4d_bt_in1k.bin' #0.8827
    
    warmup_epo=10
    warmup_factor = 10
    T_max= epochs-warmup_epo-2 if scheduler_warmup=="GradualWarmupSchedulerV2" else \
           epochs-warmup_epo-1 if scheduler_warmup=="GradualWarmupSchedulerV3" else epochs-1
    lr=1e-3 
    min_lr=1e-6 #
    weight_decay=1e-2
    n_early_stopping=25

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Tools

In [10]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 

seed_everything(CFG.seed)

@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')

# 日志记录函数
def init_logger(log_file):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    if not os.path.exists(os.path.dirname(log_file)):
        os.makedirs(os.path.dirname(log_file), exist_ok=True)  # 创建日志文件所在目录
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger(CFG.output_dir+f'/train_{CFG.exp_name}.log')
loginfo = LOGGER.info
cusprint = print

def get_timediff(time1,time2):
    minute_,second_ = divmod(time2-time1,60)
    return f"{int(minute_):02d}:{int(second_):02d}"  

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

## Dataloader

In [11]:
class Data_loader2D(Dataset):
    def __init__(self,df):
        self.df = df.reset_index()
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self,index):
        row = self.df.iloc[index]
        img_path = row.ImagePath
        label_path = row.GroundTruthPath
        img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
        label = cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)
        
        img=torch.from_numpy(img)
        label=torch.from_numpy(label)
        return img,label
    
def load_data(df):
    data_loader=Data_loader2D(df)
    data_loader=DataLoader(data_loader, batch_size=16, num_workers=0)
    img=[]
    label=[]
    for x,y in tqdm(data_loader):
        img.append(x)
        label.append(y)
    img_c=torch.cat(img,dim=0)
    label_c=torch.cat(label,dim=0)
    del img,label
    return img_c,label_c

In [12]:
train_img,train_label=load_data(df)
print(train_img.shape)
print(train_label.shape)

100%|██████████| 242/242 [00:12<00:00, 18.76it/s]


torch.Size([3858, 512, 512])
torch.Size([3858, 512, 512])


In [13]:
class Spine_Dataset(Dataset):
    def __init__(self,x:list,y:list,arg=False):
        super(Dataset,self).__init__()
        self.x = x
        self.y= y 
        self.image_size=CFG.target_height
        self.in_chans=CFG.in_chans
        self.arg=arg
        if arg:
            self.transform=CFG.train_aug
        else: 
            self.transform=CFG.valid_aug
            
    def __len__(self) -> int:
        return sum([y.shape[0]-self.in_chans for y in self.y])
        
    def __getitem__(self, index):
        i=0
        for x in self.x:
            if index>x.shape[0]-self.in_chans:
                index-=x.shape[0]-self.in_chans
                i+=1
            else:
                break
        x=self.x[i]
        y=self.y[i]
        
        x=x[index:index+self.in_chans,:,:]
        y=y[index+self.in_chans//2,:,:]
 
        if self.in_chans == 1:
             x = x.repeat(3, 1, 1)
        # 应用任何预处理或增强转换
        data = self.transform(image=x.numpy().transpose(1, 2, 0), mask=y.numpy())
        x = data['image']
        y = data['mask']
        
        # 创建二值标签mask：所有不等于0的标签变为1
        binary_mask = torch.where(y != 0, torch.tensor(1.0, dtype=torch.float32), torch.tensor(0.0, dtype=torch.float32))
        binary_mask = binary_mask.unsqueeze(0)
        multilabel_mask = torch.zeros((22, y.shape[0], y.shape[1]), dtype=torch.float32)
        
        # 填充每个标签的channel
        for label_num in range(1, 23):
            multilabel_mask[label_num - 1] = torch.where(y == label_num, torch.tensor(1.0, dtype=torch.float32), torch.tensor(0.0, dtype=torch.float32))

        return x, binary_mask, multilabel_mask  

In [14]:
train_dataset = Spine_Dataset([train_img],[train_label],arg=True)
train_dataset = DataLoader(train_dataset, batch_size=CFG.train_batch_size ,num_workers=0, shuffle=True, pin_memory=True)

In [15]:
# 获取DataLoader的迭代器
data_iterator = iter(train_dataset)
# 从迭代器中获取第一个批次
first_batch = next(data_iterator)
print('img shape:',first_batch[0].shape,'binary_label_shape:',first_batch[1].shape,'multi_label_shape:',first_batch[2].shape)

img shape: torch.Size([16, 3, 512, 512]) binary_label_shape: torch.Size([16, 1, 512, 512]) multi_label_shape: torch.Size([16, 22, 512, 512])


## CoordAtt

In [16]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):   #0.8824 0.8845???
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out
if __name__ == "__main__":
    bs, c, h, w = 10, 16, 64, 64
    in_tensor = torch.ones(bs, c, h, w)

    cs_se = CoordAtt(c,c)
    print("in shape:",in_tensor.shape)
    out_tensor = cs_se(in_tensor)
    print("out shape:", out_tensor.shape)

in shape: torch.Size([10, 16, 64, 64])
out shape: torch.Size([10, 16, 64, 64])


## MSA

In [17]:
import torch

def pool_variance(Z_prev, f_h=2, f_w=2, padding=0, stride_h=2, stride_w=3):
    (b, n_C_prev, n_H_prev, n_W_prev) = Z_prev.shape

    n_H = 1 + int((n_H_prev + 2 * padding - f_h) / stride_h)
    n_W = 1 + int((n_W_prev + 2 * padding - f_w) / stride_w)

    Z_prev_unfold = torch.nn.functional.unfold(Z_prev, (f_h, f_w), stride=(stride_h, stride_w)) # (b, n_C_prev * f_h * f_w, L)
    Z_prev_unfold = Z_prev_unfold.transpose(1, 2) # (b, L, n_C_prev * f_h * f_w)
    Z_prev_unfold = Z_prev_unfold.view(b, n_H * n_W, n_C_prev, -1) # (b, L, n_C_prev, f_h * f_w)
    
    mean_squared = torch.mean(Z_prev_unfold ** 2, dim=3, keepdim=False)
    mean = torch.mean(Z_prev_unfold, dim=3, keepdim=False)
    
    variance = mean_squared - mean ** 2
    
    variance = variance.transpose(1, 2)
    
    Z_var = torch.nn.functional.fold(variance, (n_H, n_W), (1, 1))
    
    assert(Z_var.size() == (b, n_C_prev, n_H, n_W))
    return Z_var

if __name__ == "__main__":
    sample_input = torch.randn(1, 64, 32, 16)  
    result = pool_variance(sample_input, f_h=8, f_w=16, padding=0, stride_h=8, stride_w=16)
    print(result.shape)  

torch.Size([1, 64, 4, 1])


In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=4):
        super(ChannelAttention, self).__init__()
        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.BatchNorm2d(in_planes// ratio),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False),
            nn.BatchNorm2d(in_planes),
            nn.ReLU(),
        )
        self.sigmoid = nn.Sigmoid()
        self.height_reduce = nn.Sequential(
            nn.Conv2d(in_planes, in_planes, (8, 2), bias=False),
        )
        # self.height_reduce = nn.Sequential(
        #     nn.Conv2d(in_planes, in_planes, (4, 2), bias=False),
        #    
        # )
    def forward(self, x):
        mid_col_idx_start = x.shape[3]//4
        mid_col_idx_end = x.shape[3]//4 *3
  
        mid_col = x[:, :, :, mid_col_idx_start:mid_col_idx_end]

        pool_height = mid_col.shape[2] // 8
        pool_width =  mid_col_idx_end - mid_col_idx_start
        stride = (pool_height, pool_width)  # stride matches the pool size for non-overlapping pooling
 
        avg_pool = nn.AvgPool2d((pool_height, pool_width), stride=stride)
        var_pooled = pool_variance(mid_col,f_h=pool_height, f_w=pool_width, stride_h=pool_height, stride_w=pool_width)

        avg_pooled = avg_pool(mid_col)

        avgout = self.sharedMLP(avg_pooled)
        varout = self.sharedMLP(var_pooled)
        concat_pooled = torch.cat((avgout,varout), dim=3)

        mlp_output = self.height_reduce(concat_pooled)

        return self.sigmoid(mlp_output)

In [19]:
import torch
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)

class SpatialAttention(nn.Module):
    def __init__(self, planes, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        self.kernel_size = kernel_size
        padding = 3 if kernel_size == 7 else 1
        self.padding = padding
        self.conv = nn.Conv2d(3, 1, kernel_size, padding=padding, bias=False)
        self.p2pconv = nn.Conv2d(planes,1,kernel_size,padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        convout = self.p2pconv(x)
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        varout = torch.var(x, dim=1, keepdim=True)
        x = torch.cat([avgout, varout,convout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention(planes)

    def forward(self, x):
        out = self.ca(x) * x  # 广播机制
        out = self.sa(out) * out  # 广播机制
        out = self.relu(out)
        return out

if __name__ == "__main__":

    x = torch.ones(3, 16, 32, 32)

    model = BasicBlock(16, 16, stride=1)

    print(model(x).shape)

torch.Size([3, 16, 32, 32])


## WIDRM-S Model

In [20]:
from torch import Tensor
class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])
    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i+1),mode='bilinear') 
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

class UnetBlock(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        self.bn2 = nn.BatchNorm2d(nf)             
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.gelu = nn.GELU()
        self.relu = nn.ReLU(inplace=True)
        self.nf = nf
    def forward(self, up_in, left_in):
        s = left_in
        #c_se = csSE(s.shape[1]).cuda()  #0.8832
        #c_se = SELayer(s.shape[1]).cuda() #0.8789
        #c_se = BasicBlock(s.shape[1], s.shape[1], stride=1).cuda() #0.8833 CBAM
        #c_se = CoordAtt(s.shape[1],s.shape[1]).cuda()  #0.8845
        #s = c_se(s)
        up_out = self.shuf(up_in)
        cat_x = self.gelu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.gelu(self.bn2(self.conv1(cat_x))))
        
class UnetBlockWithAtt(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        self.bn2 = nn.BatchNorm2d(nf)             
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.gelu = nn.GELU()
        self.relu = nn.ReLU(inplace=True)
        self.nf = nf
    def forward(self, up_in, left_in):
        s = left_in
        #c_se = csSE(s.shape[1]).cuda()  #0.8832
        #c_se = SELayer(s.shape[1]).cuda() #0.8789
        c_se = BasicBlock(s.shape[1], s.shape[1], stride=1).cuda() #0.8833 CBAM
        #c_se = CoordAtt(s.shape[1],s.shape[1]).cuda()  #0.8845
        s = c_se(s)
        up_out = self.shuf(up_in)
        cat_x = self.gelu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.gelu(self.bn2(self.conv1(cat_x))))
        
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU(inplace=True))
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)
    
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [21]:
class WIDRMS(nn.Module):
    def __init__(self, stride=1, **kwargs):
        super().__init__()
        #encoder
        m = timm.create_model('seresnext26d_32x4d.bt_in1k', pretrained=False)
        weights = torch.load('Encoder_backbone/Encoder_backbone/Encoder/seresnext26d_32x4d_bt_in1k.bin', map_location=torch.device('cpu'))
        # 应用这些权重到模型上
        m.load_state_dict(weights)
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = nn.Dropout2d(0.0)
        #decoder
        self.dec4 = UnetBlockWithAtt(512,1024,256)
        self.dec3 = UnetBlockWithAtt(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.0)
        self.final_conv = ConvLayer(32+16*4, CFG.num_classes, ks=1, norm_type=None, act_cls=None)
        #self.final_conv = ConvLayer(32, CFG.num_classes, ks=1, norm_type=None, act_cls=None)
        self.num_classes = CFG.num_classes
        
    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        #dec3 = self.dec4(enc4,enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        dec0 = F.interpolate(dec0,scale_factor=2,mode='bilinear')
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        return x

In [23]:
import timm
from fastai.vision.all import PixelShuffle_ICNR
from fastai.vision.all import ConvLayer
from fastai.vision.all import AdaptiveConcatPool2d, Flatten
import torch

model = WIDRMS().cuda()
input_tensor = torch.randn(2, 3, 256, 256).cuda()

output = model(input_tensor)

print('Input shape:', input_tensor.shape)
print('Output1 shape:', output.shape)

Input shape: torch.Size([2, 3, 256, 256])
Output1 shape: torch.Size([2, 1, 256, 256])


## Metric

In [24]:
def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice

In [25]:
from medpy import metric
def calculate_metric_percase(gt, pred, thr=0.5):
    gt = gt.cpu().detach().numpy()
    pred = (pred>thr).cpu().detach().numpy()
    dice = metric.binary.dc(pred, gt)
    jc = metric.binary.jc(pred, gt)
    hd = metric.binary.hd95(pred, gt)
    asd = metric.binary.asd(pred, gt)
    return dice, jc, hd, asd

In [26]:
def plot_examples(images, masks, preds, epoch, step):
    fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(9, 9))
    for row in range(3):
        ax[row, 0].imshow(images[row][1].cpu().squeeze(), cmap='gray')
        ax[row, 0].set_title(f"Epoch {epoch} Batch {step} Sample {row} - Image")
        ax[row, 1].imshow(masks[row].cpu().squeeze(), cmap='gray')
        ax[row, 1].set_title(f"Epoch {epoch} Batch {step} Sample {row} - Ground Truth")
        ax[row, 2].imshow(preds[row].cpu().squeeze(), cmap='gray')
        ax[row, 2].set_title(f"Epoch {epoch} Batch {step} Sample {row} - Prediction")
        for col in range(3):
            ax[row, col].axis("off")
    plt.show()

## Loss

In [27]:
from SegLossFunc import SegLoss
BoundaryDoULoss = SegLoss.BoundaryDoULoss()

SoftDiceCLDiceBoundaryDoULoss = SegLoss.SoftDiceCLDiceBoundaryDoULoss()

StructureLoss = SegLoss.StructureLoss()
StructureLossBoundaryDOU = SegLoss.StructureLossBoundaryDOU()
StructureLossBoundaryDOUV2 = SegLoss.StructureLossBoundaryDOUV2()

JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
FocalLoss = smp.losses.FocalLoss(mode="multilabel")

def criterion(y_pred, y_true):
    return BoundaryDoULoss(y_pred, y_true) + 0.5*BCELoss(y_pred, y_true) + 0.5*DiceLoss(y_pred, y_true)

## Scheduler

In [28]:
# 自定义逐渐升温调度器
class GradualWarmupSchedulerV3(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV3, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch >= self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

## Train&Valid&Test

In [29]:
def train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    model.train()
    start = end = time.time()
    for step, (images, bi_masks, multi_mask) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        if CFG.num_classes == 1:
            masks = bi_masks.to(device, dtype=torch.float)
        else:
            masks = multi_mask.to(device, dtype=torch.float)
        batch_size = images.size(0)
        with autocast(enabled=True):
            y_preds = model(images)
            loss = criterion(y_preds, masks)
        # record loss
        losses.update(loss.item(), batch_size)
        if CFG.accum_iter > 1:
            loss = loss / CFG.accum_iter
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.accum_iter == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            cusprint('Epoch: [{0}][{1}/{2}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                'Grad: {grad_norm:.4f}  '
                'LR: {lr:.7f}  '
                .format(
                epoch, step, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(train_loader)),
                grad_norm=grad_norm,
                lr=optimizer.param_groups[0]["lr"],
                ))
    return losses.avg, optimizer.param_groups[0]["lr"]

In [30]:
def valid_one_epoch(valid_loader, model, criterion, device,epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    start = end = time.time()
    val_scores = []
    for step, (images, bi_masks, multi_mask) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        if CFG.num_classes == 1:
            masks = bi_masks.to(device, dtype=torch.float)
        else:
            masks = multi_mask.to(device, dtype=torch.float)
        batch_size = images.size(0)

        # compute loss
        with torch.no_grad():
            y_pred = model(images)

        loss = criterion(y_pred, masks)
        losses.update(loss.item(), batch_size)
        # record accuracy
        y_pred = y_pred.sigmoid() ####
        # y_pred = y_pred.sigmoid().to('cpu').numpy()
        
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        #val_scores.append([val_dice])
        dice, jc, hd, asd = calculate_metric_percase(masks, y_pred)

        val_scores.append([val_dice, dice, jc, hd, asd])
        
        if CFG.accum_iter > 1:
            loss = loss / CFG.accum_iter
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            cusprint('EVAL: [{0}/{1}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                .format(
                step, len(valid_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(valid_loader)),
                ))
    val_scores = np.mean(val_scores, axis=0)
    
    if val_scores[0]>0.95:
        selected_indices = np.random.choice(images.shape[0], 3, replace=False)
        images = images[selected_indices]
        masks = masks[selected_indices]
        preds = y_pred[selected_indices]
        plot_examples(images, masks, preds,epoch,step)
    torch.cuda.empty_cache()
    gc.collect()
    return losses.avg, val_scores

In [40]:
# 训练函数
def train_loop(train_df, fold_1, fold_2, criterion):
    loginfo(f"========== training ==========")
    # ====================================================
    # loader 
    # ====================================================
    train_folds = train_df[train_df["fold"].isin(fold_1)].reset_index(drop=True)
    valid_folds = train_df[train_df["fold"].isin(fold_2)].reset_index(drop=True)

    train_img,train_label=load_data(train_folds)
    valid_img,valid_label=load_data(valid_folds)
    
    train_dataset = Spine_Dataset([train_img],[train_label],arg=True)
    train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size ,num_workers=0, shuffle=True, pin_memory=True)
    valid_dataset = Spine_Dataset([valid_img],[valid_label],arg=False)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_batch_size ,num_workers=0, shuffle=False, pin_memory=True)
    # ====================================================
    # model & optimizer & scheduler & loss 
    # ====================================================
    model = WIDRMS().cuda()
    #model = torch.nn.DataParallel(model, device_ids=[0, 1])
    # optimizer
    if CFG.optimizer == "AdamW":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)  
    # scheduler
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)

    if CFG.scheduler_warmup=="GradualWarmupSchedulerV3":
        scheduler_warmup = GradualWarmupSchedulerV3(optimizer, multiplier=10, total_epoch=CFG.warmup_epo, after_scheduler=scheduler)

    # ====================================================
    # loop 
    # ====================================================

    valid_acc_max=0
    valid_acc_max_cnt=0
    for epoch in range(CFG.epochs):
        loginfo(f"***** Epoch {epoch} *****")
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            pass
            # loginfo(f"schwarmup_last_epoch:{scheduler_warmup.last_epoch}, schwarmup_lr:{scheduler_warmup.get_last_lr()[0]}")
        if CFG.scheduler=='CosineAnnealingLR':
            loginfo(f"scheduler_last_epoch:{scheduler.last_epoch}, scheduler_lr:{scheduler.get_last_lr()[0]}")
        loginfo(f"optimizer_lr:{optimizer.param_groups[0]['lr']}")
                
        start_time = time.time() # 记录当前时间
        avg_loss, cur_lr = train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        avg_val_loss, valid_scores = valid_one_epoch(valid_loader, model, criterion, device,epoch)
        # scoring
        elapsed = time.time() - start_time
        # print("valid_scores:", valid_scores, type(valid_scores))
        val_dice, dice, jc, hd, asd = valid_scores
        
        loginfo(f'Epoch {epoch} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        loginfo(f'Epoch {epoch} -val_dice: {val_dice:.4f} - Dice Score: {dice:.4f}- Jaccard Score: {jc:.4f}- HD95 Score: {hd:.4f}- ASSD Score: {asd:.4f}')
    
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            scheduler_warmup.step()
        elif CFG.scheduler == "ReduceLROnPlateau":
            scheduler.step(avg_val_loss)
        elif CFG.scheduler in ["CosineAnnealingLR", "CosineAnnealingWarmRestarts"]:
            scheduler.step()

        torch.save({'model': model.state_dict()}, CFG.output_dir+f'/{CFG.model_name}_{CFG.exp_name}_fold{fold}_epoch{epoch}.pth')
        
        # early stopping 
        if val_dice > valid_acc_max:
            valid_acc_max = dice
            valid_acc_max_cnt=0
            best_acc_epoch = epoch
        else:
            valid_acc_max_cnt+=1
        
        if valid_acc_max_cnt >= CFG.n_early_stopping:
            torch.save({'model': model.state_dict()}, CFG.output_dir+f'/{CFG.model_name}_{CFG.exp_name}_fold{fold}_epoch{epoch}.pth')
        
            print("early_stopping")
            break
        
        torch.save({'model': model.state_dict()}, CFG.output_dir+f'/{CFG.model_name}_{CFG.exp_name}_fold{fold}_epoch{epoch}.pth')

In [41]:
def valid():
    fold_1 = CFG.train_fold_list
    fold_2 = CFG.valid_fold_list
    train_loop(df, fold_1, fold_2, criterion)
def test():
    fold_1 = CFG.train_fold_list
    fold_2 = CFG.test_fold_list
    train_loop(df, fold_1, fold_2, criterion)

In [42]:
test()

100%|██████████| 145/145 [00:07<00:00, 18.82it/s]
100%|██████████| 49/49 [00:02<00:00, 20.37it/s]
***** Epoch 0 *****
INFO:__main__:***** Epoch 0 *****
scheduler_last_epoch:0, scheduler_lr:0.0001
INFO:__main__:scheduler_last_epoch:0, scheduler_lr:0.0001
optimizer_lr:0.0001
INFO:__main__:optimizer_lr:0.0001


Epoch: [0][0/145] Data 0.468 (0.468) Elapsed 0m 9s (remain 21m 57s) Loss: 1.7557(1.7557) Grad: 64263.5312  LR: 0.0001000  


KeyboardInterrupt: 