### 用于训练自己写的用于医学分割的ViT

In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import math
import sys

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import timm

# assert timm.__version__ == "0.3.2"  # version check
import timm.optim.optim_factory as optim_factory

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

import my_seg_vit

from my_seg_engine import train_one_epoch



def get_args_parser():
    parser = argparse.ArgumentParser('my 2D supervised seg-ViT', add_help=False)
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=1000, type=int)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Model parameters
    parser.add_argument('--model', default='mae_vit_base_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')

    parser.add_argument('--input_size', default=224, type=int,
                        help='images input size')

    parser.add_argument('--mask_ratio', default=0.75, type=float,
                        help='Masking ratio (percentage of removed patches).')

    parser.add_argument('--norm_pix_loss', action='store_true',
                        help='Use (per-patch) normalized pixels as targets for computing loss')
    parser.set_defaults(norm_pix_loss=False)

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.0005,
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
                        help='epochs to warmup LR')

    # Dataset parameters
    parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
                        help='dataset path')

    parser.add_argument('--output_dir', default='./my_supervised_seg_ViT_earlystop_dice_lossforpatch',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./my_supervised_seg_ViT_earlystop_dice_lossforpatch',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda:3',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    return parser

In [2]:
from util.ccdataset import *
train_paths, valid_paths, test_paths = getPathList()
train_loader, valid_loader, test_loader = getDataloader(
    train_paths, valid_paths, test_paths, B1=64)
len(train_loader),len(valid_loader),len(test_loader)
# torch.Size([64, 1, 224, 224]) torch.FloatTensor


(685, 6880, 6720)

In [3]:
args = get_args_parser()
args = args.parse_args(args=[])
misc.init_distributed_mode(args)
# print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))


Not using distributed mode
[18:59:38.845946] Namespace(batch_size=64,
epochs=1000,
accum_iter=1,
model='mae_vit_base_patch16',
input_size=224,
mask_ratio=0.75,
norm_pix_loss=False,
weight_decay=0.0005,
lr=0.0001,
blr=0.0001,
min_lr=0.0,
warmup_epochs=40,
data_path='/datasets01/imagenet_full_size/061417/',
output_dir='./my_supervised_seg_ViT_earlystop_dice_lossforpatch',
log_dir='./my_supervised_seg_ViT_earlystop_dice_lossforpatch',
device='cuda:3',
seed=0,
resume='',
start_epoch=0,
num_workers=10,
pin_mem=True,
world_size=1,
local_rank=-1,
dist_on_itp=False,
dist_url='env://',
distributed=False)


In [4]:
from my_metric import dice_coef_metric,jaccard_coef_metric
import util.lr_sched as lr_sched
def main(args):

    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    # if global_rank == 0 and args.log_dir is not None:
    if args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)


    # define the model
    model = my_seg_vit.__dict__[args.model](lossforpatch = False)

    model.to(device)

    model_without_ddp = model
    print("Model = %s" % str(model_without_ddp))


    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()  #64

    # if args.lr is None:  # only base_lr is specified
    #     args.lr = args.blr * eff_batch_size / 256

    print("actual lr:" , args.lr)  #使用args.lr

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)



    param_groups = optim_factory.param_groups_weight_decay(
        model_without_ddp, args.weight_decay)

    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))

    print("查看optimizer = ",optimizer)
    loss_scaler = NativeScaler()
    misc.load_model(args=args, model_without_ddp=model_without_ddp,
                    optimizer=optimizer, loss_scaler=loss_scaler)
    print(f"Start training for {args.epochs} epochs")
    dice_list = []
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        

        #----------------------------训练---------------------------------------------
        #每个epoch设置一次优化器
        # optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))
        model.train()
        metric_logger = misc.MetricLogger(delimiter="  ")
        metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        accum_iter = args.accum_iter
        optimizer.zero_grad()#优化器清零
        if log_writer is not None:
            print('log_dir: {}'.format(log_writer.log_dir))
        
        for data_iter_step,data in enumerate(train_loader):
            if data_iter_step % accum_iter == 0:
                lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(train_loader) + epoch, args)
            samples = data["image"]  # 我自己的数据是data["image"]表示图片
            labeles = data["mask"]
            # print("mask.shape = ",labeles.shape)
            # print("samples.shape = ",samples.shape)
            # print("mask.type = ",labeles.dtype)
            samples = samples.to(device, non_blocking=True)
            labeles = labeles.to(device, non_blocking=True)

            with torch.cuda.amp.autocast():
                loss, _,  = model(samples, labeles)

            loss_value = loss.item()

            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            loss /= accum_iter
            loss_scaler(loss, optimizer, parameters=model.parameters(),update_grad=(data_iter_step + 1) % accum_iter == 0)

            if (data_iter_step + 1) % accum_iter == 0:
                optimizer.zero_grad()

            torch.cuda.synchronize()

            metric_logger.update(loss=loss_value)

            lr = optimizer.param_groups[0]["lr"]
            metric_logger.update(lr=lr)

            loss_value_reduce = misc.all_reduce_mean(loss_value)

            if log_writer is not None :
                """ We use epoch_1000x as the x-axis in tensorboard.
                This calibrates different curves when batch size changes.
                """
                epoch_1000x = int(
                    (data_iter_step / len(train_loader) + epoch) * 1000)
                log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
                log_writer.add_scalar('lr', lr, epoch_1000x)

        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print("Averaged stats:", metric_logger)
        train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
        



        

        #----------------------------验证-----------------------------------------------
        #每4个epoch进行一次验证
        if(epoch%4==0):
            with torch.no_grad(): #表示在验证的时候不需要梯度计算
                model.eval()
                dice_score = 0.
                iou_score = 0.
                for data in valid_loader:
                    valid_image = data["image"]
                    valid_target = data["mask"]

                    valid_image = valid_image.to(device, non_blocking=True)
                    valid_target = valid_target.to(device, non_blocking=True)

                    _,out = model(valid_image, valid_target) #[N,196,256]
                    out = model.unpatchify(out) #[N,1,224,224]
                    dice_score += dice_coef_metric(out.cpu(), valid_target.cpu())
                    iou_score += jaccard_coef_metric(out.cpu(), valid_target.cpu())
                
                dice_score /= len(valid_loader)
                iou_score /= len(valid_loader)
                print("epoch=", epoch,
                    "dice_score=", dice_score,
                    "iou_score=", iou_score)

                print("------------------------------------------------------------------")
                log_writer.add_scalar(tag="dice_scalar", scalar_value=dice_score, global_step=epoch)
                log_writer.add_scalar(tag="iou_scalar", scalar_value=iou_score, global_step=epoch)
                #原作者的save方法,封装更多内容 
                misc.save_model(args=args, model=model, model_without_ddp=model, optimizer=optimizer,
                loss_scaler=loss_scaler, epoch=epoch)
                #torch官方save方法
                torch.save(model.state_dict(), '/data/zhanghao/skull_project/MODEL/cangku_cc359_ViT/' +
                                    str(epoch) + 'ViT' + str(dice_score) + '.pth')
                

            #早停
            if epoch > 20:
                min = 999
                for i in range(int(epoch/4 - 5) , int(epoch/4 - 1) ):
                    if dice_list[i] < min:
                        min = dice_list[i]
                if dice_score < min:
                    lr /= 10
                    print(lr)
                    if lr < 1e-6:
                        break

            dice_list.append(dice_score)#依次放入epoch为0 4 8 12 16 20 24....

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch, }

        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


