In [None]:
import copy
import random
import torch.nn as nn
import datetime
import albumentations as A
from skimage.morphology import skeletonize
import seaborn as sns
import torch
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import Dataset,DataLoader
from pathlib import Path
import json
from tqdm.notebook import tqdm
import matplotlib.patches as mpatches
import zarr
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
from torch.utils.data import DataLoader
import shutil
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F

In [None]:

import torch
from albumentations.pytorch import ToTensorV2
# In[1]:
import numpy as np
from torchvision.transforms import v2
import json
from tqdm.notebook import tqdm
#!/usr/bin/env python
from torch.utils.data import DataLoader
import albumentations as A
# coding: utf-8

# dataset

In [None]:
class UnetDataset(Dataset):
    def __init__(self,transform,data,base_size=512,out_counts=7):
        super(UnetDataset,self).__init__()
        self.data = data
        self.transform = transform
        self.to_tensor = ToTensorV2()
        self.resizers = [
            A.Resize(base_size//(2**i),base_size//(2**i),
            interpolation=cv2.INTER_NEAREST,
            mask_interpolation=cv2.INTER_NEAREST) for i in range(1,out_counts)
        ]
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        img , mask = self.data[index]
        img = np.expand_dims(img, axis=-1) 
        mask = mask[...,None]
        result = self.transform(image=img, mask=mask)
        new_image = result['image']
        new_mask = result['mask'].squeeze(-1)

        new_masks =[new_mask] + [resizer(image = new_mask)["image"] for resizer in self.resizers]


        new_image = self.to_tensor(image = new_image)["image"]
        new_masks = [torch.tensor(m).long() for m in new_masks]
        return new_image.float() , new_masks

class ValidUnetDataset(Dataset):
    def __init__(self,transform,data):
        super(ValidUnetDataset,self).__init__()
        self.data = data
        self.transform = transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        img , mask = self.data[index]
        img = np.expand_dims(img, axis=-1) 
        mask = mask[...,None]
        result = self.transform(image=img, mask=mask)
        new_image = result['image']
        new_mask = result['mask'].squeeze(-1)

        return new_image.float() , [new_mask.long()]

class UnetExampleDataset(Dataset):
    def __init__(self,transform,data,base_transform=None):
        super(UnetExampleDataset,self).__init__()
        
        self.data = data
        self.transform = transform
        if(base_transform is None):
            self.base_transform = A.Compose([ToTensorV2()])
        else:
            self.base_transform = base_transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        img , mask = self.data[index]
        img = np.expand_dims(img, axis=-1) 
        result = self.transform(image=img, mask=mask)
        new_image = result['image']
        new_mask = result['mask']
        
        raw_result = self.base_transform(image=img, mask=mask)
        raw_image = raw_result['image']
        raw_mask = raw_result['mask']
        return new_image.float() , new_mask , raw_image.float() , raw_mask

if __name__ == "__main__":
    train_transforms = A.Compose([
        A.GaussianBlur(
            sigma_limit=[0.1,0.5],
            p=0.5
        ),
        A.RandomBrightnessContrast(
            brightness_limit=0.1,
            contrast_limit=0.15,
            brightness_by_max=True,
            p=0.3
        ),
        A.RandomGamma(
            gamma_limit=(90, 120), 
            p=0.3
        ),
        A.Rotate(limit=15, p=0.3 , fill_mask = 0),
        A.HorizontalFlip(p=0.3),
        A.VerticalFlip(p=0.3),
        # A.Lambda(image=normalize_xca),
        ]
    )
    ds = UnetDataset(transform=train_transforms,data = [[np.random.rand(512,512),np.random.rand(512,512)],[np.random.rand(512,512),np.random.rand(512,512)]])
    dl = DataLoader(ds,batch_size=2)
    for img , masks in dl:
        print(img.shape)
        for mask in masks:
            print(mask.shape)

    

# helpers

In [None]:
def read_images(base_path, part,preprocessor,max_workers=None):
    base_path = Path(base_path)
    images_base = base_path / "images" / part
    labels_base = base_path / "labels" / part
    skels_base = base_path / "skels" / part

    image_names = sorted([p.name for p in os.scandir(images_base) if p.is_file()])
    if(not preprocessor):
        print("NOTE : preprocessor is not defined . no preprocessing will be used !")
    def _read_one(fname):
        name_stem = Path(fname).stem
        img_path = images_base / fname
        label_path = labels_base / f"{name_stem}.zarr"
        # skel_path = skels_base / fname

        # skel_img = cv2.imread(str(skel_path), cv2.IMREAD_GRAYSCALE)/255.0
        img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
        if(preprocessor):
            img = preprocessor(img)
        label = zarr.load(str(label_path))

        return img, label

    if max_workers is None:
        cpu = os.cpu_count() or 4
        max_workers = min(32, cpu * 4)

    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        for img, label in tqdm(ex.map(_read_one, image_names), total=len(image_names)):
            results.append([img,label])

    return results

def to_device(img,gt_mask,device,binary_mode):
    gt_mask = gt_mask.long()
    img = img.to(device)
    gt_mask = gt_mask.to(device)
    if(binary_mode):
        gt_label = gt_label.to(device)
    else :
        gt_label = None
    return img , gt_mask 


def crop_dims(target , current):
    left = (current.shape[3]-target.shape[3])//2
    right = (current.shape[3]-target.shape[3]) - left
    top = (current.shape[2]-target.shape[2])//2
    down = (current.shape[2]-target.shape[2]) - top
    croped = current[:,:,top:-down , left:-right]
    return croped
def padd_dims(target , current):
    pad_h = target.shape[2] - current.shape[2] 
    pad_w = target.shape[3] - current.shape[3]
    padded = F.pad(current, (0, pad_w, 0, pad_h), mode='constant', value=0)
    return padded

@torch.no_grad()
def TP_TN_FP_FN(preds,gt,process_preds=True,return_TN=False):
    if(process_preds):
        preds_argmax = torch.argmax(preds,dim=1)
        onehot_preds = F.one_hot(preds_argmax,num_classes=preds.shape[1])
        pred_onehot = onehot_preds.permute(0, 3, 1, 2).float()
    else :
        pred_onehot = preds
        
    onehot_gt = F.one_hot(gt,num_classes=preds.shape[1])
    onehot_gt = onehot_gt.permute(0, 3, 1, 2).float()
    TN = 0
    if(return_TN):
        TN = (((1-onehot_gt)*(1-pred_onehot)).sum(dim=(0,2,3))).cpu()
    TP = ((onehot_gt*pred_onehot).sum(dim=(0,2,3))).cpu()
    
    FP = (((1-onehot_gt)*pred_onehot).sum(dim=(0,2,3))).cpu()
    FN = ((onehot_gt*(1-pred_onehot)).sum(dim=(0,2,3))).cpu()
    return TP , TN , FP , FN

def draw_mask(image,mask,args=None,colors=None):
    img = image.copy()
    if(args is not None):
        class_count= args["class_count"]
    H,W,C=img.shape
    for i in range(H):
        for j in range(W):
            c = mask[i,j]
            if(c==0):
                continue
            if(colors is not None):
                img[i,j] = colors[c-1]
            else :
                img[i,j] = (0,255,0)
    # plt.imshow(image)
    return img.astype(np.uint8)

