In [None]:
import wandb
import os 

os.environ['WANDB_NOTEBOOK_NAME'] =
%env WANDB_API_KEY = 

wandb.login()

In [2]:
import time
from torch.nn import BCEWithLogitsLoss
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import gc  


INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.18 (you have 1.4.12). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.


In [3]:
from utils import (
    save_checkpoint,
    load_checkpoint,
    check_accuracy,
    save_test_images,
    check_accuracy_val,
    check_accuracy_test
)

from models import UNET
from monai.networks.nets import UNet
from train import train_epoch
from simulated_lung_dataset_gaussian import get_loaders_simulated
from simulated_lung_dataset_extra import get_loaders

In [4]:
# HYPERPARAMETERS:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIN_MEMORY = False  # True if DEVICE == 'cuda' else False
LOGGING = True
BATCH_SIZE = 8
NUM_EPOCHS = 50
NUM_WORKERS = 0
IMAGE_SIZE = 256

# train

In [5]:
def model_train(model_input, save_model_path, LOGGING, DEVICE):    
    model = model_input
    loss_fn = BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler()

    # training
    train_loader, test_loader, validation_loader = get_loaders()
    for epoch in range(1, NUM_EPOCHS+1):
        loss = train_epoch(train_loader, model, optimizer, loss_fn, scaler, DEVICE)
        if epoch % 5 == 0:
            check_accuracy_val(epoch, validation_loader, model, logging = LOGGING ,device=DEVICE)
        del loss
        torch.cuda.empty_cache() 
        
    checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            }
    check_accuracy_test(epoch, test_loader, model, logging = LOGGING ,device=DEVICE)
    save_checkpoint(checkpoint, filename= save_model_path)

# main

In [6]:
def model_train_log(i, model, FEATURES, IMAGE_SIZE, LOGGING, DEVICE = 'cuda'):
    model_save_path = f'/home/alex/Documents/new try/Data/Lung Unet/save_states/Own_Unet_gaussian_test_123.pth.tar' 
    
    #logging
    if LOGGING:
        run = wandb.init(
            project = "Segmentation normal",
            group = "normal",
            job_type = 'normal',
            name = 'monai_UNet',

            # track hyperparameters and run metadata
            config={
                "try no." : i,
                "features" : FEATURES,
                "image_size" : IMAGE_SIZE,
                },
        )
    else:
        pass
 
    
    model_train(
        model,
        model_save_path,
        LOGGING,
        DEVICE)
    
    if LOGGING:
        run.finish()

    torch.cuda.empty_cache() 
   

In [7]:
def run_tests():
    for i in range(3):
        for IMAGE_SIZE in [256]:
            for FEATURES in [(8,16, 32, 64)]:        
                LOGGING = True
                #wn_unet = UNET(in_channels=1,out_channels=1, features= FEATURES).to(DEVICE)
                monai_UNet = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=FEATURES, strides=(2, 2, 2), bias = False, num_res_units = 0).to(device=DEVICE)
                #load_checkpoint(Own_unet, r'/home/alex/Documents/new try/Data/Lung Unet/save_states/Own_unet_padding_1.pth.tar')
                model_train_log(i, monai_UNet, FEATURES, IMAGE_SIZE, LOGGING, DEVICE)
                

if __name__ == "__main__":
    run_tests()

100%|██████████| 62/62 [00:46<00:00,  1.33it/s, loss=0.639]
100%|██████████| 62/62 [00:46<00:00,  1.35it/s, loss=0.589]
100%|██████████| 62/62 [00:46<00:00,  1.34it/s, loss=0.548]
100%|██████████| 62/62 [00:46<00:00,  1.33it/s, loss=0.523]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.498]


Val-Epoch: 5, Acc: 0.87, and Dice score: 0.57, IoU: 0.81, hd: 76.09


100%|██████████| 62/62 [00:46<00:00,  1.34it/s, loss=0.465]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.453]
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.425]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.398]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.372]


Val-Epoch: 10, Acc: 0.95, and Dice score: 0.85, IoU: 0.83, hd: 25.44


100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.37] 
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.334]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.317]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.314]
100%|██████████| 62/62 [00:43<00:00,  1.44it/s, loss=0.292]


Val-Epoch: 15, Acc: 0.96, and Dice score: 0.88, IoU: 0.85, hd: 21.20


