In [1]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from PIL import Image

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from torchmetrics import JaccardIndex
from torch import optim
from tqdm import tqdm
from unet import UNet

In [2]:
from utils.CustomDataset import load_data
from utils.loss import DiceLoss

In [3]:
train_loader, valid_loader = load_data(test_size=0.3, batch_size=1, img_size=256, dir='./data/all_data/', artificial_increase=20)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
jaccard = JaccardIndex(task='binary',num_classes=1)

In [6]:
model = UNet(in_channels=1, out_channels=1)

In [7]:
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
# criterion = nn.BCEWithLogitsLoss()
# criterion = JaccardLoss().to(device)
# criterion_val = JaccardLoss(log_loss=True).to(device)
criterion = DiceLoss()

In [8]:
def soft_jaccard_score(
    output: torch.Tensor,
    target: torch.Tensor,
    smooth: float = 0.0,
    eps: float = 1e-7,

) -> torch.Tensor:
    intersection = torch.sum(output * target)
    cardinality = torch.sum(output + target)

    union = cardinality - intersection
    jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
    return jaccard_score

In [9]:
@torch.inference_mode()
def evaluate(model, dataloader, device):
    model.eval()
    score = 0

    # iterate over the validation set
    for batch in tqdm(dataloader, total=len(dataloader), desc= f'Validation', unit='batch', leave=False):
        image, mask_true = batch['image'], batch['mask']

        image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
        mask_true = mask_true.to(device=device, dtype=torch.long)

        mask_pred = model(image)
            
        iou_curr = soft_jaccard_score(mask_true.cpu(), mask_pred.cpu())
        score += iou_curr
            
            

    return iou_curr / max(len(dataloader), 1)

In [10]:
def train_model(
    model,
    device,
    criterion, 
    train_loader,
    val_loader,
    epochs: int=10,
    learning_rate: float=1e-5,
    img_scale: float=0.5, 
):
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    grad_scaler = torch.cuda.amp.GradScaler(enabled=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
    # criterion = nn.BCEWithLogitsLoss()
    # criterion = JaccardLoss().to(device)
    
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0
        train_score = 0
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='img') as epoch_bar:
            for batch in train_loader:
                images, mask = batch['image'], batch['mask']
                images = images.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.long)
                
                optimizer.zero_grad()
                
                output = model(images)
                loss = criterion(output.squeeze(1), mask.squeeze(1).float())
                    
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                # grad_scaler.step(optimizer)
                # grad_scaler.update()
                optimizer.step()
                
                train_loss += loss.item()
                epoch_bar.update()
            
                iou_curr = soft_jaccard_score(mask.cpu(), output.cpu())
                train_score += iou_curr
                
                epoch_bar.set_postfix(**{'loss (batch)': loss.item()})
                
        model.eval()
        val_score = evaluate(model, val_loader, device)
        print(f'Train {epoch} epoch with iou-score:{train_score/len(train_loader)}')
        print(f'Valid {epoch} epoch with iou-score:{val_score}')
        scheduler.step(val_score)
        
                
    

In [11]:
train_model(model, device, criterion, train_loader, valid_loader)

Epoch 1/10: 100%|██████████| 266/266 [00:11<00:00, 24.13img/s, loss (batch)=0.00397]
                                                                

Train 1 epoch with iou-score:0.007801605388522148
Valid 1 epoch with iou-score:7.312961315619759e-06


Epoch 2/10: 100%|██████████| 266/266 [00:08<00:00, 30.02img/s, loss (batch)=0.00089]
                                                               

Train 2 epoch with iou-score:0.007533056195825338
Valid 2 epoch with iou-score:7.27738051864435e-06


Epoch 3/10: 100%|██████████| 266/266 [00:08<00:00, 30.18img/s, loss (batch)=0.000624]
                                                                

Train 3 epoch with iou-score:0.00750366086140275
Valid 3 epoch with iou-score:7.22636877981131e-06


Epoch 4/10: 100%|██████████| 266/266 [00:08<00:00, 30.23img/s, loss (batch)=0.000439]
                                                                

Train 4 epoch with iou-score:0.0074931420385837555
Valid 4 epoch with iou-score:7.2292609729629476e-06


Epoch 5/10: 100%|██████████| 266/266 [00:08<00:00, 30.66img/s, loss (batch)=0.000284]
                                                                

Train 5 epoch with iou-score:0.00749823497608304
Valid 5 epoch with iou-score:7.277587428689003e-06


Epoch 6/10: 100%|██████████| 266/266 [00:08<00:00, 30.37img/s, loss (batch)=0.000203]
                                                      

KeyboardInterrupt: 