def plot_some_images(data,transforms,image_counts=36,fig_shape=(6,6),base_transforms=None):
    ds = UnetExampleDataset(transform=transforms , data=data,base_transform=base_transforms)
    dataloader = DataLoader(
        ds,
        batch_size = 2 ,
        num_workers = 4 ,
        pin_memory=False,
        shuffle=True
    )

    iter_loader = iter(dataloader)
    w,h=fig_shape
    plt.figure(figsize=(w*5,h*5))
    for i in range(1,image_counts+1,2):
        new_imgs , new_mask , old_imgs , old_mask = next(iter_loader)
        new_img = new_imgs[0].numpy()
        old_img = old_imgs[0].numpy()
        if(new_img.shape[0]==1):
            new_img = new_img[0]
            old_img = old_img[0]

        x_disp = (new_img- new_img.min()) / (new_img.max() - new_img.min() + 1e-8)
        new_img = np.repeat(x_disp[..., None], 3, axis=2)*255
        new_img = draw_mask(new_img,new_mask[0])

        plt.subplot(w,h,i)
        plt.imshow(old_img,cmap="gray")
        plt.title("Old Image")

        plt.subplot(w,h,i+1)
        plt.imshow(new_img)
        plt.title("New Image")

def pre_hard_skeletonize(base_path,output_path):
    parts = ["train","val","test"]
    os.makedirs(os.path.join(output_path , "skels"),exist_ok=True)
    for part in parts:
        mask_base_path = os.path.join(base_path,"labels",part)
        os.makedirs(os.path.join(output_path , "skels",part),exist_ok=True)

        mask_list = os.listdir(mask_base_path)
        for mask_name in tqdm(mask_list):
            name = Path(mask_name).stem
            mask_path = os.path.join(mask_base_path,mask_name)

            mask = zarr.load(str(mask_path))
       
            mask = (mask!=0).astype(np.uint8)
        
            out_skel_path = os.path.join(output_path,"skels",part,f"{name}.png")
            skel = skeletonize(mask).astype(np.uint8) * 255
            cv2.imwrite(out_skel_path,skel)
@torch.no_grad()
def pre_soft_skeletonize(base_path,output_path,batch_size=10,k=25):
    parts = ["train","val","test"]
    os.makedirs(os.path.join(output_path , "skels_soft"),exist_ok=True)
    for part in parts:
        mask_base_path = os.path.join(base_path,"labels",part)
        os.makedirs(os.path.join(output_path , "skels_soft",part),exist_ok=True)

        mask_list = os.listdir(mask_base_path)
        mask_buffer = []
        name_buffer = []
        for i,mask_name in enumerate(tqdm(mask_list)):
            name = Path(mask_name).stem
            mask_path = os.path.join(mask_base_path,mask_name)
            mask = zarr.load(str(mask_path))
            mask = (mask!=0).astype(np.float32)

            mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)
            mask_buffer.append(mask)
            name_buffer.append(name)
            if((i+1)%batch_size==0 or i==len(mask_list)-1):
                mask_buffer = torch.cat(mask_buffer,dim=0).to("cuda")
                skels = soft_skeletonize(mask_buffer,k=k)
                skels = skels.cpu().numpy()
                B = skels.shape[0]
                for i in range(B):
                    skel = skels[i,0].astype(np.uint8)*255
                    o_name = name_buffer[i]
                    out_skel_path = os.path.join(
                        output_path,"skels_soft",part,f"{o_name}.png")
                    cv2.imwrite(out_skel_path,skel)
                mask_buffer = []
                name_buffer = []

@torch.no_grad()
def compute_confution_matrix(data_loader,model,class_maps,output_folder_path=None,draw_plot = True,class_count=26):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    conf_mat = torch.zeros((class_count,class_count))
    model.eval()
    for img , masks in tqdm(data_loader):
        img = img.to(device)
        with torch.autocast(device_type=device,dtype=torch.float16):
            mask = masks[0].to(device).view(-1)
            pred_masks = model(img)

        pred_mask = torch.argmax(pred_masks[0],dim=1).view(-1)
        encoded_results = (mask*class_count + pred_mask).cpu()
        counts = torch.bincount(encoded_results,minlength=class_count**2).view(class_count,class_count)
        conf_mat += counts
        
    conf_mat = conf_mat.float() / conf_mat.sum(dim=1,keepdims=True).clamp(min=1)
    conf_mat = conf_mat.numpy()
    if(draw_plot):
        class_names = ["background" for i in range(class_count)]
        for index , name in class_maps.items():
            class_names[index] = name
        plt.figure(figsize=(20,20))
        ax = sns.heatmap(
            conf_mat,
            annot=True,
            fmt=".2f",
            xticklabels=class_names,
            yticklabels=class_names,
            cmap="Blues"
        )
        ax.set_xlabel("Predicted class")
        ax.set_ylabel("True class")
        ax.set_title("Confusion Matrix")
        plt.tight_layout()
        if(output_folder_path):
            out_path = os.path.join(output_folder_path,"conf_mat.png")
            plt.savefig(out_path)
    return conf_mat

def erode(mask):
    h_pool = -F.max_pool2d(-mask,(3,1),(1,1),(1,0))
    v_pool = -F.max_pool2d(-mask,(1,3),(1,1),(0,1))
    return torch.min(v_pool,h_pool)
def dilate(mask):
    return F.max_pool2d(mask,(3,3),(1,1),(1,1))
def soft_open(mask):
    return dilate(erode(mask))
def soft_skeletonize(I,k=25):
    I_ = soft_open(I)
    S = F.relu(I-I_)
    for i in range(k):
        I = erode(I)
        I_ = soft_open(I)
        S = S + (1-S)*F.relu(I-I_)
    return S

# logger

In [None]:
colors = np.array([
    (242,  24,  24),   # Red
    (242,  77,  24),   # Red-Orange
    (242, 129,  24),   # Orange
    (242, 181,  24),   # Yellow-Orange
    ( 24, 242, 216),   # Cyan
    (242, 234,  24),   # Yellow
    (146,  24, 242),   # Purple
    (199, 242,  24),   # Yellow-Green
    (146, 242,  24),   # Lime
    ( 94, 242,  24),   # Green
    (242,  24, 181),   # Fuchsia
    ( 42, 242,  24),   # Green (brighter)
    ( 94,  24, 242),   # Violet
    ( 24, 242,  59),   # Spring Green
    (242,  24, 129),   # Pink
    ( 24, 242, 111),   # Aquamarine
    ( 24, 242, 164),   # Turquoise
    ( 24, 164, 242),   # Azure
    (199,  24, 242),   # Magenta
    ( 24, 216, 242),   # Sky Blue
    ( 24, 111, 242),   # Blue
    (242,  24, 234),   # Hot Pink
    ( 24,  59, 242),   # Royal Blue
    ( 42,  24, 242),   # Indigo
    (242,  24,  77),   # Rose
], dtype=np.uint8)

