In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torchio

In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchio as tio
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from scipy.ndimage import rotate
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Subset, Dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#Reproductibilitate
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)


# Configurare globală
CONFIG = {
    "seed": 42,
    "batch_size": 24,
    "epochs": 20,
    "learning_rate": 1.67e-5,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

set_seed(CONFIG["seed"])

device = torch.device(CONFIG["device"])
torch.use_deterministic_algorithms(False)


# Definim Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, preds, targets):
        targets = targets.float()  # Conversie la FloatTensor
        bce_loss = F.binary_cross_entropy(preds, targets, reduction="none")
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

# Dice Loss
class WeightedDiceLoss(nn.Module):
    def __init__(self, weight=None, smooth=1e-6):
        super(WeightedDiceLoss, self).__init__()
        self.weight = weight
        self.smooth = smooth

    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (
            pred.sum() + target.sum() + self.smooth
        )

        if self.weight is not None:
            # Aplica ponderi reale: 1 - dice = loss, și îl ponderăm
            weighted_dice_loss = self.weight[1] * (1 - dice) + self.weight[0] * dice
            return weighted_dice_loss
        else:
            return 1 - dice


# Combined Loss (Focal + Dice)
class CombinedLoss(nn.Module):
    def __init__(self, weight=None, ce_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.ce = nn.BCELoss()  # Sau poți folosi nn.BCEWithLogitsLoss() dacă nu ai sigmoid în model
        self.dice = WeightedDiceLoss(weight=weight)
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight

    def forward(self, preds, targets):
        ce_loss = self.ce(preds, targets.float())
        dice_loss = self.dice(preds, targets)
        return self.ce_weight * ce_loss + self.dice_weight * dice_loss

# Dice score pentru validare
def dice_score(preds, targets, threshold=0.5):
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()
    return (2. * intersection) / (union + 1e-8)

# Dataset wrapper pentru lista de patch-uri
class PatchDataset(Dataset):
    def __init__(self, patches):
        self.patches = patches

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        img, mask = self.patches[idx]
        return torch.tensor(img).unsqueeze(0).float(), torch.tensor(mask).unsqueeze(0).float()




class AttentionBlock3D(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock3D, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv3d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv3d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv3d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm3d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class UNet3D_Attention(nn.Module):
    def __init__(self, dropout_rate=0.3):
        super(UNet3D_Attention, self).__init__()

        def CBR(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout3d(p=dropout_rate),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout3d(p=dropout_rate)
            )

        self.pool = nn.MaxPool3d(2)

        self.enc1 = CBR(1, 32)
        self.enc2 = CBR(32, 64)
        self.enc3 = CBR(64, 128)
        self.enc4 = CBR(128, 256)

        self.bottleneck = CBR(256, 512)

        # Attention gates
        self.att4 = AttentionBlock3D(F_g=512, F_l=256, F_int=128)
        self.att3 = AttentionBlock3D(F_g=256, F_l=128, F_int=64)
        self.att2 = AttentionBlock3D(F_g=128, F_l=64, F_int=32)
        self.att1 = AttentionBlock3D(F_g=64, F_l=32, F_int=16)

        # Decoder blocks
        self.dec4 = CBR(512 + 256, 256)
        self.dec3 = CBR(256 + 128, 128)
        self.dec2 = CBR(128 + 64, 64)
        self.dec1 = CBR(64 + 32, 32)

        self.final = nn.Conv3d(32, 1, kernel_size=1)

    def center_crop(self, enc_feat, target_size):
        _, _, d, h, w = enc_feat.size()
        td, th, tw = target_size
        d1 = (d - td) // 2
        h1 = (h - th) // 2
        w1 = (w - tw) // 2
        return enc_feat[:, :, d1:d1+td, h1:h1+th, w1:w1+tw]

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder + Attention
        up4 = F.interpolate(bottleneck, size=enc4.shape[2:], mode='trilinear', align_corners=True)
        att4 = self.att4(g=up4, x=self.center_crop(enc4, up4.shape[2:]))
        dec4 = self.dec4(torch.cat([up4, att4], dim=1))

        up3 = F.interpolate(dec4, size=enc3.shape[2:], mode='trilinear', align_corners=True)
        att3 = self.att3(g=up3, x=self.center_crop(enc3, up3.shape[2:]))
        dec3 = self.dec3(torch.cat([up3, att3], dim=1))

        up2 = F.interpolate(dec3, size=enc2.shape[2:], mode='trilinear', align_corners=True)
        att2 = self.att2(g=up2, x=self.center_crop(enc2, up2.shape[2:]))
        dec2 = self.dec2(torch.cat([up2, att2], dim=1))

        up1 = F.interpolate(dec2, size=enc1.shape[2:], mode='trilinear', align_corners=True)
        att1 = self.att1(g=up1, x=self.center_crop(enc1, up1.shape[2:]))
        dec1 = self.dec1(torch.cat([up1, att1], dim=1))

        output = self.final(dec1)
        return torch.sigmoid(output)

# Funcție de antrenare
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    dice_scores = []

    for imgs, masks in loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        output = model(imgs)
        loss = criterion(output, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)

        with torch.no_grad():
            dice_scores.append(dice_score(output, masks).item())

    avg_loss = total_loss / len(loader.dataset)
    avg_dice = np.mean(dice_scores)
    return avg_loss, avg_dice

# Funcție de validare
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    dice_scores = []

    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item() * imgs.size(0)
            dice_scores.append(dice_score(outputs, masks).item())

    avg_loss = total_loss / len(loader.dataset)
    avg_dice = np.mean(dice_scores)
    return avg_loss, avg_dice

# 5-Fold Cross-Validation cu Early Stopping
def run_kfold_cv(patch_dataset, num_folds=5, num_epochs=50, batch_size=2, lr=1e-4, patience=10, device='cuda'):
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    all_fold_scores = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(patch_dataset)):
        print(f"\n🔁 Fold {fold+1}/{num_folds}")
        train_set = Subset(patch_dataset, train_idx)
        val_set = Subset(patch_dataset, val_idx)

        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

        model = UNet3D_Attention().to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.BCELoss()

        best_dice = 0.0
        patience_counter = 0
        best_model_state = None

        for epoch in range(num_epochs):
            train_loss, train_dice = train_one_epoch(model, train_loader, optimizer, criterion, device)
            val_loss, val_dice = validate(model, val_loader, criterion, device)

            print(f"Epoch {epoch+1}: "
                  f"Train Loss = {train_loss:.4f}, Train Dice = {train_dice:.4f} | "
                  f"Val Loss = {val_loss:.4f}, Val Dice = {val_dice:.4f}")

            if val_dice > best_dice:
                best_dice = val_dice
                patience_counter = 0
                best_model_state = model.state_dict()
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"⏹️ Early stopping triggered at epoch {epoch+1}")
                    break

        all_fold_scores.append(best_dice)
        torch.save(best_model_state, f"/content/drive/MyDrive/best_model_fold_{fold+1}.pt")
        print(f"✅ Fold {fold+1} Best Dice: {best_dice:.4f}")

    avg_dice = np.mean(all_fold_scores)
    print(f"\n📊 Final Dice Score AVG over {num_folds} folds: {avg_dice:.4f}")
    return all_fold_scores


dataset = PatchDataset(all_patches)
fold_scores = run_kfold_cv(dataset, num_folds=5, num_epochs=50, batch_size=2, patience=10, device='cuda')