100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.275]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.275]
100%|██████████| 62/62 [00:44<00:00,  1.38it/s, loss=0.268]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.255]
100%|██████████| 62/62 [00:45<00:00,  1.38it/s, loss=0.235]


Val-Epoch: 20, Acc: 0.97, and Dice score: 0.90, IoU: 0.88, hd: 18.32


100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.232]
100%|██████████| 62/62 [00:45<00:00,  1.38it/s, loss=0.213]
100%|██████████| 62/62 [00:46<00:00,  1.35it/s, loss=0.215]
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.196]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.185]


Val-Epoch: 25, Acc: 0.97, and Dice score: 0.91, IoU: 0.89, hd: 17.37


100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.182]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.201]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.189]
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.168]
100%|██████████| 62/62 [00:45<00:00,  1.38it/s, loss=0.169]


Val-Epoch: 30, Acc: 0.97, and Dice score: 0.91, IoU: 0.90, hd: 16.95


100%|██████████| 62/62 [00:46<00:00,  1.33it/s, loss=0.15] 
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.14] 
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.142]
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.14] 
100%|██████████| 62/62 [00:46<00:00,  1.33it/s, loss=0.14] 


Val-Epoch: 35, Acc: 0.97, and Dice score: 0.92, IoU: 0.91, hd: 15.75


100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.137]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.125]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.133]
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.131]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.118]


Val-Epoch: 40, Acc: 0.98, and Dice score: 0.93, IoU: 0.92, hd: 14.58


100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.159]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.135]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.11] 
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.11] 
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.104] 


Val-Epoch: 45, Acc: 0.98, and Dice score: 0.93, IoU: 0.92, hd: 13.73


100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.0965]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.0925]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.104] 
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.101] 
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.118] 


Val-Epoch: 50, Acc: 0.98, and Dice score: 0.93, IoU: 0.93, hd: 12.90
Epoch: 50, Acc: 0.97, and Dice score: 0.93, IoU: 0.93, hd: 13.54
=> Saving checkpoint


0,1
epoch,▁
test_acc,▁
test_box_iou,▁
test_dice,▁
test_hd,▁
test_loss,▁

0,1
epoch,50.0
test_acc,0.97474
test_box_iou,0.92908
test_dice,0.93056
test_hd,13.54497
test_loss,8.99351


100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.618]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.552]
100%|██████████| 62/62 [00:46<00:00,  1.34it/s, loss=0.484]
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.455]
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.415]


Val-Epoch: 5, Acc: 0.93, and Dice score: 0.78, IoU: 0.79, hd: 40.59


100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.398]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.376]
100%|██████████| 62/62 [00:44<00:00,  1.38it/s, loss=0.343]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.339]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.313]


Val-Epoch: 10, Acc: 0.96, and Dice score: 0.86, IoU: 0.83, hd: 48.63


100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.297]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.285]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.271]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.258]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.267]


Val-Epoch: 15, Acc: 0.97, and Dice score: 0.89, IoU: 0.87, hd: 21.76


100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.246]
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.224]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.227]
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.222]
100%|██████████| 62/62 [00:44<00:00,  1.38it/s, loss=0.188]


Val-Epoch: 20, Acc: 0.97, and Dice score: 0.90, IoU: 0.88, hd: 20.31


100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.202]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.186]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.198]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.172]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.166]


Val-Epoch: 25, Acc: 0.97, and Dice score: 0.91, IoU: 0.90, hd: 18.58


100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.163]
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.159]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.17] 
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.139]
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.14] 


Val-Epoch: 30, Acc: 0.98, and Dice score: 0.92, IoU: 0.92, hd: 15.75


100%|██████████| 62/62 [00:44<00:00,  1.38it/s, loss=0.148]
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.141]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.127]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.129]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.127]


Val-Epoch: 35, Acc: 0.98, and Dice score: 0.92, IoU: 0.92, hd: 16.03


100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.122]
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.112]
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.122]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.103]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.108] 


Val-Epoch: 40, Acc: 0.98, and Dice score: 0.93, IoU: 0.92, hd: 13.80


100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.11]  
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.094] 
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.0939]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.119] 
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.111] 


Val-Epoch: 45, Acc: 0.98, and Dice score: 0.93, IoU: 0.92, hd: 14.64


100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.118] 
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.0825]
100%|██████████| 62/62 [00:46<00:00,  1.33it/s, loss=0.0809]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.0842]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.102] 


