In [1]:
import gc
import numpy as np
import os
import shutil
import cv2
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import sys

from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from torch import optim , Tensor
from tqdm import tqdm
from os import listdir
from os.path import isfile, join
from pathlib import Path
from copy import deepcopy
from PIL import Image

from Utils import *
from Loss_Functions import *
from Aug import *
from U_Net import *


In [18]:
def train_model(
        input_ch,
        model,
        device,
        augmentation,
        use_aug_data,
        epochs,
        T_max,
        train_batch_size,
        val_batch_size,
        test_batch_size,
        learning_rate,
        eta_min,
        momentum,
        weight_decay,
        size,
        alpha_Focal,
        gamma_Focal,
        alpha_FocalTversky,
        gamma_FocalTversky,
        alpha_WBCE,
        reduction,
        use_loss1,
        use_loss2,
        use_loss3,
        use_loss4,
        use_loss5,
        use_loss6,
        first_norm,
        optimizer_type,
        scheduler_type,
        time_str):

    calculate_loss = criterion(alpha_Focal,gamma_Focal,alpha_FocalTversky,gamma_FocalTversky,alpha_WBCE,reduction,
                               use_loss1,use_loss2,use_loss3,use_loss4,use_loss5,use_loss6)

    #------------------------------------------------------------------------------------------------

    if optimizer_type == 'Adam':
        optimizer = optim.Adam(model.parameters(),lr=learning_rate, weight_decay=weight_decay)
    if optimizer_type == 'AdamW':
        optimizer = optim.AdamW(model.parameters(),lr=learning_rate, weight_decay=weight_decay)
    if optimizer_type == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)

    if scheduler_type == 'Reduce':
      scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, cooldown=1)
    if scheduler_type == 'Cosine':
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=eta_min)

    #------------------------------------------------------------------------------------------------
    num_params_text = 'num params:' + str(get_num_params(model)) + '\n'
    print(num_params_text)
    dir = f'/content/drive/MyDrive/DENTISTRY/MODELS/{time_str}'
    os.makedirs(dir,exist_ok=True)

    #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

    min_val_loss = float('inf')
    max_val_iou = 0

    for epoch in range(1, epochs + 1):

        epoch_text = 'epoch:' + str(epoch)
        lr_text = 'lr:' + str(optimizer.state_dict()['param_groups'][0]['lr'])
        print(epoch_text)
        print(lr_text)

        #------------------------------------------------------------------------------------------------

        train_loss = 0
        train_iou = 0
        train_dice = 0
        train_sens = 0
        train_spec = 0
        train_precision = 0
        train_F1 = 0
        num_train_samples = 0

        model.train()
        enable_grad(model)

        if use_aug_data:
            train_dataset = dataset(images_path = aug_train_images_path, masks_path = aug_train_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)
        else:
            train_dataset = dataset(images_path = train_images_path, masks_path = train_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)
        train_data = DataLoader(dataset=train_dataset,batch_size=train_batch_size,shuffle=True)

        for batch in tqdm(train_data):

            images, masks = batch['images'], batch['masks']
            if augmentation == True:
                images, masks = augment(images, masks, input_ch = input_ch, size = size)
            images = images.to(device=device)
            masks = masks.to(device=device)
            num_train_samples = num_train_samples + len(images)

            optimizer.zero_grad()
            pred = model(images)

            loss = calculate_loss(pred, masks)
            train_loss += loss*len(images)

            pred_thres = (pred>=0.5).float()
            iou, dice, sens, spec, precision, F1 = calculate_metrics(pred_thres, masks)
            train_iou += iou
            train_dice += dice
            train_sens += sens
            train_spec += spec
            train_precision += precision
            train_F1 += F1

            loss.backward()
            optimizer.step()

        train_iou = train_iou.item()/num_train_samples
        train_dice = train_dice.item()/num_train_samples
        train_sens = train_sens.item()/num_train_samples
        train_spec = train_spec.item()/num_train_samples
        train_precision = train_precision.item()/num_train_samples
        train_F1 = train_F1.item()/num_train_samples

        #------------------------------------------------------------------------------------------------

        val_loss = 0
        val_iou = 0
        val_dice = 0
        val_sens = 0
        val_spec = 0
        val_precision = 0
        val_F1 = 0
        num_val_samples = 0

        disable_grad(model)

        val_dataset = dataset(images_path = val_images_path, masks_path = val_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)
        val_data = DataLoader(dataset=val_dataset,batch_size=val_batch_size,shuffle=False)

        with torch.no_grad():
            for batch in tqdm(val_data):

                images, masks = batch['images'], batch['masks']
                images = images.to(device=device)
                masks = masks.to(device=device)

                num_val_samples = num_val_samples + len(images)
                pred = model(images)
                loss = calculate_loss(pred, masks)
                val_loss += loss*len(images)

                pred_thres = (pred>=0.5).float()
                iou, dice, sens, spec, precision, F1 = calculate_metrics(pred_thres, masks)
                val_iou += iou
                val_dice += dice
                val_sens += sens
                val_spec += spec
                val_precision += precision
                val_F1 += F1

            val_iou = val_iou.item()/num_val_samples
            val_dice = val_dice.item()/num_val_samples
            val_sens = val_sens.item()/num_val_samples
            val_spec = val_spec.item()/num_val_samples
            val_precision = val_precision.item()/num_val_samples
            val_F1 = val_F1.item()/num_val_samples

        if scheduler_type == 'Reduce':
          scheduler.step(val_iou)
        if scheduler_type == 'Cosine':
          scheduler.step()
        #------------------------------------------------------------------------------------------------

        train_loss = train_loss.item()/num_train_samples
        val_loss = val_loss.item()/num_val_samples

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            max_val_iou = val_iou
            model_params = model.state_dict()
            torch.save(model_params,dir+'/best_model.pth')

        print('\n')
        print('Train Set')
        print(f'Loss: {train_loss:0.6f} - iou: {train_iou:0.6f} - dice: {train_dice:0.6f} - sens: {train_sens:0.6f} - precision: {train_precision:0.6f} - F1: {train_F1:0.6f}')

        print('Val Set')
        print(f'Loss: {val_loss:0.6f} - iou: {val_iou:0.6f} - dice: {val_dice:0.6f} - sens: {val_sens:0.6f} - precision: {val_precision:0.6f} - F1: {val_F1:0.6f} - best loss: {min_val_loss:0.6f} - best iou: {max_val_iou:0.6f}')
        print('------------------------------------------------------------------------------------------------------------------')

    #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

    model_params = model.state_dict()
    torch.save(model_params,dir+'/final_model.pth')

    #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

    test_dataset = dataset(images_path = test_images_path, masks_path = test_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)
    test_data = DataLoader(dataset=test_dataset,batch_size=test_batch_size,shuffle=False)

    test_loss = 0
    test_iou = 0
    test_dice = 0
    test_sens = 0
    test_spec = 0
    test_precision = 0
    test_F1 = 0
    num_test_samples = 0

    model.load_state_dict(torch.load(dir+'/best_model.pth'))
    disable_grad(model)

    with torch.no_grad():
        for batch in tqdm(test_data):

            images, masks = batch['images'], batch['masks']
            images = images.to(device=device)
            masks = masks.to(device=device)

            num_test_samples = num_test_samples + len(images)
            pred = model(images)
            loss = calculate_loss(pred, masks)
            test_loss += loss*len(images)

            pred_thres = (pred>=0.5).float()
            iou, dice, sens, spec, precision, F1 = calculate_metrics(pred_thres, masks)
            test_iou += iou
            test_dice += dice
            test_sens += sens
            test_spec += spec
            test_precision += precision
            test_F1 += F1

        test_iou = test_iou.item()/num_test_samples
        test_dice = test_dice.item()/num_test_samples
        test_sens = test_sens.item()/num_test_samples
        test_spec = test_spec.item()/num_test_samples
        test_precision = test_precision.item()/num_test_samples
        test_F1 = test_F1.item()/num_test_samples
        test_loss = test_loss.item()/num_test_samples

        print('\n')
        print('Test Set')
        print(f'Loss: {test_loss:0.6f} - iou: {test_iou:0.6f} - dice: {test_dice:0.6f} - sens: {test_sens:0.6f} - precision: {test_precision:0.6f} - F1: {test_F1:0.6f}')
        print('------------------------------------------------------------------------------------------------------------------')


