In [4]:
from __future__ import division
import os
import argparse
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from loader import *
from networks.attention_swin_unet import SwinAttentionUnet as ViT_seg
from configs import swin_attention_unet as config
from scipy.ndimage.morphology import binary_fill_holes, binary_opening
from sklearn.metrics import f1_score
import pandas as pd
import glob
import nibabel as nib
from tqdm import tqdm
import numpy as np
import copy
import yaml
from types import SimpleNamespace  
from utils import load_pretrain

## Hyper parameters and dataloader

In [5]:


config         = yaml.load(open('./config_skin.yml'), Loader=yaml.FullLoader)
number_classes = int(config['number_classes'])
input_channels = 3
best_val_loss  = np.inf
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_path = config['path_to_data']  

train_dataset = isic_loader(path_Data = data_path, train = True)
train_loader  = DataLoader(train_dataset, batch_size = int(config['batch_size_tr']), shuffle= True)
val_dataset   = isic_loader(path_Data = data_path, train = False)
val_loader    = DataLoader(val_dataset, batch_size = int(config['batch_size_va']), shuffle= False)

test_dataset  = isic_loader(path_Data = data_path, train = False, Test = True)
test_loader   = DataLoader(test_dataset, batch_size = 1, shuffle= True)


NameError: name 'isic_loader' is not defined

# config and arguments

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
                    default='./Synapse/', help='root dir for data')
parser.add_argument('--eval_interval', type=int, default=5, help='eval interval')
parser.add_argument('--volume_path', type=str,
                    default='./Synapse/', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
                    default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
                    default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
                    default=9, help='output channel of network')
parser.add_argument('--saved_model', type=str,
                    default='./weights/weights_isic17.model' , help='output dir')                   
parser.add_argument('--max_iterations', type=int,
                    default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
                    default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
                    default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int,  default=1,
                    help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=0.01,
                    help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
                    default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
                    default=1234, help='random seed')
parser.add_argument('--cfg', type=str, default='configs/swin_tiny_patch4_window7_224_lite.yaml', metavar="FILE", help='path to config file', )
parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                    help='no: no cache, '
                            'full: cache all data, '
                            'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
                    help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                    help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--mode', help='Select our contribution',
                    choices=['swin','cross_contextual_attention', default='swin')
parser.add_argument('--skip_num', help='Select our contribution',
                    choices=['0', '1', '2','3'], default='3'),
parser.add_argument('--operationaddatten', help='Select our contribution',
                    choices=['+', 'mul'], default='+')
parser.add_argument('--attention', help='0 or 1',
                    choices=['0',"1"], default="0")

                    

args = parser.parse_args(args=[])
if args.dataset == "Synapse":
    args.root_path = os.path.join(args.root_path, "train_npz")
                             
config =  config.get_swin_unet_attention_configs().to_dict()
config.update(vars(args))
configs = SimpleNamespace(**config)



=> merge config from configs/swin_tiny_patch4_window7_224_lite.yaml


# build model and optimizer and loss

In [4]:
# config_model = get_config() 
Net   = ViT_seg(configs,num_classes=args.num_classes).cuda()
Net   = load_pretrain(configs,Net)
Net   = Net.to(device)
if int(config['pretrained']):
    Net.load_state_dict(torch.load(config['saved_model'], map_location='cpu')['model_weights'])
    best_val_loss = torch.load(config['saved_model'], map_location='cpu')['val_loss']

optimizer = optim.Adam(Net.parameters(), lr= float(config['lr']))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = config['patience'])
criteria  = torch.nn.BCELoss()


SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:1
------------------------------ 0 <class 'str'>
mode: swin skip_num 3 cross_num 3
pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth
---start load pretrained modle of swin encoder---


# Training

In [None]:
best_F1_score = 0.0
for ep in range(int(config['epochs'])):
    Net.train()
    epoch_loss = 0
    for itter, batch in enumerate(train_loader):
        img = batch['image'].to(device, dtype=torch.float)
        msk = batch['mask'].to(device)
        mask_type = torch.float32
        msk = msk.to(device=device, dtype=mask_type)
        msk_pred = Net(img)
        loss          = criteria(msk_pred, msk) 
        optimizer.zero_grad()
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()  
        if itter%int(float(config['progress_p']) * len(train_loader))==0:
            print(f' Epoch>> {ep+1} and itteration {itter+1} Loss>> {((epoch_loss/(itter+1)))}')

    predictions = []
    gt = []

    if (ep+1)%args.eval_interval==0:
        with torch.no_grad():
            print('val_mode')
            val_loss = 0
            Net.eval()
            for itter, batch in tqdm(enumerate(test_loader)):
                img = batch['image'].to(device, dtype=torch.float)
                msk = batch['mask']
                msk_pred = Net(img)

                gt.append(msk.numpy()[0, 0])
                msk_pred = msk_pred.cpu().detach().numpy()[0, 0]
                msk_pred  = np.where(msk_pred>=0.4, 1, 0)
                predictions.append(msk_pred)        

        predictions = np.array(predictions)
        gt = np.array(gt)

        y_scores = predictions.reshape(-1)
        y_true   = gt.reshape(-1)

        y_scores2 = np.where(y_scores>0.5, 1, 0)
        y_true2   = np.where(y_true>0.5, 1, 0)

        #F1 score
        F1_score = f1_score(y_true2, y_scores2, labels=None, average='binary', sample_weight=None)
        print ("\nF1 score (F-measure) or DSC: " +str(F1_score))    
        if (F1_score) > best_F1_score:
            print('New best loss, saving...')
            best_F1_score = copy.deepcopy(F1_score)
            state = copy.deepcopy({'model_weights': Net.state_dict(), 'test_F1_score': F1_score})
            torch.save(state, args.saved_model)
        
        
        
        

 Epoch>> 1 and itteration 1 Loss>> 0.7598584294319153
 Epoch>> 1 and itteration 30 Loss>> 0.45233075122038524
 Epoch>> 1 and itteration 59 Loss>> 0.3408874323812582
 Epoch>> 2 and itteration 1 Loss>> 0.3103235065937042
 Epoch>> 2 and itteration 30 Loss>> 0.17979845528801283
 Epoch>> 2 and itteration 59 Loss>> 0.17059589619353666
 Epoch>> 3 and itteration 1 Loss>> 0.12587174773216248
 Epoch>> 3 and itteration 30 Loss>> 0.1560018355647723
 Epoch>> 3 and itteration 59 Loss>> 0.14095270254854428
 Epoch>> 4 and itteration 1 Loss>> 0.10675635188817978
 Epoch>> 4 and itteration 30 Loss>> 0.10480246866742769
 Epoch>> 4 and itteration 59 Loss>> 0.11470934066732051
 Epoch>> 5 and itteration 1 Loss>> 0.10782317817211151
 Epoch>> 5 and itteration 30 Loss>> 0.11757173538208007
 Epoch>> 5 and itteration 59 Loss>> 0.10765916816258835