@torch.no_grad()
def save_full_report(recorder,output_base_path,model,valid_loader,
                     args,class_map,name=None):
    now = datetime.datetime.now()
    save_folder_name = str(now)
    if(name):
        save_folder_name += f" [{name}]"
    output_folder_path = os.path.join(output_base_path,save_folder_name)

    os.makedirs(output_folder_path,exist_ok=True)

    print("Saving Memory")
    save_memory(recorder,args,output_folder_path)

    print("Saving All Plots")
    draw_loss_plots(recorder , output_folder_path)
    draw_avg_metric_plots(recorder , output_folder_path)
    draw_all_metric_plots(recorder , output_folder_path)
    compute_confution_matrix(
        data_loader=valid_loader,
        model = model,
        class_maps = class_map,
        draw_plot = True,
        class_count=len(class_map)+1,
        output_folder_path=output_folder_path
    )
    print("Saving Examples")
    draw_examples(model,valid_loader,args,class_map,output_folder_path)

    print("Saving Verbal Results")
    write_verbal_results(recorder,output_folder_path)

    print("Copying Notebook To Results")
    notebook_name = "nnUnetAttention.ipynb"

    notebook_out_path = os.path.join(output_folder_path,"notebook.ipynb") 
    shutil.copyfile(f"./{notebook_name}",notebook_out_path )
    print("builfding kaggle project")
    build_kaggle_project(output_folder_path)

def write_verbal_results(recorder,output_base_path):
    report = ""
    report_path = os.path.join(output_base_path,"report.txt")
    losses_keys = recorder.losses_keys
    with open("./data/train_count.json","r") as f:
        train_count = json.load(f)
    for part,data in recorder.metric_avg_list.items():
        report +=f"======= > {part} verbal Report < =======\n"

        dice_list = data["dice"]
        precison_list = data["precision"]
        recall_list = data["recall"]

        best_idx = int(np.argmax(dice_list))
        
        best_dice = dice_list[best_idx]
        best_precision = precison_list[best_idx]
        best_recall = recall_list[best_idx]


        report += (
            f"best epoch : [{best_idx+1}]\n"
            f"best dice : [{best_dice}] - best precision : [{best_precision}] - best recall : [{best_recall}] \n"
        )
    
        for loss_name in losses_keys:
            loss_list = recorder.history[part][loss_name]
            best_loss = loss_list[best_idx]

            report += f"bset {loss_name} : [{best_loss}] - "
            

        for index , c in recorder.class_maps.items():
            dice = recorder.metric_history[part]["dice"][index][best_idx]
            precision = recorder.metric_history[part]["precision"][index][best_idx]
            recall = recorder.metric_history[part]["recall"][index][best_idx]

            counts = train_count[c]
            report += f"{c} => dice : {dice} - p : {precision} - r : {recall} || train counts : {counts}\n"
        report +="<=><=><=><=><=><=><=><=><=><=><=><=><=><=><=><=><=>\n"
    with open(report_path , "w") as f : 
        f.write(report)

def save_memory(recorder,args,output_folder_path):
    history_path = os.path.join(output_folder_path,"loss_history.json")
    full_metric_path = os.path.join(output_folder_path,"full_metric_hostory.json")
    avg_metric_path = os.path.join(output_folder_path,"avg_metric_hostory.json")
    args_path = os.path.join(output_folder_path,"args.json")

    with open(history_path , "w") as f:
        json.dump(recorder.history,f,indent=4)
    with open(full_metric_path , "w") as f:
        json.dump(recorder.metric_history,f,indent=4)
    with open(avg_metric_path , "w") as f:
        json.dump(recorder.metric_avg_list,f,indent=4)
    with open(args_path , "w") as f:
        json.dump(args,f,indent=4)

def draw_loss_plots(recorder,output_folder_path):
    plt.figure(figsize=(15,20))
    
    losses_keys = recorder.losses_keys
    colors = ["g","r","b","y","orange"]
    colors_per_class = {}

    for i,loss_name in enumerate(losses_keys):
        colors_per_class[loss_name] = colors[i]

    plt_path =os.path.join(output_folder_path,"loss_plot.png")
    
    for i,part in enumerate(recorder.history):
        plt.subplot(2,1,i+1)
        for loss_name,data in recorder.history[part].items():
            length = len(data)-1
            x = np.arange(length)
            plt.plot(x,data[:-1],color = colors_per_class[loss_name],label=loss_name)
        plt.title(f"{part} loss plot")
        plt.legend()
    plt.savefig(plt_path,dpi=150)

def draw_avg_metric_plots(recorder,output_folder_path):
    plt.figure(figsize=(15,20))
    plt_path =os.path.join(output_folder_path,"avg_metrics.png")
    for i,part in enumerate(recorder.metric_avg_list):
        plt.subplot(2,1,i+1)

        dice_data = recorder.metric_avg_list[part]["dice"]
        precision_data = recorder.metric_avg_list[part]["precision"]
        recall_data = recorder.metric_avg_list[part]["recall"]

        length = len(dice_data)
        x = np.arange(length)
        plt.plot(x,dice_data,color="g",label="dice")
        plt.plot(x,precision_data,color="r",label="precision")
        plt.plot(x,recall_data,color="b",label="recall")
        plt.title(f"{part} avg dice plot")
        plt.legend()
    plt.savefig(plt_path)
def draw_all_metric_plots(recorder,output_folder_path):
    for part in recorder.history: 
        plt_path =os.path.join(output_folder_path,f"{part}_full_metric.png")
        plt.figure(figsize=(30,30))
        for i , class_index  in enumerate(recorder.metric_history[part]["dice"]):
            dice_data = recorder.metric_history[part]["dice"][class_index]
            precision_data = recorder.metric_history[part]["precision"][class_index]
            recall_data = recorder.metric_history[part]["recall"][class_index]

            class_name = recorder.class_maps[class_index]
            plt.subplot(5,5,i+1)
            length = len(dice_data)
            x = np.arange(length)
            plt.plot(x,dice_data,color="g",label="dice")
            plt.plot(x,precision_data,color="r",label="precision")
            plt.plot(x,recall_data,color="b",label="recall")
            plt.title(f"{class_name}")
            plt.legend()

        plt.savefig(plt_path)



@torch.no_grad()
def draw_examples(model,valid_loader,args,class_map,output_folder_path,w=6,h=6):
    plt_path = os.path.join(output_folder_path,"examples.png")
    plt.figure(figsize=(30,30))
    plot_count =18
    patches = [
        mpatches.Patch(color=np.array(colors[j-1]) / 255.0, label=class_map[j])
        for j in range(1,len(class_map)+1)
    ]
    i=0
    img_index=1
    valid_iterator = iter(valid_loader)
    model.eval()
    for i in tqdm(range(plot_count)):
        img , mask = next(valid_iterator)
        with torch.autocast(device_type=args["device"],dtype=torch.float16):
            pred_masks = model(img.to(args["device"]))
        pred_mask = pred_masks[0].cpu().numpy()
        pred_mask = np.argmax(pred_mask,axis=1)
        mask = mask[0].numpy()
        img = img.numpy()
        x_disp = (img[0,0] - img[0,0].min()) / (img[0,0].max() - img[0,0].min() + 1e-8)
        rgb = np.repeat(x_disp[..., None], 3, axis=2)*255
        real_annoted = draw_mask(rgb,mask[0],args,colors)
        pred_annoted = draw_mask(rgb,pred_mask[0],args,colors)
        
        plt.subplot(h,w,img_index)
        plt.imshow(real_annoted)
        plt.title(f"Ground Truth ")
        plt.subplot(h,w,img_index+1)
        plt.imshow(pred_annoted)
        plt.title(f"Predicted ")
        img_index+=2
        if(img_index-1==h):
            plt.legend(
                handles=patches,
                bbox_to_anchor=(1.05, 1),
                loc='upper left',
                borderaxespad=0.,
                title="Classes"
            )
        i+=1
    plt.savefig(plt_path)