In [None]:
dataset_name = 'NEW_DATASET_3'

aug_train_images_path = f"/content/{dataset_name}/AUG-TRAIN-IMAGES"
aug_train_masks_path = f"/content/{dataset_name}/AUG-TRAIN-MASKS"

train_images_path = f"/content/{dataset_name}/TRAIN-IMAGES"
train_masks_path = f"/content/{dataset_name}/TRAIN-MASKS"

val_images_path = f"/content/{dataset_name}/VAL-IMAGES"
val_labels_path = f"/content/{dataset_name}/VAL-LABELS"
val_masks_path = f"/content/{dataset_name}/VAL-MASKS"

test_images_path = f"/content/{dataset_name}/TEST-IMAGES"
test_labels_path = f"/content/{dataset_name}/TEST-LABELS"
test_masks_path = f"/content/{dataset_name}/TEST-MASKS"

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

current_time = datetime.datetime.now()
Year = current_time.year
Month = current_time.month
Day = current_time.day
Hour = current_time.hour
Minute = current_time.minute
Second = current_time.second

time_str = f'{Year}-{Month:02}-{Day:02} -- {Hour:02}-{Minute:02}-{Second:02}'
print(time_str)

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

input_ch = 1
dropout = 0.1
size = (256,256)
batch_size = 20

