In [None]:
# Training Methods

> Training, loss functions and metris will be developed here

In [None]:
#| default_exp pytorch_training_and_loss

In [None]:
# export
#import torch
#from torch.optim.lr_scheduler import CosineAnnealingLR
#from torch.optim.swa_utils import SWALR
#from torch import nn
#from torchvision.transforms.functional import to_pil_image


In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
#| export
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        return F_loss


In [None]:
#| export
def iou_metric(preds, labels):
    preds = preds.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W
    intersection = (preds & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (preds | labels).float().sum((1, 2))  # Will be zero if both are 0
    iou = (intersection + 1e-6) / (union + 1e-6)  # We smooth our devision to avoid 0/0
    return iou.mean()


In [None]:
#| export
def false_positive_negative(outputs, labels):
    FP = ((outputs == 1) & (labels == 0)).float().sum()
    FN = ((outputs == 0) & (labels == 1)).float().sum()
    return FP, FN

In [None]:
# Define model, optimizer, and loss
model = AttentionUnet(num_classes=1)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = FocalLoss()


In [None]:
#| export
total_epochs = 10
# Define LR scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=total_epochs, eta_min=0)

# Training loop
for epoch in range(total_epochs):
    model.train()

    train_loss, train_iou_sum, train_fp_sum, train_fn_sum = 0.0, 0.0, 0.0, 0.0
    train_total = 0

    with tqdm(total=len(train_loader),desc=f'Epoch {epoch+1}/{total_epochs}', unit="batch") as pbar:

        for data, target in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = FocalLoss(outputs, target)
            loss.backward()
            optimizer.step()

            # Update metrics
            train_loss += loss.item()
            iou_score = iou(outputs, target)
            train_iou_sum += iou_score.item()
            fp, fn = false_positive_negative(outputs, target)
            train_fp_sum += fp.item()
            train_fn_sum += fn.item()
            train_total += 1

            # Update progress bar
            pbar.set_postfix({'loss': loss.item(), 'iou': iou_score.item()})
            pbar.update()
    
    train_loss /= train_total
    train_iou = train_iou_sum / train_total
    train_fp = train_fp_sum / train_total
    train_fn = train_fn_sum / train_total

    # Validation
    model.eval()
    val_loss, val_iou_sum, val_fp_sum, val_fn_sum = 0.0, 0.0, 0.0, 0.0
    val_total = 0
    with torch.no_grad():
        for data, target in val_loader:
            outputs = model(data)
            loss = FocalLoss(outputs, target)

            # Update metrics
            val_loss += loss.item()
            iou_score = iou(outputs, target)
            val_iou_sum += iou_score.item()
            fp, fn = false_positive_negative(outputs, target)
            val_fp_sum += fp.item()
            val_fn_sum += fn.item()
            val_total += 1

    val_loss /= val_total
    val_iou = val_iou_sum / val_total
    val_fp = val_fp_sum / val_total
    val_fn = val_fn_sum / val_total



    # Adjust learning rate after warm-up
    if epoch < warmup_epochs:
        lr = initial_lr * (epoch / warmup_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

    # Save model with metrics in the filename
    torch.save(model.state_dict(), f'model_epoch_{epoch+1}_valiou_{val_iou:.2f}_fp_{val_fp:.0f}_fn_{val_fn:.0f}.pth')

    # Print epoch summary
    print(f'Epoch {epoch+1}/{total_epochs} - Train Loss: {train_loss:.4f}, IoU: {train_iou:.4f}, FP: {train_fp}, FN: {train_fn}')
    print(f'Validation Loss: {val_loss:.4f}, IoU: {val_iou:.4f}, FP: {val_fp}, FN: {val_fn}')


In [None]:
#| export
class SSIMLoss(nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

    def forward(self, img1, img2):
        return 1 - pytorch_ssim.SSIM()#ssim(img1, img2, window_size=self.window_size, size_average=self.size_average, val_range=self.val_range)



In [None]:
#| export
import nbdev; nbdev.nbdev_export('10_training.ipynb')