# preprocessing

In [None]:
class CLAHE : 
    def __init__(self,clipLimit=2.0,tileGridSize=(8, 8)):
        self.clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    def __call__(self,img):
        enhanced = np.clip(img, 0, 255)
        return self.clahe.apply(enhanced)
class WhiteTopHat:
    def __init__(self,kernel_size = (50, 50),turn_neg = True):
        self.kernel = cv2.getStructuringElement(cv2.MORPH_RECT, kernel_size)
        self.turn_neg = turn_neg
    def __call__(self,img):
        neg_img = cv2.bitwise_not(img)
        tophat_img = cv2.morphologyEx(neg_img, cv2.MORPH_TOPHAT, self.kernel,borderType=cv2.BORDER_REPLICATE)
        # tophat_img = morphology.white_tophat(neg_img, self.kernel) 
        return cv2.subtract(img, tophat_img)
        
# Augementations 
def normalize_xca(img, **kwargs):
    x = img.astype(np.float32, copy=False)
    m = x > 0
    if np.any(m):
        mean = x[m].mean()
        std  = x[m].std()
        x[m] = (x[m] - mean) / (std + 1e-8)
        x[~m] = 0.0
    else:
        x = x / 1.0
    return x

# recorder

In [None]:
class HistoryRecorder:
    def __init__(self,class_maps,losses_keys,class_count = 25):

        self.history = {
            "train":{},
            "valid":{}
        }
        for key in losses_keys:
            self.history["train"][key] = [[]]
            self.history["valid"][key] = [[]]
        self.losses_keys = losses_keys
        self.metric_history={}
        self.class_maps =class_maps
        for part in self.history :
            self.metric_history[part]={"dice":{},"precision":{},"recall":{}}
            for i in range(1,class_count+1):
                self.metric_history[part]["dice"][i]=[]
                self.metric_history[part]["recall"][i]=[]
                self.metric_history[part]["precision"][i]=[]
                
        self.metric_avg_list = {
            "train":{"dice":[],"precision":[],"recall":[]},
            "valid":{"dice":[],"precision":[],"recall":[]}
        }

        self.class_count = class_count 
        
    def add_losses(self,part,loss_dict):
        for loss_name,loss in loss_dict.items():
            self.history[part][loss_name][-1] += [loss]
        
    def add_metrics(self,dice,precision,recall,part):
        dice = dice[1:]
        precision = precision[1:]
        recall = recall[1:]
        for i in range(self.class_count):
            d = dice[i]
            r = recall[i]
            p = precision[i]
            self.metric_history[part]["dice"][i+1].append(d)
            self.metric_history[part]["recall"][i+1].append(r)
            self.metric_history[part]["precision"][i+1].append(p)
    def avg_losses(self,part):
        for key in self.history[part]:
            self.history[part][key][-1] = np.mean(self.history[part][key][-1])
            self.history[part][key].append([])
            
    def print_loss_report(self,part,epoch,avg_first=True):
        if(avg_first):
            self.avg_losses(part)
            
        report = f"{part} ==> epcoh ({epoch})\n"
        co=0
        for loss_name in self.losses_keys:
            loss_list = self.history[part][loss_name]

            loss = loss_list[-2]
            report += f"{loss_name} : {loss}"
            if((co+1)%3==0):
                report += "\n"
            else:
                report+=" - "
            co+=1  
                 
        print(report)
    def print_metrics_report(self,part,epoch,class_wise=False):
        
        report_temp = f"{part} avg metrics for epoch {epoch} :\n"
        report_class_wise_temp = ""
        avg_dice = 0
        avg_precision = 0
        avg_recall = 0
        for index , c in self.class_maps.items():
            dice = self.metric_history[part]["dice"][index][-1]
            precision = self.metric_history[part]["precision"][index][-1]
            recall = self.metric_history[part]["recall"][index][-1]
            avg_dice += dice
            avg_precision += precision
            avg_recall += recall
            if(class_wise):
                report_class_wise_temp += f"{c} => dice : {dice} p : {precision} , r : {recall}\n"
        
        avg_dice = avg_dice/self.class_count
        avg_precision = avg_precision/self.class_count
        avg_recall = avg_recall/self.class_count

        self.metric_avg_list[part]["dice"]+=[avg_dice]
        self.metric_avg_list[part]["precision"]+=[avg_precision]
        self.metric_avg_list[part]["recall"]+=[avg_recall]
        
        report_temp+=f"avg dice : {avg_dice} - avg precision : {avg_precision} - avg recall : {avg_recall}"

        if(class_wise):
            report_temp = report_temp + "\n" +report_class_wise_temp[:-1] #removing last \n
        print(report_temp)





# nnunet_blocks

In [None]:
def crop_dims():
    pass
class Conv(nn.Module):
    def __init__(self,in_c , out_c,p):
        super(Conv,self).__init__()
        self.layers =  nn.Sequential( 
            nn.Conv2d(
                in_channels = in_c , 
                out_channels = out_c ,
                kernel_size=3, 
                stride = 1 ,
                padding = p
            ),
            nn.InstanceNorm2d(out_c, eps=1e-5, affine=True),
            nn.LeakyReLU(negative_slope=1e-2, inplace=True)
        )
    def forward(self,x):
        return self.layers(x)
class DownsampleConv(nn.Module):
    def __init__(self,in_c , out_c):
        super(DownsampleConv,self).__init__()
        self.layers =  nn.Sequential( 
            nn.Conv2d(
                in_channels = in_c , 
                out_channels = out_c ,
                kernel_size=3, 
                stride = 2,
                padding=1
            ),
            nn.InstanceNorm2d(out_c, eps=1e-5, affine=True),
            nn.LeakyReLU(negative_slope=1e-2, inplace=True)
        )
    def forward(self,x):
        return self.layers(x)
class EncoderBlock(nn.Module):
    def __init__(self,in_c , out_c,p=1):
        super(EncoderBlock,self).__init__()
        self.layers = nn.Sequential(
            Conv(in_c = in_c , out_c = out_c , p=p),
            Conv(in_c = out_c , out_c=out_c ,p=p)
        )
        self.pool =  DownsampleConv(in_c = out_c , out_c=out_c )
    def forward(self,x):
        z = self.layers(x)
        return z , self.pool(z)