alpha_Focal = 3
gamma_Focal = 1
alpha_WBCE = 1
alpha_FocalTversky = 0.3
gamma_FocalTversky = 1
reduction = 'mean'

use_loss1= bool(1)
use_loss2= bool(0)
use_loss3= bool(0)
use_loss4= bool(1)
use_loss5= bool(0)
use_loss6= bool(0)

augmentation = bool(1)
use_aug_data = bool(0)

optimizer_type='SGD'
scheduler_type='Reduce'

first_norm = bool(1)
torch.cuda.empty_cache()

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

d = 32
device = torch.device('cuda')
model = UNet(in_channels=input_ch,d1=d,d2=2*d,d3=4*d,d4=8*d,d5=16*d,kernel1=3,padding1=1,kernel2=3,padding2=1,device=device,dtype = torch.float32, dropout=dropout, first_norm=False)

model.to(device)
#model.load_state_dict(torch.load('/content/drive/MyDrive/DENTISTRY/MODELS/2024-07-06 -- 22-35-03/best_model.pth'))
#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

train_model(
    input_ch = input_ch,
    model = model,
    device = device,
    augmentation = augmentation,
    use_aug_data = use_aug_data,
    time_str = time_str,
    train_batch_size = batch_size,
    val_batch_size = batch_size,
    test_batch_size = batch_size,
    size=size,
    alpha_Focal = alpha_Focal,
    gamma_Focal = gamma_Focal,
    alpha_FocalTversky = alpha_FocalTversky,
    gamma_FocalTversky = gamma_FocalTversky,
    alpha_WBCE = alpha_WBCE,
    reduction = reduction,
    use_loss1 = use_loss1,
    use_loss2 = use_loss2,
    use_loss3 = use_loss3,
    use_loss4 = use_loss4,
    use_loss5 = use_loss5,
    use_loss6 = use_loss6,
    first_norm = first_norm,
    optimizer_type = optimizer_type,
    scheduler_type = scheduler_type,
    epochs = 50,
    learning_rate = 1e-1,
    eta_min = 1e-5,
    momentum = 0.9,
    weight_decay = 1e-4,
    T_max = 100)


---

In [None]:
dataset_name = 'NEW_DATASET_3'

aug_train_images_path = f"/content/{dataset_name}/AUG-TRAIN-IMAGES"
aug_train_masks_path = f"/content/{dataset_name}/AUG-TRAIN-MASKS"

train_images_path = f"/content/{dataset_name}/TRAIN-IMAGES"
train_masks_path = f"/content/{dataset_name}/TRAIN-MASKS"

val_images_path = f"/content/{dataset_name}/VAL-IMAGES"
val_labels_path = f"/content/{dataset_name}/VAL-LABELS"
val_masks_path = f"/content/{dataset_name}/VAL-MASKS"

