# Import Libraries

In [30]:
import os
import random
import time
import copy
import numpy as np
from tqdm import tqdm

from collections import defaultdict
import gc
from PIL import Image

# Sklearn
from sklearn.model_selection import KFold

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

#model 
import segmentation_models_pytorch as smp

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

from colorama import Fore, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

#import wandb
import warnings
warnings.filterwarnings("ignore")

#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Configuration

In [31]:
class CFG:
  seed          = 101
  device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  
  img_size      = [512, 512]
  channels      = 3
  num_classes   = 1

  n_fold        = 4
  train_bs      = 16
  valid_bs      = 32
  n_accumulate  = 64//train_bs
  epochs        = 115

#Model Architecture
  model         = "UNet"
  encoder       = "efficientnet-b2"
  weights       = "imagenet"

#optimizer Hyperparameters Setting
  optimizer     = "Adam"               #'AdamW'
  lr            = 1e-3
  wd            = 1e-5

#Learning Rate Scheduler Setting
  lr_scheduler  = 'CosineAnnealingLR'  #1.'CosineAnnealingWarmRestarts' 2.'CosineAnnealingLR' 
  min_lr        = 1e-7
  
  #'CosineAnnealingLR' Hyperparameters
  T_max         = 1000
  
  #'CosineAnnealingWarmRestarts' Hyperparameters
  T_0           = 5  
  T_mult        = 2

# Set Path

In [32]:
TRAIN_IMG_DIR    = "../Datasets/Train_Images"
TRAIN_MASK_DIR   = "../Datasets/Train_masks"
IMPORT_MODEL_DIR = "../result/Train1_models"
SAVE_MODEL_DIR   = "../result/Train2_models"

# Set Seed 

In [33]:
def set_seed(seed = None):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
set_seed(CFG.seed)

> SEEDING DONE


# Dataset

In [34]:
class STASDataset(Dataset):
    def __init__(self, IMG_DIR, MASK_DIR, dataIndex, transform=None):
        self.image_dir = IMG_DIR
        self.mask_dir  = MASK_DIR
        self.dataIndex = dataIndex
        self.transform = transform

    def __len__(self):
        return len(self.dataIndex)

    def __getitem__(self, index):
        img_name = os.listdir(self.image_dir)[self.dataIndex[index]]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name.replace(".jpg", ".png")) 
        image = np.array(Image.open(img_path))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        mask = np.expand_dims(mask, axis=0)
        return image, mask

# Pre-processing

In [35]:
transforms = {
    "train": A.Compose([
        A.Resize(*CFG.img_size),
        
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        A.Flip(p=0.5),
        A.GaussNoise(p=0.2),
        A.OneOf([
            A.MotionBlur(p=0.5),
            A.MedianBlur(blur_limit=3, p=0.5),
            A.Blur(blur_limit=3, p=0.5),
        ], p=0.5),
        A.HueSaturationValue(p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.CoarseDropout(max_holes=10, max_height=CFG.img_size[0]//30, max_width=CFG.img_size[1]//30,
                    min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ToTensorV2()],p=1.0),
    
    "valid": A.Compose([
        A.Resize(*CFG.img_size),
        ToTensorV2()], p=1.0)
}

# Model

In [36]:
def build_model():
    model = smp.Unet(
        encoder_name     = CFG.encoder,      
        encoder_weights  = CFG.weights,    
        in_channels      = CFG.channels,                  
        classes          = CFG.num_classes,                      
        activation       = None,
    )
    model.to(CFG.device)
    return model
def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    print("Load model from: ",path)
    model.to(CFG.device)
    return model

# Loss Function

In [37]:
JaccardLoss = smp.losses.JaccardLoss(mode='binary')
Jaccard     = smp.losses.JaccardLoss(mode='binary', from_logits=False)
Dice        = smp.losses.DiceLoss(mode='binary', from_logits=False)
#BCELoss     = smp.losses.SoftBCEWithLogitsLoss()

def criterion(y_pred, y_true):
    return JaccardLoss(y_pred, y_true)

# Optimizer

In [38]:
def select_optimizer(model):
    if   CFG.optimizer  == 'Adam':
       optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
    
    elif CFG.optimizer  == 'AdamW':
       optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
    else:
       print("Optimizer Error")
    return optimizer

# LR_Scheduler

In [39]:
def LR_scheduler(optimizer):
    if CFG.lr_scheduler   == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 
                                                   T_max=CFG.T_max, 
                                                   eta_min=CFG.min_lr)
    
    elif CFG.lr_scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                             T_0    =CFG.T_0,
                                                             T_mult =CFG.T_mult, 
                                                             eta_min=CFG.min_lr)

    elif CFG.lr_scheduler == None:
        return None
    
    return scheduler

