In [None]:
!pip install monai > /dev/null
!pip install itk > /dev/null
!pip install ltntorch > /dev/null

In [None]:
def nested(hard):
    mask1 = hard == 1
    mask2 = hard == 2
    mask1 = mask1.unsqueeze(1)
    mask2 = mask2.unsqueeze(1)

    return sample_and_check(mask1, mask2)

def sample_and_check(mask1, mask2, num_samples=20):
    """
    Args:
    - mask1: A tensor of shape (b, c, h, w, d) where each batch contains a mask.
    - mask2: A tensor of shape (b, c, h, w, d) where each batch contains a mask.
    - num_samples: Number of points to sample between each (src, dst) pair.
    """
    b, c, h, w, d = mask1.shape

    # Initialize the output
    counts = torch.zeros(b, dtype=torch.int).to(device)

    # Loop through each batch
    for i in range(b):
        # Find the indices where mask1 == 1 for the current batch
        mask1_indices = torch.nonzero(mask1[i] == 1)

        # Randomly sample num_samples (src, dst) pairs
        if mask1_indices.shape[0] < 2:
            continue  # Skip if there are fewer than 2 points in mask1 == 1

        src_dst_pairs = mask1_indices[torch.randint(0, mask1_indices.shape[0], (num_samples, 2))]

        for pair_idx, (src, dst) in enumerate(src_dst_pairs):
            # Get the linearly interpolated points from src to dst
            src_coords = src[1:]
            dst_coords = dst[1:]

            # Interpolate between src and dst in each spatial dimension (h, w, d)
            steps = torch.linspace(0, 1, 50).to(device)
            interpolated_points = torch.stack([steps * (dst_coords[i] - src_coords[i]) + src_coords[i]
                                               for i in range(3)], dim=-1)

            # Check for mask2 == 1 at each interpolated point (round coordinates to integers)
            for point in interpolated_points:
                # Convert to integers and clamp to valid indices
                point_rounded = torch.round(point).long()
                point_rounded = torch.clamp(point_rounded, min=torch.tensor(0).to(device),
                                            max=torch.tensor([h - 1, w - 1, d - 1]).to(device))

                # Check if mask2 == 1 at this point
                if mask2[i, 0, point_rounded[0], point_rounded[1], point_rounded[2]] == 1:
                    counts[i] += 1
                    break
            if counts[i] > 0:
                break

    return counts


import torch
from monai.metrics import DiceMetric
from monai.apps import DecathlonDataset
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Resized
from monai.data import partition_dataset
from torch.utils.data import DataLoader, random_split
from monai.networks.nets import SwinUNETR
from monai.losses import DiceLoss
from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm
import ltn

dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=True)


def evaluate(model, val_loader, device='cuda'):
    model.eval()
    dice_metric.reset()
    with torch.no_grad():
        for batch in val_loader:
            images, labels = batch["image"].to(device), batch["label"].to(device).squeeze(1)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).squeeze(1)
            dice_metric(preds, labels)
    dice_score, _ = dice_metric.aggregate()
    dice_metric.reset()
    return dice_score.cpu().item()


# Data transformation
transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 1.5), mode=("bilinear", "nearest")),
    Resized(keys=["image", "label"], spatial_size=(64, 64, 64), mode=("bilinear", "nearest"))
])


def dimension(pred, gamma=0.0001, epsilon=5000.):
    hard = pred.argmax(dim=1)
    mask1 = hard == 1
    mask2 = hard == 2

    B = mask1.shape[0]  # Batch size

    dims = torch.zeros(B, device=mask1.device)

    for i in range(B):
        n1 = torch.sum(mask1[i] > 0).item()
        n2 = torch.sum(mask2[i] > 0).item()
        diff = torch.clamp((torch.abs(torch.tensor(n1 - n2)) - epsilon), min=0)

        dims[i] = torch.exp(torch.tensor(-gamma * (diff) ** 2))

    return dims


def chamfer_distance(pred):
    """
    Computes a differentiable Chamfer distance between two binary masks.

    Args:
    - mask1, mask2: (h, w, d) binary PyTorch tensors

    Returns:
    - chamfer_dist: scalar differentiable Chamfer distance
    """

    hard = pred.argmax(dim=1)
    mask1 = hard == 1
    mask2 = hard == 2

    B = mask1.shape[0]  # Batch size
    chamfer_dists = torch.zeros(B, device=mask1.device)

    for i in range(B):
        # Get nonzero coordinates (foreground points) for each sample
        coords1 = torch.nonzero(mask1[i], as_tuple=False).float()  # (N1, 3)
        coords2 = torch.nonzero(mask2[i], as_tuple=False).float()  # (N2, 3)

        if coords1.numel() == 0 or coords2.numel() == 0:  # Handle empty masks
            chamfer_dists[i] = float('inf')
            continue

        # Compute pairwise Euclidean distances
        dists = torch.cdist(coords1, coords2, p=2)  # (N1, N2)

        # Get min distances for both directions
        min_dists1 = torch.min(dists, dim=1)[0]  # (N1,)
        min_dists2 = torch.min(dists, dim=0)[0]  # (N2,)

        # Chamfer distance for the batch element
        chamfer_dists[i] = torch.mean(min_dists1) + torch.mean(min_dists2)

    return chamfer_dists  # Shape (B,)


dataset = DecathlonDataset(root_dir="./", task="Task04_Hippocampus", section="training", download=True,
                           transform=transforms)
kf = KFold(n_splits=5, shuffle=True, random_state=42)

from monai.losses import DiceLoss

# Loss function (Dice Loss is commonly used for segmentation)
baseline_loss_function = DiceLoss(to_onehot_y=True, softmax=True)