test_images_path = f"/content/{dataset_name}/TEST-IMAGES"
test_labels_path = f"/content/{dataset_name}/TEST-LABELS"
test_masks_path = f"/content/{dataset_name}/TEST-MASKS"

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

input_ch = 1
dropout = 0.1
size = (256,256)
batch_size = 20

alpha_Focal = 3
gamma_Focal = 1
alpha_WBCE = 1
alpha_FocalTversky = 0.3
gamma_FocalTversky = 1
reduction = 'mean'

use_loss1= bool(1)
use_loss2= bool(0)
use_loss3= bool(0)
use_loss4= bool(1)
use_loss5= bool(0)
use_loss6= bool(0)

augmentation = bool(1)
use_aug_data = bool(0)

optimizer_type='SGD'
scheduler_type='Reduce'

first_norm = bool(1)
torch.cuda.empty_cache()

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

d = 32
device = torch.device('cuda')
model = UNet(in_channels=input_ch,d1=d,d2=2*d,d3=4*d,d4=8*d,d5=16*d,kernel1=3,padding1=1,kernel2=3,padding2=1,device=device,dtype = torch.float32, dropout=dropout, first_norm=False)

model.to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/DENTISTRY/MODELS/2024-07-13 -- 21-33-29/best_model.pth'))

#d = 32
#device = torch.device('cuda')
#model = UNet(in_channels=input_ch,d1=d,d2=2*d,d3=4*d,d4=8*d,d5=16*d,kernel1=3,padding1=1,kernel2=3,padding2=1,device=device,dtype = torch.float32, dropout=dropout, first_norm=True)
#model.to(device)

#dir = f'/content/drive/MyDrive/DENTISTRY/MODELS/{time_str}'
#model.load_state_dict(torch.load(dir+'/final_model.pth'))
#model.load_state_dict(torch.load(dir+'/best_model.pth'))

#--------------------------------------------------------------------------------------------------------------------------------------------

train_dataset = dataset(images_path = train_images_path, masks_path = train_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)
val_dataset = dataset(images_path = val_images_path, masks_path = val_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)
test_dataset = dataset(images_path = test_images_path, masks_path = test_masks_path, input_ch = input_ch, size=size, normalize=not first_norm)

train_data = DataLoader(dataset=train_dataset,batch_size=5,shuffle=False)
val_data = DataLoader(dataset=val_dataset,batch_size=5,shuffle=False)
test_data = DataLoader(dataset=test_dataset,batch_size=5,shuffle=False)

#--------------------------------------------------------------------------------------------------------------------------------------------

calculate_loss = criterion(alpha_Focal,gamma_Focal,alpha_FocalTversky,gamma_FocalTversky,alpha_WBCE,reduction,use_loss1,use_loss2,use_loss3,use_loss4,use_loss5,use_loss6)

#--------------------------------------------------------------------------------------------------------------------------------------------

total_loss = 0
total_iou = 0
total_dice = 0
total_sens = 0
total_spec = 0
total_precision = 0
total_F1 = 0
total_mae = 0
num_samples = 0

#--------------------------------------------------------------------------------------------------------------------------------------------

torch.cuda.manual_seed(0)
torch.manual_seed(0)

i = 2
dataloader = [train_data,val_data,test_data]
name = ['train_data','val_data','test_data']

#dataloader = [train_data,val_data]
#name = ['train_data','val_data']

#--------------------------------------------------------------------------------------------------------------------------------------------

disable_grad(model)

with torch.no_grad():
  for batch in tqdm(dataloader[i]):

      images, masks = batch['images'], batch['masks']
      images = images.to(device=device)
      masks = masks.to(device=device)

      num_samples = num_samples + len(images)
      pred = model(images)

      loss = calculate_loss(pred, masks)
      total_loss += loss*len(images)

      pred_thres = (pred>=0.5).float()
      iou, dice, sens, spec, precision, F1, mae = calculate_metrics_for_test(pred_thres, masks)
      total_iou += iou
      total_dice += dice
      total_sens += sens
      total_spec += spec
      total_precision += precision
      total_F1 += F1
      total_mae += mae

  total_iou = total_iou.item()/num_samples
  total_dice = total_dice.item()/num_samples
  total_sens = total_sens.item()/num_samples
  total_spec = total_spec.item()/num_samples
  total_precision = total_precision.item()/num_samples
  total_F1 = total_F1.item()/num_samples
  total_mae = total_mae.item()/num_samples
  total_loss = total_loss.item()/num_samples

  print('\n')
  print(f'{name[i]} - iou: {total_iou:0.6f} - dice: {total_dice:0.6f} - sens (recall): {total_sens:0.6f} - spec: {total_spec:0.6f} - precision: {total_precision:0.6f} - F1: {total_F1:0.6f} - mae: {total_mae:0.6f}')
  print('------------------------------------------------------------------------------------------------------------------')


