In [None]:
import warnings
warnings.filterwarnings('ignore')

import wandb 

import numpy as np
import pandas as pd

import random
import os, shutil, gc, yaml

from tqdm import tqdm
# from glob import glob
import glob
from collections import defaultdict
from attrdict import AttrDict
import time
from copy import deepcopy
import joblib
from joblib import Parallel, delayed

from IPython import display as ipd
from colorama import Fore, Back, Style
c_ = Fore.GREEN
sr_ = Style.RESET_ALL

import cv2
import albumentations as A
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedGroupKFold

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

import timm

import rasterio

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
## custom
from utils import *
from data import *
from model import *
from scheduler import *
from loss import *

In [None]:
#################  config  #################
config_name = 'unet-efficient-b3.yaml' # cuda:0
# config_name = 'unext-resnext101_32x4d.yaml' # cuda:1

with open(f'./configs/{config_name}', 'r') as f:
    CFG = AttrDict(yaml.load(f, yaml.FullLoader))
print(CFG)

@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    val_scores = []
    class_scores = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        y_pred = nn.Sigmoid()(y_pred)
        
        train_dice1 = dice_coef(masks[:,0,:,:].unsqueeze(1), y_pred[:,0,:,:].unsqueeze(1)).cpu().detach().numpy()
        train_dice2 = dice_coef(masks[:,1,:,:].unsqueeze(1), y_pred[:,1,:,:].unsqueeze(1)).cpu().detach().numpy()
        train_dice3 = dice_coef(masks[:,2,:,:].unsqueeze(1), y_pred[:,2,:,:].unsqueeze(1)).cpu().detach().numpy()
        class_scores.append([train_dice1, train_dice2, train_dice3])
        
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
        val_scores.append([val_dice, val_jaccard])
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_memory=f'{mem:0.2f} GB')
        
    val_scores  = np.mean(val_scores, axis=0)
    class_scores = np.mean(class_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss, val_scores, class_scores

def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    wandb.watch(model, log_freq=100)
    
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
        
    valid_loss, valid_scores, valid_class_scores = valid_one_epoch(model, valid_loader, 
                                             device=CFG.device, 
                                             epoch=epoch)
    valid_dice, valid_jaccard = valid_scores

    print("valid class score", valid_class_scores, np.mean(valid_class_scores, axis=0))

    history['Valid Loss'].append(valid_loss)
    history['Valid Dice'].append(valid_dice)
    history['Valid Jaccard'].append(valid_jaccard)

    print(f'Valid Dice: {valid_dice:0.4f} | Valid Jaccard: {valid_jaccard:0.4f}')

    # deep copy the model
    if valid_dice >= best_dice:
        print(f"{c_}Valid Score Improved ({best_dice:0.4f} ---> {valid_dice:0.4f})")
        best_dice    = valid_dice
        best_jaccard = valid_jaccard
        best_epoch   = epoch
        run.summary["Best Dice"]    = best_dice
        run.summary["Best Jaccard"] = best_jaccard
        run.summary["Best Epoch"]   = best_epoch
        best_model_wts = deepcopy(model.state_dict())
        PATH = f"./pths/{CFG.comment}/best_epoch-{fold:02d}-dice{best_dice:.4f}.bin"
        torch.save(model.state_dict(), PATH)
        # Save a model file from the current directory
        # wandb.save(PATH)
        print(f"Model Saved{sr_}")

    last_model_wts = deepcopy(model.state_dict())
    PATH = f"./pths/{CFG.comment}/last_epoch-{fold:02d}.bin"
    torch.save(model.state_dict(), PATH)

    print(); print()
    
    print("Best Score: {:.4f}".format(best_jaccard))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history


os.makedirs(f"./pths/{CFG.comment}", exist_ok=True) 

set_seed(CFG.seed)

# data setting
# print(pd, glob) 
path_df = pd.DataFrame(glob.glob(CFG.data_root), columns=['image_path'])
path_df['mask_path'] = path_df.image_path.str.replace('image','mask')
path_df['id'] = path_df.image_path.map(lambda x: x.split('/')[-1].replace('.npy',''))

df = pd.read_csv('./train_.csv')
df['segmentation'] = df.segmentation.fillna('')
df['rle_len'] = df.segmentation.map(len)

df2 = df.groupby(['id'])['segmentation'].agg(list).to_frame().reset_index() # rle list of each id
df2 = df2.merge(df.groupby(['id'])['rle_len'].agg(sum).to_frame().reset_index()) # total length of all rles of each id

df = df.drop(columns=['segmentation', 'class', 'rle_len'])
df = df.groupby(['id']).head(1).reset_index(drop=True)
df = df.merge(df2, on=['id'])
df['empty'] = (df.rle_len==0) # empty masks

df = df.drop(columns=['image_path','mask_path'])
df = df.merge(path_df, on=['id'])

fault1 = 'case7_day0'
fault2 = 'case81_day30'
df = df[~df['id'].str.contains(fault1) & ~df['id'].str.contains(fault2)].reset_index(drop=True)

skf = StratifiedGroupKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):
    df.loc[val_idx, 'fold'] = fold

    
if CFG.positive_only == True:
    df = df[df['empty'] == False]
    
# env setting
model = build_model(CFG)
optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
scheduler = fetch_scheduler(optimizer, CFG)

# print(df.image_path)
for fold in CFG.folds:
    print(f'#'*15)
    print(f'### Fold: {fold}')
    print(f'#'*15)
    run = wandb.init(project='uw-segmentation', 
                     config={k:v for k, v in dict(vars(CFG)).items() if '__' not in k},
                     anonymous='must',
                     name=f"fold-{fold}|dim-{CFG.img_size[0]}x{CFG.img_size[1]}|model-{CFG.model_name}",
                     group=CFG.comment,
                    )
    train_loader, valid_loader = prepare_loaders(fold, df, get_train_transforms(CFG), get_valid_transforms(CFG), CFG,  debug=CFG.debug)
    model     = build_model(CFG)
    optimizer = optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
    scheduler = fetch_scheduler(optimizer, CFG)
    model, history = run_training(model, optimizer, scheduler,
                                  device=CFG.device,
                                  num_epochs=CFG.epochs)
    run.finish()
    # display(ipd.IFrame(run.url, width=1000, height=720))