In [None]:
import os
import gc
import time
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from sklearn.metrics import f1_score, jaccard_score, precision_score, recall_score, accuracy_score

from albumentations import Compose, HorizontalFlip, VerticalFlip, ColorJitter, Affine
from albumentations.pytorch import ToTensorV2

from ModelArchitecture.RoSeg import RoSeg
from ModelArchitecture.DiceLoss import dice_loss
from ImageLoader2D import load_data

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd

import os
import random

In [None]:
class CustomDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]

        image_np = image.permute(1, 2, 0).cpu().numpy()
        mask_np = mask.permute(1, 2, 0).cpu().numpy()
        mask_np = np.squeeze(mask_np)

        if self.transform:
            augmented = self.transform(image=image_np, mask=mask_np)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask
    
def polynomial_decay_lambda(epoch):
    if epoch >= decay_steps:
        return end_learning_rate / starter_learning_rate
    else:
        return ((starter_learning_rate - end_learning_rate) * \
               (1 - epoch / decay_steps)**decay_power + end_learning_rate) / starter_learning_rate

In [None]:
train_transform = Compose([
    HorizontalFlip(p=0.5),
    VerticalFlip(p=0.5),
    ColorJitter(brightness=0.6, contrast=0.2, saturation=0.1, hue=0.01, always_apply=True),
    Affine(scale=(0.5, 1.5), translate_percent=(-0.125, 0.125), rotate=(-180, 180), shear=(-22.5, 22), always_apply=True),
    ToTensorV2()
])

val_transform = Compose([ToTensorV2()])
    
img_size = 352
dataset_type = 'cvc-clinicdb' # Options: kvasir/cvc-clinicdb/cvc-colondb/etis-laribpolypdb
starter_learning_rate = 1e-4
end_learning_rate = 1e-6
decay_steps = 1000
decay_power = 0.2

batch_size = 8
epochs = 300
seed_value = 58800
starting_filters = 17
min_loss_for_saving = 0.1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)

g = torch.Generator()
g.manual_seed(seed_value)

ct = datetime.now().strftime('%Y%m%d-%H%M%S')
model_type = "RoSeg"

base_dir = os.getcwd()
model_path = os.path.join(base_dir, f'ModelSave/{dataset_type}/{model_type}_{ct}_v2.pt')

In [None]:
os.makedirs(os.path.join(base_dir, f'ModelSave/{dataset_type}'), exist_ok=True)

In [None]:
x_data, y_data = load_data(img_size, img_size, -1, 'cvc-clinicdb', "../datasets/CVC-ClinicDB/CVC-ClinicDB/")

total_size = len(x_data)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

indices = torch.randperm(total_size)

x_data = x_data[indices]
y_data = y_data[indices]

x_train, y_train = x_data[:train_size], y_data[:train_size]
x_val, y_val = x_data[train_size:train_size + val_size], y_data[train_size:train_size + val_size]
x_test, y_test = x_data[train_size + val_size:], y_data[train_size + val_size:]

train_dataset = CustomDataset(x_train, y_train, transform=train_transform)
val_dataset = CustomDataset(x_val, y_val, transform=val_transform)
test_dataset = CustomDataset(x_test, y_test, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

log_path = os.path.join(base_dir, f"ProgressFull/{dataset_type}_training_log_{model_type}_{ct}_v2.csv")
log_df = pd.DataFrame(columns=["epoch", "train_loss", "val_loss", "lr"])
log_df.to_csv(log_path, index=False)

model = RoSeg(input_channels=3, out_classes=1, starting_filters=starting_filters).to(device)
optimizer = optim.AdamW(model.parameters(), lr=starter_learning_rate, weight_decay=starter_learning_rate)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=polynomial_decay_lambda)

loss_fn = dice_loss

best_val_loss = float('inf')
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    num_train_batches = 0
    
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = loss_fn(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        num_train_batches += 1
        
    avg_train_loss = total_train_loss / num_train_batches

    lr_scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']

    model.eval()
    val_loss = 0
    num_val_batches = 0
    
    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            outputs = model(x_val)
            loss = loss_fn(outputs, y_val)
            val_loss += loss.item()
            num_val_batches += 1
            
    avg_val_loss = val_loss / num_val_batches

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {current_lr:.6f}")

    new_row = {"epoch": epoch+1, "train_loss": avg_train_loss, "val_loss": avg_val_loss, "lr": current_lr}
    pd.DataFrame([new_row]).to_csv(log_path, mode='a', header=False, index=False)
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), model_path)
        print(f"Model saved with val_loss {avg_val_loss:.4f}")
    gc.collect()

In [None]:
def evaluate(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch = x_batch.to(device)
            outputs = model(x_batch).cpu() > 0.5
            all_preds.append(outputs)
            all_labels.append(y_batch)

    preds = torch.cat(all_preds).numpy().astype(bool)
    labels = torch.cat(all_labels).numpy().astype(bool)
    
    return {
        'dice': f1_score(labels.flatten(), preds.flatten()),
        'iou': jaccard_score(labels.flatten(), preds.flatten()),
        'precision': precision_score(labels.flatten(), preds.flatten()),
        'recall': recall_score(labels.flatten(), preds.flatten()),
        'accuracy': accuracy_score(labels.flatten(), preds.flatten())
    }

In [None]:
model.load_state_dict(torch.load(model_path))
metrics = evaluate(model, test_loader)

print("Test Metrics:")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")