def baseline_train(model, train_loader, val_loader, num_epochs=20, loss_fn=None):
    from monai.optimizers import WarmupCosineSchedule
    from monai.data import DataLoader
    from torch.optim import AdamW

    # Optimizer
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

    # steps = num_epochs * len(train_loader)
    # Learning rate scheduler
    # scheduler = WarmupCosineSchedule(optimizer, warmup_steps=steps // 10, t_total=steps)
    scheduler = WarmupCosineSchedule(optimizer, warmup_steps=10, t_total=100)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for batch in train_loader:
            inputs, labels = batch["image"].to(device), batch["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"BASE Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")
        dice_score = evaluate(model, val_loader, device)
        print(f'val dice score: {dice_score}')
        # train_dice_score = evaluate(model, train_loader, device)
        # print(f'train dice score: {train_dice_score}')
        scheduler.step()
    return model


import ltn

def ltn_train(model, train_loader, val_loader, num_epochs=20):
    from monai.optimizers import WarmupCosineSchedule
    from monai.data import DataLoader
    from torch.optim import AdamW

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = WarmupCosineSchedule(optimizer, warmup_steps=10, t_total=100)

    segmentator = ltn.Function(model)

    Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier='f')
    SatAgg = ltn.fuzzy_ops.SatAgg()

    l_background = ltn.Constant(torch.tensor(0.))
    l_anterior = ltn.Constant(torch.tensor(1.))
    l_posterior = ltn.Constant(torch.tensor(2.))

    def eq_fn3d(u, v, alpha=0.3):
        return torch.exp(-alpha * torch.sqrt(torch.sum(torch.square(u - v), dim=1))).mean(dim=(1, 2, 3))

    def eq_fn(u, v, alpha=1e-3):
        return torch.exp(-alpha * torch.sqrt(torch.square(u - v)))

    Eq3d = ltn.Predicate(func=eq_fn3d)
    Eq = ltn.Predicate(func=eq_fn)

    dice_loss = DiceLoss(to_onehot_y=True, softmax=True, reduction='none')

    def my_dice_loss(outputs, labels):
        return 1. - dice_loss(outputs, labels).mean(dim=1)

    Dl = ltn.Predicate(func=my_dice_loss)

    MinDst = ltn.Function(func=chamfer_distance)
    SimDim = ltn.Function(func=dimension)
    Nested = ltn.Function(func=nested)
    Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for batch in train_loader:
            inputs, labels = batch["image"].to(device), batch["label"].to(device)
            optimizer.zero_grad()

            # ltn!
            # we ground the variables with current batch data
            x = ltn.Variable("x", inputs)
            y = ltn.Variable("y", labels)
            ll = ltn.Constant(labels)
            y_background = ltn.Variable("y_background", labels[labels == 0])
            y_anterior = ltn.Variable("y_anterior", labels[labels == 1])
            y_posterior = ltn.Variable("y_posterior", labels[labels == 2])

            outputs = model(inputs)
            pred = torch.argmax(outputs, dim=1)
            pred = ltn.Variable("pred", pred)

            zero = ltn.Constant(torch.tensor(0.))

            # ltn!
            sat_agg = SatAgg(
                Forall(ltn.diag(x, y), Dl(segmentator(x), y)).value,
                Forall(pred, Eq(MinDst(pred), zero)).value,
                Forall(pred, SimDim(pred)).value,
                Forall(pred, Not(Nested(pred))).value
            )
            loss = 1. - sat_agg

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"LTN Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")
        dice_score = evaluate(model, val_loader, device)
        print(f'val dice score: {dice_score}')
        # train_dice_score = evaluate(model, train_loader, device)
        # print(f'train dice score: {train_dice_score}')
        scheduler.step()
    return model


# k-fold CV execution
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)
baseline_dice_scores = []
ltn_dice_scores = []

num_epochs = 100

tr_size = 0.25

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset)):
    print(f"Fold {fold + 1}/k")

    train_idx = train_idx[:int(tr_size * len(train_idx))]

    train_subset = torch.utils.data.Subset(dataset, train_idx)
    val_subset = torch.utils.data.Subset(dataset, val_idx)
    train_loader = DataLoader(train_subset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=4, shuffle=False)

    # LTN Model
    model_ltn = SwinUNETR(img_size=(64, 64, 64), in_channels=1, out_channels=3, use_checkpoint=True).to(device)
    trained_model_ltn = ltn_train(model_ltn, train_loader, val_loader, num_epochs=num_epochs)
    torch.save(trained_model_ltn.state_dict(), f'model_state_dict_ltn_fold{fold + 1}-tr={tr_size}.pth')
    ltn_dice_scores.append(evaluate(trained_model_ltn, val_loader, device))

    # Baseline Model
    model = SwinUNETR(img_size=(64, 64, 64), in_channels=1, out_channels=3, use_checkpoint=True).to(device)
    trained_model = baseline_train(model, train_loader, val_loader, num_epochs=num_epochs, loss_fn=baseline_loss_function)
    torch.save(trained_model.state_dict(), f'model_state_dict_baseline_fold{fold+1}-tr={tr_size}.pth')
    baseline_dice_scores.append(evaluate(trained_model, val_loader, device))
    

# Reporting final results
baseline_mean = np.mean(baseline_dice_scores)
baseline_std = np.std(baseline_dice_scores)
ltn_mean = np.mean(ltn_dice_scores)
ltn_std = np.std(ltn_dice_scores)

print(f"Baseline Model - Mean Dice: {baseline_mean:.4f} ± {baseline_std:.4f}")
print(f"LTN Model - Mean Dice: {ltn_mean:.4f} ± {ltn_std:.4f}")