---

In [None]:
#torch.manual_seed(10)
#torch.cuda.manual_seed(10)

#d = 32
#device = torch.device('cuda')
#model = UNet(in_channels=input_ch,d1=d,d2=2*d,d3=4*d,d4=8*d,d5=16*d,kernel1=3,padding1=1,kernel2=3,padding2=1,device=device,dtype = torch.float32, dropout=dropout, first_norm=first_norm)
#model.to(device)

disable_grad(model)
dir = f'/content/drive/MyDrive/DENTISTRY/MODELS/{time_str}'
#model.load_state_dict(torch.load(dir+'/final_model.pth'))
#model.load_state_dict(torch.load(dir+'/best_model.pth'))

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

new_dataset = special_dataset(images_path=test_images_path,masks_path=test_masks_path, labels_path=test_labels_path,size=size, normalize=not first_norm)
#new_dataset = special_dataset(images_path=val_images_path,masks_path=val_masks_path, labels_path=val_labels_path,size=size, normalize=not first_norm)

for i in range(0,20):

    data, img_name = new_dataset.__getitem__(i)

    with torch.no_grad():
      disable_grad(model)
      output = model(data['images'].to(device).unsqueeze(0))
    output = output.squeeze(0,1)

    # add threshold
    thresh = 0.5
    output_thresh = (output >= thresh).float().to('cpu').detach().numpy()
    output = output.to('cpu').detach().numpy()
    yellow_output = color(data['images_view'],output_thresh,1,1,0)
    red_output = data['labels']
    test = torch.concatenate((data['images_view'].squeeze(0).unsqueeze(2).repeat(1,1,3),red_output,yellow_output),dim=1)

    plt.figure(i,figsize=(20,20))
    plt.figtext(0, 0.5, img_name, fontsize=15)
    plt.subplot(111), plt.imshow(test)


In [None]:
disable_grad(model)
dir = f'/content/drive/MyDrive/DENTISTRY/MODELS/{time_str}'
#model.load_state_dict(torch.load(dir+'/final_model.pth'))
#model.load_state_dict(torch.load(dir+'/best_model.pth'))

new_dataset = special_dataset(images_path=test_images_path,masks_path=test_masks_path, labels_path=test_labels_path,size=size, normalize=not first_norm)
#new_dataset = special_dataset(images_path=val_images_path,masks_path=val_masks_path, labels_path=val_labels_path,size=size, normalize=not first_norm)

for i in range(0,new_dataset.__len__()):

    data, img_name = new_dataset.__getitem__(i)

    with torch.no_grad():
      disable_grad(model)
      output = model(data['images'].to(device).unsqueeze(0))
    output = output.squeeze(0,1)

    # add threshold
    thresh = 0.5
    output_thresh = (output >= thresh).float().to('cpu').detach().numpy()
    output = output.to('cpu').detach().numpy()
    yellow_output = color(data['images_view'],output_thresh,1,1,0)
    red_output = data['labels']
    test = torch.concatenate((data['images_view'].squeeze(0).unsqueeze(2).repeat(1,1,3),red_output,yellow_output),dim=1).numpy()*255
    test = cv2.cvtColor(test,cv2.COLOR_BGR2RGB)

    output_folder = f"output_{time_str}"
    os.makedirs(output_folder,exist_ok=True)
    cv2.imwrite(output_folder + '/' + img_name, test)

    source_file = output_folder + '/' + img_name
    destination_path = f'/content/drive/MyDrive/DENTISTRY/{output_folder}'
    os.makedirs(destination_path,exist_ok=True)
    shutil.copy(source_file, destination_path)