class DecoderBlock(nn.Module):
    def __init__(self,in_c ,out_c , f_int_scale,class_count,gate_c = None , attention=False,dsv=False):
        super(DecoderBlock,self).__init__()
        self.dsv=dsv
        self.conv1 = Conv(in_c=in_c , out_c=out_c,p=1)
        self.conv2 = Conv(in_c = out_c , out_c = out_c , p=1)
        self.upsampler = nn.ConvTranspose2d(
            in_channels = out_c , 
            out_channels = out_c//2 ,
            kernel_size=2 ,
            stride=2
        )

        if(self.dsv):
            self.dsv_block = nn.Conv2d(in_channels=out_c,out_channels=class_count,kernel_size=1)
        #in_c * 2 = gate_c
        if(attention):
            self.gate = AttentionGate(gate_in_c=gate_c ,f_int_scale=f_int_scale,skip_in_c=in_c//2)
        self.attention=attention
    def forward(self,x_in,x_skip,x_gate):
        if(self.attention):
            x_skip = self.gate(x_skip , x_gate)
        if(x_in.shape[2] != x_skip.shape[2] or x_in.shape[3] != x_skip.shape[3]):
            x_skip = crop_dims(x_in,x_skip)
        
        x = torch.cat([x_skip,x_in],dim=1)
        z = self.conv1(x)
        gate_z = self.conv2(z)
        upsampled_z = self.upsampler(gate_z)
        if(self.dsv):
            dsv_out = self.dsv_block(gate_z)
            
            if(self.attention):
               
                return upsampled_z , gate_z , dsv_out
            else:
                return upsampled_z , None, dsv_out 
        else:
            if(self.attention):
                return upsampled_z , gate_z , None
            else:
                return upsampled_z , None , None

class BottleNeck(nn.Module):
    def __init__(self,in_c , out_c,p,attention=False):
        super(BottleNeck,self).__init__()
        self.conv1 = Conv(in_c=in_c , out_c=out_c,p=p)
        self.conv2 = Conv(in_c = out_c , out_c = out_c , p=p)
        self.upsampler = nn.ConvTranspose2d(
            in_channels = out_c , 
            out_channels = out_c//2 ,
            kernel_size=2 ,
            stride=2
        )
        self.attention=attention
    def forward(self,x):
        z = self.conv1(x)
        gate_z = self.conv2(z)
        upsampled_z = self.upsampler(gate_z)
        if(self.attention):
            return upsampled_z , gate_z
        return upsampled_z , None

class AttentionGate(nn.Module):
    def __init__(self,gate_in_c,skip_in_c,f_int_scale=2,f_int=None,scaler="sigmoid"):
        super(AttentionGate,self).__init__()
        f_int = min(gate_in_c//f_int_scale,skip_in_c//f_int_scale) if f_int==None else f_int
        f_int = 1 if f_int == 0 else f_int
        self.conv_gate = nn.Conv2d(in_channels = gate_in_c , out_channels = f_int , 
                                   kernel_size = 1)
        self.conv_skip = nn.Conv2d(in_channels = skip_in_c , out_channels = f_int , 
                                   kernel_size = 1)
        self.relu = nn.ReLU(inplace=True)
        self.conv_shrink = nn.Conv2d(in_channels = f_int , out_channels = 1 ,
                                     kernel_size = 1)
        if(scaler =="sigmoid"):
            self.scaler = nn.Sigmoid()    
    def forward(self,x_skip,x_gate):
        x_gate_int = self.conv_gate(x_gate)
        x_skip_int = self.conv_skip(x_skip)
        # x_skip_int = crop_dims(x_gate_int,x_skip_int)
        if x_skip_int.shape[2:] != x_gate_int.shape[2:]:
            x_skip_int = F.interpolate(
                x_skip_int, 
                size=x_gate_int.shape[2:], 
                mode="bilinear", 
                align_corners=False
            )
        
        added_x = x_skip_int + x_gate_int
        relu_x = self.relu(added_x)
        shrinked_x = self.conv_shrink(relu_x)
        sig_x = self.scaler(shrinked_x)
        # padded_x = padd_dims(x_skip , sig_x)
        if sig_x.shape[2:] != x_skip.shape[2:]:
            padded_x = F.interpolate(
                sig_x, 
                size=x_skip.shape[2:], 
                mode="bilinear", 
                align_corners=False
            )

        return padded_x*x_skip

class Head(nn.Module):
    def __init__(self,in_c ,out_c ,class_count ,f_int_scale, 
        gate_c = None , attention=False):

        super(Head,self).__init__()
        self.conv1 = Conv(in_c=in_c , out_c=out_c,p=1)
        self.conv2 = Conv(in_c = out_c , out_c = out_c , p=1)
        self.conv1x1 = nn.Conv2d(
            in_channels = out_c , 
            out_channels = class_count ,
            kernel_size=1
        )
        
        if(attention):
            self.gate = AttentionGate(
                gate_in_c=gate_c , 
                f_int_scale=f_int_scale,
                skip_in_c=in_c//2
            )
        self.attention=attention
    def forward(self,x_in,x_skip,x_gate):
        if(self.attention):
            x_skip = self.gate(x_skip , x_gate)
        if(x_in.shape[2] != x_skip.shape[2] or x_in.shape[3] != x_skip.shape[3]):
            x_skip = crop_dims(x_in,x_skip)
            
        x = torch.cat([x_skip,x_in],dim=1)
        z = self.conv1(x)
        gate_z = self.conv2(z)
        
        class_feature_maps = self.conv1x1(gate_z) 
        
        return class_feature_maps

# nnunet

In [None]:
class nnUnet(nn.Module):
    def __init__(self,args,encoder_channel_settings=None,decoder_channel_settings=None):
        super(nnUnet,self).__init__()
        
        in_c = args["in_c"]
        class_count = args["class_count"]
        attention = args["attention"]
        image_shape = args["image_shape"]
        base_channel = args["base_channel"]
        f_int_scale = args["f_int_scale"]
        max_channels = args["max_channels"]
        input_channels = args["input_channels"]
        self.deep_super_vision = args["deep_super_vision"]
        h = image_shape[0]
        w = image_shape[1]
        
        max_pool_count = 0
        co=0
        
        while(w>4 and h>4):
            w/=2
            h/=2
            co+=1
        print(f"number of layers : {co}")

        # create encoder settings 
        if(encoder_channel_settings is None):
            self.encoder_channel_settings = [base_channel]
            for i in range(co-1):
                new_c =min(self.encoder_channel_settings[i]*2,max_channels)
                self.encoder_channel_settings +=[new_c]
        else :
            self.encoder_channel_settings = encoder_channel_settings
        
        # create bottleneck settings
        self.bottle_neck_channel_setting = self.encoder_channel_settings[-1]*2
        # create decoder settings 
        if(decoder_channel_settings is  None):
            self.decoder_channel_settings =[]
            for i in range(co-1):
                self.decoder_channel_settings = [self.encoder_channel_settings[i]*2] +  self.decoder_channel_settings
        else :
            self.decoder_channel_settings = decoder_channel_settings

        
        # build encoder
        self.encoders = nn.ModuleList()
        for i in range(co):
            output_channels = self.encoder_channel_settings[i]
            self.encoders.append(EncoderBlock(in_c=input_channels,out_c=output_channels , p=1))
            input_channels = output_channels
        # build bottleneck

        self.bottle_neck = BottleNeck(in_c = output_channels ,out_c = self.bottle_neck_channel_setting , p=1,attention = attention)
        #build decoder
        input_channels = self.bottle_neck_channel_setting
        self.decoders = []
        for i in range(co-1):
            
            output_channels = self.decoder_channel_settings[i]

            self.decoders = [
                DecoderBlock(
                    in_c = input_channels , 
                    out_c=output_channels , 
                    gate_c = input_channels , 
                    attention = attention,
                    f_int_scale=f_int_scale,
                    dsv = self.deep_super_vision,
                    class_count=class_count
                )] + self.decoders
            
            input_channels = output_channels


        self.decoders = nn.ModuleList(self.decoders)
        self.attention = attention
        
        self.head = Head(
            in_c = input_channels , 
            out_c=input_channels//2 ,
            class_count = class_count,
            gate_c = input_channels , 
            attention = False,
            f_int_scale=f_int_scale
        )
        print("encoder settings : ", self.encoder_channel_settings)
        print("bottle-neck settings : ", self.bottle_neck_channel_setting)
        print("decoder settings : ", self.decoder_channel_settings)
        print("head settings : ",class_count)
    def forward(self,x):
        skips = []
        for encoder in self.encoders : 
            skip , out = encoder(x)
            skips += [skip]
            x = out
        x_in,gate_in = self.bottle_neck(x)
        
        outputs = []
        # print(len(self.decoders))
        for i in range(len(self.decoders) - 1, -1, -1):
            # print("2")
            decoder = self.decoders[i]
            skip = skips[i+1]
            x_out,gate_out,dsv_out = decoder(x_in,skip,gate_in)
            
            if(dsv_out!=None):
                outputs = [dsv_out] + outputs
            x_in=x_out
            gate_in=gate_out
        # print(x_in.shape)
        outputs = [self.head(
            x_in = x_in,
            x_skip = skips[0],
            x_gate = gate_in
        )] + outputs
        return outputs
if __name__ == "__main__":
    args = {
        "base_path" : "../arcade/nnUnet_dataset/syntax",
        "in_c" : 1,
        "base_channel" :32,
        "image_shape" : (512,512),
        "class_count" : 26 ,
        "attention" : True,
        "k":40,
        "batch_size" : 10,
        "num_workers" : 10,
        "device" : "cuda" if torch.cuda.is_available() else "cpu",
        "lr" : 0.01,
        "momentum" : 0.99,
        "weight_decay" : 3e-5,
        "epcohs":30,
        "f_int_scale" : 2,
        "full_report_cycle" : 10,
        "max_channels":512,
        "input_channels":1,
        "loss_type":"dice loss",
        "alpha":0.75,
        "beta":0.25,
        "gamma":1.00,
        "f_gamma":2.0,
        "f_loss_scale":1,
        "loss_coefs":{"CE":1.0,"Second":1.0},
        "output_base_path" : "./outputs",
        "name" : "Attention7-AllClass",
        "deep_super_vision" : True
    }
    class_map = {
        1: '1',2: '2', 3: '3',4: '4',
        5: '5',6: '6',7: '7',8: '8',
        9: '9',10: '9a',11: '10',12: '10a',
        13: '11',14: '12',15: '12a',16: '13',
        17: '14',18: '14a',19: '15',20: '16',
        21: '16a',22: '16b',23: '16c',
        24: '12b',25: '14b'
    }
    model = nnUnet(args).to("cuda")
    ls = torch.ones((10,1,512,512)).float().to("cuda")
    outs = model(ls)
    # for out in outs:
    #     print(out.shape)

    # """
    # torch.Size([10, 32, 256, 256])
    # torch.Size([10, 64, 128, 128])
    # torch.Size([10, 128, 64, 64])
    # torch.Size([10, 256, 32, 32])
    # torch.Size([10, 512, 16, 16])
    # torch.Size([10, 512, 8, 8])
    # torch.Size([10, 512, 4, 4])
    # """

# losses

In [None]:
class UnetLoss(nn.Module):
    def __init__(self,args,eps = 1e-8):
        super(UnetLoss,self).__init__()
        self.class_count = args["class_count"]
        self.loss_type = args["loss_type"]
        self.alpha = args["alpha"]
        self.beta = args["beta"]
        self.t_gamma = args["t_gamma"]
        self.f_gamma = args["f_gamma"]
        self.k = args["k"]
        self.loss_coefs = args["loss_coefs"]
        # self.focal_fn = FocalCrossEntropy(
        #     f_gamma=self.f_gamma,
        #     eps=eps,
        #     f_alpha=args["f_alpha"],
        #     f_loss_scale = args["f_loss_scale"]
        # )
        if(args["f_alpha"] is not None):
            w = torch.tensor(args["f_alpha"],dtype=torch.float32,device="cuda")
            self.ce_fn = nn.CrossEntropyLoss(weight=w)
        else :
            self.ce_fn = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax(dim=1)
        self.eps = eps
        self.sum_dims = (0,2,3)
        if(self.loss_type=="dice loss"):
            print("loss is set to dice")
            self.loss_fn = DiceLoss(self.eps,self.sum_dims)
        elif(self.loss_type=="tversky loss"):
            print("loss is set to tversky")
            self.loss_fn = TverskyLoss(self.eps,self.sum_dims,self.alpha,self.beta,self.t_gamma)
        self.cldice_fn = CLDiceLoss(sum_dims=self.sum_dims,eps=self.eps,k=self.k) 
    def forward(self,pred_mask , gt_mask ):
        
        onehot_mask = F.one_hot(gt_mask, num_classes=self.class_count)
        onehot_mask = onehot_mask.permute(0, 3, 1, 2).float()  

        prob = self.softmax(pred_mask)

        # Cross Entropy Loss
        ce_loss = self.ce_fn(pred_mask,gt_mask)
        # Dice/Tversky Loss
        forground_prob = prob[:,1:]
        forground_onehot_mask = onehot_mask[:,1:]
        # present_class = forground_onehot_mask.sum(dim=self.sum_dims)>0
        
        second_loss = self.loss_fn(
            pred_probs = forground_prob,
            gt = forground_onehot_mask
        )

        ce_loss = self.loss_coefs["CE"]*ce_loss
        second_loss = self.loss_coefs["Second"]*second_loss

        total_loss = second_loss + ce_loss

        loss_dict = {
            "CE loss" : ce_loss,
            self.loss_type : second_loss
        }
        return total_loss , loss_dict

class FocalCrossEntropy(nn.Module):
    def __init__(self,f_gamma,eps,f_loss_scale=1,f_alpha=None):
        super(FocalCrossEntropy,self).__init__()
        self.f_gamma = f_gamma
        if(f_alpha is not None):
            device = "cuda" if torch.cuda.is_available() else "cpu"
            self.f_alpha = torch.tensor(f_alpha).to(device)
        else:
            self.f_alpha = f_alpha
        self.eps = eps
        self.f_loss_scale = f_loss_scale
    def forward(self,prob,onehot_mask):
        # prob : (B,C,H,W)
        # onehot_mask : (B,C,H,W)
        # gt_mask = (B,H,W)

        p = (prob*onehot_mask).sum(dim=1) # (B,H,W)
        pt = torch.clamp(p,self.eps,1-self.eps)
        focal_weights = (1-pt)**self.f_gamma
        focal_loss = focal_weights*(torch.log(pt))
        if(self.f_alpha is not None):
            alpha_b = self.f_alpha.view(1, -1, 1, 1).type_as(prob)
            class_w = (alpha_b*onehot_mask).sum(dim=1)
        else :
            class_w = 1.0
        return -self.f_loss_scale*(class_w*focal_loss).mean()

class CLDiceLoss(nn.Module):
    def __init__(self,eps,sum_dims,k=40):
        super(CLDiceLoss,self).__init__()
        self.k=k
        self.eps = eps
        self.sum_dims = (1,2)

    def forward(self,pred_binary_mask , gt_mask,gt_skel):

        binary_pred = (pred_binary_mask>=0.5).type_as(pred_binary_mask)
        binary_gt = (gt_mask!=0).type_as(gt_mask)

        pred_skel = soft_skeletonize(binary_pred,k=self.k)

        t_prec = (pred_skel*binary_gt + self.eps).sum(dim=self.sum_dims)/(pred_skel.sum(dim=self.sum_dims) +self.eps)
        t_rec = (gt_skel*binary_pred + self.eps).sum(dim=self.sum_dims)/(gt_skel.sum(dim=self.sum_dims) +self.eps)
        
        cldice = 2*((t_prec*t_rec)/(t_prec+t_rec))
        cldice_loss = 1 - cldice.mean()
        return cldice_loss
class DiceLoss(nn.Module):
    def __init__(self,eps,sum_dims):
        super(DiceLoss,self).__init__()
        self.eps = eps
        self.sum_dims = sum_dims
    def forward(self,pred_probs,gt):
        tp = (gt * pred_probs).sum(dim=self.sum_dims)
        fp = ((1-gt)*pred_probs).sum(dim=self.sum_dims)
        fn = ((1-pred_probs)*gt).sum(dim=self.sum_dims)
        per_class_dice_score = (2*tp +self.eps)/(2*tp + fp + fn + self.eps)
        # if(present_class is None):
        #     dice_loss = -per_class_dice_score.mean()
        # else:
        #     dice_loss = -per_class_dice_score[present_class].mean()
        dice_loss = -per_class_dice_score.mean()
        return dice_loss

class TverskyLoss(nn.Module):
    def __init__(self,eps,sum_dims,alpha=0.3,beta=0.7,gamma=1.33):
        super(TverskyLoss,self).__init__()
        self.eps = eps
        self.sum_dims = sum_dims
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
    def forward(self,pred_probs,gt):
        tp = (gt * pred_probs).sum(dim=self.sum_dims)
        fp = ((1-gt)*pred_probs).sum(dim=self.sum_dims)
        fn = ((1-pred_probs)*gt).sum(dim=self.sum_dims)
        t_index = (tp + self.eps) / (tp + self.alpha*fp + self.beta*fn + self.eps) 

        t_index = t_index.mean()
        
        return (1 - t_index)**self.gamma 

    
    

# trainer

In [None]:
def train_fn(model,img,gt_masks,optimizer,loss_fn,scaler,args,device,loss_weights=[1]):
    optimizer.zero_grad()
    loss_dict={}
    with torch.autocast(device_type=args["device"],dtype=torch.float16):
        pred_masks =  model(img)
        loss = 0
        for i,pred_mask in enumerate(pred_masks) : 

            gt_mask = gt_masks[i].to(device)
            loss_weight = loss_weights[i]
            layer_loss , layer_loss_dict = loss_fn(pred_mask , gt_mask)
            if(i==0):
                loss_dict = layer_loss_dict
                pred_mask_last = pred_mask.detach()
                gt_mask_last = gt_mask
            else:
                loss_dict = {key : loss_dict[key]+ loss_weight*layer_loss_dict[key] for key in layer_loss_dict}

            loss += loss_weight*layer_loss

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

    if random.random() < 0.01:
        with torch.no_grad():
            print("--- Total Norm ---")
            print(total_norm)
            # print("\n--- Gradient norms ---")
            # for name, param in model.named_parameters():
            #     if param.grad is not None:
            #         grad_norm = param.grad.data.norm().item()
            #         print(f"{name:30s}: {grad_norm:.6f}")
            # print("----------------------\n")
    scaler.step(optimizer)
    scaler.update()
    
    loss = loss.detach().cpu().item()
    for loss_name in loss_dict:
        loss_dict[loss_name] = loss_dict[loss_name].detach().cpu().item()
    
    loss_dict["total loss"] = loss
    return loss_dict , pred_mask_last , gt_mask_last
    
def trainer(args,recorder,model,optimizer,loss_fn,train_loader,valid_loader,loss_weights=[1]):
    device = args["device"]
    epcohs = args["epcohs"]
    class_count = args["class_count"]
    full_report_cycle = args["full_report_cycle"]
    scaler = torch.amp.GradScaler(device = device) 
    for ep in tqdm(range(epcohs)):
        total_TP =  torch.zeros(class_count)
        total_FP = torch.zeros(class_count)
        total_FN = torch.zeros(class_count)
        
        model.train()
        class_wise_report = False
        for img , gt_masks  in tqdm(train_loader) : 
            # gt_mask = gt_mask.long()
            img = img.to(device)
            # gt_mask = gt_mask.to(device)

            loss_dict , pred_mask , gt_mask = train_fn(
                model = model,
                img = img,
                gt_masks = gt_masks,
                optimizer = optimizer,
                loss_fn = loss_fn,
                scaler = scaler,
                args = args,
                device = device,
                loss_weights = loss_weights
            )
            
            TP , _ , FP , FN = TP_TN_FP_FN(pred_mask,gt_mask,process_preds=True)
            total_TP += TP
            total_FP += FP
            total_FN += FN
            
            recorder.add_losses("train",loss_dict)
            
        dice_score = (2 * total_TP + 1e-8) / (2 * total_TP + total_FP + total_FN + 1e-8)
        precision = total_TP /(total_FP + total_TP + 1e-8) 
        recall = total_TP /(total_FN + total_TP + 1e-8) 
        
        recorder.add_metrics(
            dice_score.tolist(),
            precision.tolist(),
            recall.tolist(),
            part = "train"
        )
        
        recorder.print_loss_report("train",ep)
        recorder.print_metrics_report("train",ep,class_wise=False)
        print("<=>"*20)
        
        if((ep+1)%full_report_cycle==0):
            class_wise_report=True
            
        evaluation(
            recorder=recorder,
            model=model,
            loss_fn=loss_fn,
            valid_loader=valid_loader,
            class_wise_report=class_wise_report,
            class_count = class_count,
            epoch=ep,
            device=device)
            
@torch.no_grad()
def evaluation(recorder,model,loss_fn,valid_loader,class_count,class_wise_report=False,epoch=None,device="cuda"):
    model.eval()
    total_TP = total_TP = torch.zeros(class_count)
    total_FP = torch.zeros(class_count)
    total_FN = torch.zeros(class_count)
    
    for img , gt_mask  in valid_loader:
        img = img.to(device)
        gt_mask = gt_mask[0].to(device)
        
        with torch.autocast(device_type=device,dtype=torch.float16):
            pred_mask  = model(img)[0]
            loss , loss_dict = loss_fn(pred_mask , gt_mask)

            loss = loss.detach().cpu().item()
            for loss_name in loss_dict:
                loss_dict[loss_name] = loss_dict[loss_name].detach().cpu().item()
        
            loss_dict["total loss"] = loss
        
        TP , _ , FP , FN = TP_TN_FP_FN(pred_mask,gt_mask,process_preds=True)
        total_TP += TP
        total_FP += FP
        total_FN += FN
        
        recorder.add_losses("valid",loss_dict)
        
    dice_score = (2 * total_TP + 1e-8) / (2 * total_TP + total_FP + total_FN + 1e-8)
    precision = total_TP /(total_FP + total_TP + 1e-8) 
    recall = total_TP /(total_FN + total_TP + 1e-8) 

    recorder.add_metrics(
        dice_score.tolist(),
        precision.tolist(),
        recall.tolist(),
        part = "valid"
    )
    recorder.print_loss_report("valid",epoch)
    recorder.print_metrics_report("valid",epoch,class_wise=class_wise_report)
    print("-"*60)

# temp_script

In [None]:


# # Training

# In[2]:


args = {
    "base_path" : "../arcade/nnUnet_dataset/syntax",
    "in_c" : 1,
    "base_channel" :32,
    "image_shape" : (512,512),
    "class_count" : 26 ,
    "attention" : True,
    "k":40,
    "batch_size" : 10,
    "num_workers" : 10,
    "device" : "cuda" if torch.cuda.is_available() else "cpu",
    "lr" : 0.001,
    "momentum" : 0.99,
    "weight_decay" : 3e-5,
    "epcohs":100,
    "f_int_scale" : 2,
    "full_report_cycle" : 10,
    "max_channels":512,
    "input_channels":1,
    "loss_type":"tversky loss",
    "alpha":0.3,
    "beta":0.7,
    "t_gamma":2.00,
    "f_gamma":2.0,
    "f_loss_scale":1,
    "loss_coefs":{"CE":1.0,"Second":1.0},
    "output_base_path" : "./outputs",
    "name" : "Attention7-DSV-tev",
    "deep_super_vision" : True,
    "f_alpha":None
}
class_map = {
    1: '1',2: '2', 3: '3',4: '4',
    5: '5',6: '6',7: '7',8: '8',
    9: '9',10: '9a',11: '10',12: '10a',
    13: '11',14: '12',15: '12a',16: '13',
    17: '14',18: '14a',19: '15',20: '16',
    21: '16a',22: '16b',23: '16c',
    24: '12b',25: '14b'
}
# losses_keys = ["total loss","FCE loss",args["loss_type"]]
losses_keys = ["total loss","CE loss",args["loss_type"]]
out_counts = 7 if args["deep_super_vision"] else 1
loss_weights = [1/(2**i) for i in range(out_counts)]
loss_weights


# In[3]:


# with open("./data/train_class_counts.json","r") as f:
#     train_class_counts = json.load(f)

# b = 0.999999

# counts = [0]*(len(train_class_counts))
# for k,v in train_class_counts.items():
#     counts[int(k)] = int(v)
# counts = np.array(counts,dtype=np.float64)

# f_alpha = (1-b)/(1-np.power(b,counts))
# f_alpha = f_alpha / f_alpha.sum()
# f_alpha[12] = 0.25
# # args["f_alpha"] = f_alpha.tolist()
# args["f_alpha"]=None
# args["f_alpha"]


# In[4]:


b=0.999
train_class_counts = [
    1200,374,375,369,303,525,525,
    340,310,198,70,21,1,320,61,
    129,305,107,49,38,232,43,48,31,63,127
]
# f_alpha = (1-b)/(1-np.power(b,train_class_counts))
total = np.sum(train_class_counts)
f_alpha = np.log(total/np.array(train_class_counts))
f_alpha = (f_alpha / f_alpha.mean()).tolist()
f_alpha
args["f_alpha"] = f_alpha
f_alpha


# In[5]:


# pre_soft_skeletonize(args["base_path"],output_path=args["base_path"],batch_size=10,k=40)


# In[6]:


train_transforms = A.Compose([
    A.GaussianBlur(
        sigma_limit=[0.1,0.5],
        p=0.5
    ),
    A.RandomBrightnessContrast(
        brightness_limit=0.1,
        contrast_limit=0.15,
        brightness_by_max=True,
        p=0.3
    ),
    A.RandomGamma(
        gamma_limit=(90, 120), 
        p=0.3
    ),
    A.Rotate(limit=15, p=0.3 , fill_mask = 0),
    A.HorizontalFlip(p=0.3),
    A.VerticalFlip(p=0.3),
    A.Lambda(image=normalize_xca),
    ]
)
test_transforms = A.Compose([
    A.Lambda(image=normalize_xca),
    ToTensorV2()
    ]  
)
# train_preprocess = v2.Compose([
#     WhiteTopHat(kernel_size=(50,50)),
#     CLAHE()

# ])
train_preprocess = None


# In[7]:


train_images = read_images(base_path = args["base_path"],preprocessor = train_preprocess,part = "train")
valid_images = read_images(base_path = args["base_path"],preprocessor = train_preprocess,part = "val")

train_ds = UnetDataset(transform = train_transforms,data = train_images,base_size=args["image_shape"][0],out_counts=out_counts)
valid_ds = ValidUnetDataset(transform = test_transforms,data = valid_images)

train_loader = DataLoader(
    train_ds,
    batch_size = args["batch_size"] ,
    num_workers = args["num_workers"] ,
    pin_memory=True,
    shuffle=True
)
valid_loader = DataLoader(
    valid_ds,
    batch_size = args["batch_size"] ,
    num_workers = args["num_workers"] ,
    pin_memory=True,
    shuffle=False,
)


# In[8]:


# plot_some_images(train_images, train_transforms, image_counts=36, fig_shape=(6,6), base_transforms=test_transforms)


# In[ ]:


model = nnUnet(args).to(args["device"])
loss_fn = UnetLoss(args)
optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])
# optimizer = torch.optim.SGD(
#     model.parameters(),
#     momentum=args["momentum"],
#     lr=args["lr"],
#     nesterov=True,
#     weight_decay=args["weight_decay"]
# )
recorder = HistoryRecorder(losses_keys=losses_keys,class_maps =class_map)

trainer(
    args=args,
    recorder = recorder,
    model = model,
    optimizer = optimizer,
    loss_fn = loss_fn,
    train_loader = train_loader,
    valid_loader = valid_loader,
    loss_weights=loss_weights)


# In[ ]:


# torch.save(model.state_dict(), "model.pth")


# In[ ]:


save_full_report(
    recorder= recorder , 
    output_base_path=args["output_base_path"],
    model=model,
    valid_loader=valid_loader,
    args=args,
    class_map=class_map,
    name=args["name"]
)