# Train Function

In [40]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    iters = len(dataloader)
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / CFG.n_accumulate
            
        scaler.scale(loss).backward()
    
        if (step + 1) % CFG.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()
            
            if scheduler is not None:
                scheduler.step()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        pbar.set_postfix(train_loss=f'{epoch_loss:0.4f}',
                        lr=optimizer.param_groups[0]['lr'])
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss

# Validation Function

In [41]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    TARGETS = []
    PREDS   = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
            
        PREDS.append(nn.Sigmoid()(y_pred))
        TARGETS.append(masks)

        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=optimizer.param_groups[0]['lr'])
    
    TARGETS = torch.cat(TARGETS,dim=0).to(torch.float32)
    PREDS   = (torch.cat(PREDS, dim=0)>0.5).to(torch.float32)
    val_dice    = 1. - Dice(TARGETS, PREDS).cpu().detach().numpy()
    val_jaccard = 1. - Jaccard(TARGETS, PREDS).cpu().detach().numpy()
    val_scores  = [val_dice, val_jaccard]  
    
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss, val_scores

# Trainer Function

In [42]:
def trainer(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    #wandb.watch(model, log_freq=100)
    if torch.cuda.is_available():
        print("GPU: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice  = -np.inf
    best_epoch = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CFG.device, epoch=epoch)
        
        val_loss, val_scores = valid_one_epoch(model, valid_loader, 
                                                        device=CFG.device, 
                                                        epoch=epoch)
        
        val_dice, val_jaccard = val_scores
            
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)
        
        #wandb.log({"Train Loss": train_loss, 
        #           "Valid Loss": val_loss,
        #           "Valid Dice": val_dice,
        #           "Valid Jaccard": val_jaccard,
        #           "LR": scheduler.get_last_lr()[0]})
        
        print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')
        
        if val_dice >= best_dice:
            print(f"{c_}Valid Dice Improved ({best_dice:0.4f} --> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            #run.summary["Best Dice"]    = best_dice
            #run.summary["Best Jaccard"] = best_jaccard
            #run.summary["Best Epoch"]   = best_epoch
            #best_model_wts = copy.deepcopy(model.state_dict())
            #setting save model path
            SAVE_MODEL_PATH = f"{SAVE_MODEL_DIR}/best_retrain_Fold{fold:02d}.bin"
            torch.save(model.state_dict(), SAVE_MODEL_PATH)
            #wandb.save(PATH)
            print(f"Model Saved{sr_}")
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_dice))
    
    #model.load_state_dict(best_model_wts)
    
    #return model, history

# Start Training

In [43]:
dataset_len = np.array(range(0, 1053))
kfold = KFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for fold, (train_ids, valid_ids) in enumerate(kfold.split(dataset_len)):
    print("***************************************")
    print(f'*************** Fold: {fold} ***************')
    print("***************************************")
    
    # Define dataset for training and validation data in this fold
    train_ds = STASDataset(TRAIN_IMG_DIR,TRAIN_MASK_DIR,train_ids,transform= transforms["train"])
    valid_ds = STASDataset(TRAIN_IMG_DIR,TRAIN_MASK_DIR,valid_ids,transform= transforms["valid"])
    
    train_loader = DataLoader(train_ds, 
                      batch_size=CFG.train_bs, shuffle=True,pin_memory=True)#, num_workers=2)
    valid_loader = DataLoader(valid_ds,
                      batch_size=CFG.valid_bs, shuffle=False, pin_memory=True)#, num_workers=2)

    model     = load_model(f"{IMPORT_MODEL_DIR}/best_Fold{fold:02d}.bin")
    optimizer = select_optimizer(model)
    scheduler = LR_scheduler(optimizer)
    trainer(model, optimizer, scheduler,
                        device     = CFG.device,
                        num_epochs = CFG.epochs)

***************************************
*************** Fold: 0 ***************
***************************************
Load model from:  C:/Ting/STAS/result/Train1_models/best_Fold00.bin
GPU: NVIDIA GeForce RTX 2080 Ti

Epoch 1/115

Train : 100%|██████████| 50/50 [00:39<00:00,  1.28it/s, lr=0.001, train_loss=0.0729]
Valid : 100%|██████████| 9/9 [00:07<00:00,  1.13it/s, lr=0.001, valid_loss=0.2486]


Valid Dice: 0.8531 | Valid Jaccard: 0.7438
[32mValid Dice Improved (-inf --> 0.8531)
Model Saved[0m

Epoch 2/115

Train :  18%|█▊        | 9/50 [00:07<00:33,  1.21it/s, lr=0.001, train_loss=0.0745]


KeyboardInterrupt: 