Val-Epoch: 50, Acc: 0.98, and Dice score: 0.94, IoU: 0.94, hd: 14.02
Epoch: 50, Acc: 0.98, and Dice score: 0.93, IoU: 0.94, hd: 14.01
=> Saving checkpoint


0,1
epoch,▁
test_acc,▁
test_box_iou,▁
test_dice,▁
test_hd,▁
test_loss,▁

0,1
epoch,50.0
test_acc,0.97535
test_box_iou,0.93842
test_dice,0.93248
test_hd,14.00759
test_loss,9.00191


100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.74] 
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.661]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.588]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.554]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.526]


Val-Epoch: 5, Acc: 0.72, and Dice score: 0.47, IoU: 0.31, hd: 88.99


100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.52] 
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.459]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.445]
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.443]
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.433]


Val-Epoch: 10, Acc: 0.90, and Dice score: 0.72, IoU: 0.80, hd: 91.68


100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.392]
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.379]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.355]
100%|██████████| 62/62 [00:44<00:00,  1.41it/s, loss=0.334]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.334]


Val-Epoch: 15, Acc: 0.95, and Dice score: 0.84, IoU: 0.86, hd: 74.81


100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.319]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.271]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.28] 
100%|██████████| 62/62 [00:43<00:00,  1.42it/s, loss=0.296]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.28] 


Val-Epoch: 20, Acc: 0.97, and Dice score: 0.89, IoU: 0.88, hd: 22.31


100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.238]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.223]
100%|██████████| 62/62 [00:45<00:00,  1.35it/s, loss=0.215]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.214]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.213]


Val-Epoch: 25, Acc: 0.97, and Dice score: 0.90, IoU: 0.90, hd: 19.57


100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.191]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.189]
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.18] 
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.182]
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.18] 


Val-Epoch: 30, Acc: 0.97, and Dice score: 0.91, IoU: 0.91, hd: 17.91


100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.162]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.163]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.145]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.147]
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.134]


Val-Epoch: 35, Acc: 0.98, and Dice score: 0.93, IoU: 0.92, hd: 16.28


100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.138]
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.151]
100%|██████████| 62/62 [00:44<00:00,  1.38it/s, loss=0.125]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.146]
100%|██████████| 62/62 [00:45<00:00,  1.38it/s, loss=0.135]


Val-Epoch: 40, Acc: 0.98, and Dice score: 0.93, IoU: 0.93, hd: 15.77


100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.111]
100%|██████████| 62/62 [00:43<00:00,  1.43it/s, loss=0.132]
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.119]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.128]
100%|██████████| 62/62 [00:44<00:00,  1.39it/s, loss=0.12]  


Val-Epoch: 45, Acc: 0.98, and Dice score: 0.93, IoU: 0.93, hd: 14.89


100%|██████████| 62/62 [00:45<00:00,  1.38it/s, loss=0.106] 
100%|██████████| 62/62 [00:45<00:00,  1.36it/s, loss=0.0931]
100%|██████████| 62/62 [00:43<00:00,  1.41it/s, loss=0.122] 
100%|██████████| 62/62 [00:44<00:00,  1.40it/s, loss=0.093] 
100%|██████████| 62/62 [00:45<00:00,  1.37it/s, loss=0.0833]


Val-Epoch: 50, Acc: 0.98, and Dice score: 0.93, IoU: 0.93, hd: 13.53
Epoch: 50, Acc: 0.97, and Dice score: 0.93, IoU: 0.93, hd: 13.45
=> Saving checkpoint


0,1
epoch,▁
test_acc,▁
test_box_iou,▁
test_dice,▁
test_hd,▁
test_loss,▁

0,1
epoch,50.0
test_acc,0.97462
test_box_iou,0.92765
test_dice,0.92907
test_hd,13.44573
test_loss,8.9861


: 

In [None]:
from utils import plot_image_mask_box_pred_box
train_loader, test_loader, validation_loader = get_loaders()
monai_UNet = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), bias = False, num_res_units = 0).to(device=DEVICE)
load_checkpoint(monai_UNet,r'/home/alex/Documents/new try/Data/Lung Unet/save_states/Monai_Unet_final_1.pth.tar')

for image, target in test_loader:
    plot_image_mask_box_pred_box(monai_UNet, image, target)