In [5]:
args = get_args_parser()
args = args.parse_args(args=[])
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)


[18:59:39.105961] Namespace(batch_size=64,
epochs=1000,
accum_iter=1,
model='mae_vit_base_patch16',
input_size=224,
mask_ratio=0.75,
norm_pix_loss=False,
weight_decay=0.0005,
lr=0.0001,
blr=0.0001,
min_lr=0.0,
warmup_epochs=40,
data_path='/datasets01/imagenet_full_size/061417/',
output_dir='./my_supervised_seg_ViT_earlystop_dice_lossforpatch',
log_dir='./my_supervised_seg_ViT_earlystop_dice_lossforpatch',
device='cuda:3',
seed=0,
resume='',
start_epoch=0,
num_workers=10,
pin_mem=True,
world_size=1,
local_rank=-1,
dist_on_itp=False,
dist_url='env://')
[18:59:44.448606] Model = SegViT2D(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_featu

KeyboardInterrupt: 

In [None]:
args = get_args_parser()
args = args.parse_args(args=[])
if args.output_dir: 
    print(args.output_dir)

In [None]:
# 加载测试集
from util.ccdataset import *
train_paths, valid_paths, test_paths = getPathList()
train_loader, valid_loader, test_loader = getDataloader(
    train_paths, valid_paths, test_paths, B=1)
len(test_loader)
# torch.Size([64, 1, 224, 224]) torch.FloatTensor

In [None]:
from util.diceloss import dice_coeff
import torch.nn as nn
from functools import partial
# model = my_seg_vit.SegViT2D(
#     patch_size=16, embed_dim=768, depth=12, num_heads=12,
#     decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
#     mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6))
model = my_seg_vit.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)

model.load_state_dict(torch.load('./output_dir/checkpoint-20.pth'))
dicecoeff = 0.
for data in test_loader :
    image = data['image']
    mask = data['mask']
    image = image.cuda(1)
    mask = mask.cuda(1)
    _,pred = model(image)
    dicecoeff += dice_coeff(image,pred)
dicecoeff /= len(test_loader)
print("平均dicecoeff=",dicecoeff)

    

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class SoftDiceLoss(nn.Module):
    def init(self,weight = None,size_average = True):
        super(SoftDiceLoss,self).init()
        def forward(self,logits,targets):
            num = targets.size(0)
            smooth = 1e-9
            probs = F.sigmoid(logits)
            m1 = probs.view(num,-1)
            m2 = targets.view(num,-1)
            intersection = (m1*m2)
            score = 2.*(intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)

            score = 1 - score.sum()/num

            return score


In [None]:
import torch
from torch import Tensor
import torch.nn.functional as F


def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(
            f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')

    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter

        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:,
                           channel, ...], reduce_batch_first, epsilon)

    return dice / input.shape[1]


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)