val_mode


400it [00:04, 82.60it/s]



F1 score (F-measure) or DSC: 0.9043919450327523
New best loss, saving...
 Epoch>> 6 and itteration 1 Loss>> 0.09028090536594391
 Epoch>> 6 and itteration 30 Loss>> 0.08701437128086885
 Epoch>> 6 and itteration 59 Loss>> 0.09338036709922855
 Epoch>> 7 and itteration 1 Loss>> 0.09963338822126389
 Epoch>> 7 and itteration 30 Loss>> 0.09091660069922607
 Epoch>> 7 and itteration 59 Loss>> 0.08712454190698721
 Epoch>> 8 and itteration 1 Loss>> 0.059745561331510544
 Epoch>> 8 and itteration 30 Loss>> 0.08953677602112294
 Epoch>> 8 and itteration 59 Loss>> 0.08596139300172612
 Epoch>> 9 and itteration 1 Loss>> 0.07752776145935059
 Epoch>> 9 and itteration 30 Loss>> 0.08311334513127804
 Epoch>> 9 and itteration 59 Loss>> 0.08043004149350069
 Epoch>> 10 and itteration 1 Loss>> 0.07093984633684158
 Epoch>> 10 and itteration 30 Loss>> 0.07712965222696463
 Epoch>> 10 and itteration 59 Loss>> 0.07644497319045714
val_mode


400it [00:05, 79.01it/s]



F1 score (F-measure) or DSC: 0.9152280165215335
New best loss, saving...
 Epoch>> 11 and itteration 1 Loss>> 0.07130124419927597
 Epoch>> 11 and itteration 30 Loss>> 0.07121715831259887
 Epoch>> 11 and itteration 59 Loss>> 0.07280064083762088
 Epoch>> 12 and itteration 1 Loss>> 0.08941862732172012
 Epoch>> 12 and itteration 30 Loss>> 0.093370030199488
 Epoch>> 12 and itteration 59 Loss>> 0.08433971730834347
 Epoch>> 13 and itteration 1 Loss>> 0.13843557238578796
 Epoch>> 13 and itteration 30 Loss>> 0.07845850611726443
 Epoch>> 13 and itteration 59 Loss>> 0.07514385039270935
 Epoch>> 14 and itteration 1 Loss>> 0.0605207160115242
 Epoch>> 14 and itteration 30 Loss>> 0.06449461753169695
 Epoch>> 14 and itteration 59 Loss>> 0.07079538431460575
 Epoch>> 15 and itteration 1 Loss>> 0.041732873767614365
 Epoch>> 15 and itteration 30 Loss>> 0.06670719670752684
 Epoch>> 15 and itteration 59 Loss>> 0.06604485971442724
val_mode


400it [00:05, 67.82it/s]



F1 score (F-measure) or DSC: 0.9172819426683395
New best loss, saving...
 Epoch>> 16 and itteration 1 Loss>> 0.0721331313252449
 Epoch>> 16 and itteration 30 Loss>> 0.06559696520368258
 Epoch>> 16 and itteration 59 Loss>> 0.06454508890539913
 Epoch>> 17 and itteration 1 Loss>> 0.042822740972042084
 Epoch>> 17 and itteration 30 Loss>> 0.06008548500637213
 Epoch>> 17 and itteration 59 Loss>> 0.06137927040710288
 Epoch>> 18 and itteration 1 Loss>> 0.05977281555533409
 Epoch>> 18 and itteration 30 Loss>> 0.05741407809158166
 Epoch>> 18 and itteration 59 Loss>> 0.059474343591827454
 Epoch>> 19 and itteration 1 Loss>> 0.05831508710980415
 Epoch>> 19 and itteration 30 Loss>> 0.055315382033586505
 Epoch>> 19 and itteration 59 Loss>> 0.055440444814956795
 Epoch>> 20 and itteration 1 Loss>> 0.04754256457090378
