> **Requirements and Uploading datasets**

In [None]:
!pip install monai
#!pip install scikit-image==0.19.3
#!pip install networkx
#!pip install gdown --upgrade

> **Train the basic model**

In [None]:
# full_debugged_unet3d_train.py
import os
import glob
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.transforms import (
    LoadImaged,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    CropForegroundd,
    RandCropByPosNegLabeld,
    EnsureChannelFirstd,
    Compose,
)
from monai.data import DataLoader, Dataset, NibabelReader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
from monai.config import print_config
from monai.utils import set_determinism

import torch.nn.functional as F
from torch.cuda import amp

# ---------------------------

class SingleConv(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, order='crg', GroupNumber=8):
        super(SingleConv, self).__init__()
        for name, module in self._create_conv(in_channels, out_channels, kernel_size, order, GroupNumber):
            self.add_module(name, module)

    def _create_conv(self, in_channels, out_channels, kernel_size, order, GroupNumber):
        assert 'c' in order, 'Convolution must have a conv operation'
        modules = []
        for i, char in enumerate(order):
            if char == 'r':
                modules.append(('ReLU', torch.nn.ReLU(inplace=True)))
            elif char == 'c':
                bias = not ('g' in order or 'b' in order)
                modules.append(('conv', torch.nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, padding=1, stride=1)))
            elif char == 'g':
                is_before_conv = i < order.index('c')
                assert not is_before_conv, 'GroupNorm MUST go after the Conv3d'
                if out_channels < GroupNumber:
                    GroupNumber = out_channels
                modules.append(('groupnorm', torch.nn.GroupNorm(num_groups=GroupNumber, num_channels=out_channels)))
            elif char == 'i':
                modules.append(('instancenorm', torch.nn.InstanceNorm3d(out_channels, affine=True)))
            elif char == 'b':
                modules.append(('batchnorm', torch.nn.BatchNorm3d(out_channels)))
        return modules

class DoubleConv(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, encoder, order='crg', GroupNumber=8):
        super(DoubleConv, self).__init__()
        if encoder:
            conv1_in_channels = in_channels
            conv1_out_channels = out_channels // 2
            if (conv1_out_channels < in_channels):
                conv1_out_channels = in_channels
            conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
        else:
            conv1_in_channels, conv1_out_channels = in_channels, out_channels
            conv2_in_channels, conv2_out_channels = out_channels, out_channels
        self.add_module(name='Conv1', module=SingleConv(in_channels=conv1_in_channels,
                                                       out_channels=conv1_out_channels,
                                                       kernel_size=kernel_size,
                                                       order=order,
                                                       GroupNumber=GroupNumber))
        self.add_module(name='Conv2', module=SingleConv(in_channels=conv2_in_channels,
                                                       out_channels=conv2_out_channels,
                                                       kernel_size=kernel_size,
                                                       order=order,
                                                       GroupNumber=GroupNumber))

class UNet3D_Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pool_kernelsize=(2, 2, 2),
                     pooling_type='max', apply_pooling=True, Basic_Module=DoubleConv, order='crg', GroupNumber=8):
        super(UNet3D_Encoder, self).__init__()
        assert pooling_type in ['max', 'avg'], 'Pooling_Type must be max or avg'
        if apply_pooling:
            if pooling_type == 'max':
                self.pooling = torch.nn.MaxPool3d(kernel_size=pool_kernelsize)
            else:
                self.pooling = torch.nn.AvgPool3d(kernel_size=pool_kernelsize)
        else:
            self.pooling = None
        self.basic_module = Basic_Module(in_channels=in_channels,
                                             out_channels=out_channels,
                                             kernel_size=kernel_size,
                                             encoder=True, order=order,
                                             GroupNumber=GroupNumber)

    def forward(self, x):
        if self.pooling is not None:
            x = self.pooling(x)
        x = self.basic_module(x)
        return x

class UNet3D_Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, Basic_Module=DoubleConv, order='crb', GroupNumber=8):
        super(UNet3D_Decoder, self).__init__()
        self.upsample = None
        self.basic_module = Basic_Module(in_channels=in_channels,
                                             out_channels=out_channels,
                                             kernel_size=kernel_size,
                                             encoder=False, order=order,
                                             GroupNumber=GroupNumber)

    def forward(self, encoder_feature, x):
        output_size = encoder_feature.size()[2:]
        x = F.interpolate(input=x, size=output_size, mode='trilinear', align_corners=True)
        x = torch.cat((encoder_feature, x), dim=1)
        x = self.basic_module(x)
        return x

class UNet3D(torch.nn.Module):
    def __init__(self, in_channels, out_channels, finalsigmoid, fmaps_degree, GroupNormNumber,
                     fmaps_layer_number, layer_order, device, **kwargs):
        super(UNet3D, self).__init__()
        self.device = device
        assert isinstance(fmaps_degree, int), 'fmaps_degree must be an integer!'
        fmaps_list = [fmaps_degree * 2 ** k for k in range(fmaps_layer_number)]

        self.EncoderLayer1 = UNet3D_Encoder(in_channels=in_channels, out_channels=fmaps_list[0], apply_pooling=False,
                                                 Basic_Module=DoubleConv, order=layer_order,
                                                 GroupNumber=GroupNormNumber).to(self.device)
        self.EncoderLayer2 = UNet3D_Encoder(in_channels=fmaps_list[0], out_channels=fmaps_list[1], apply_pooling=True,
                                                 Basic_Module=DoubleConv, order=layer_order,
                                                 GroupNumber=GroupNormNumber).to(self.device)
        self.EncoderLayer3 = UNet3D_Encoder(in_channels=fmaps_list[1], out_channels=fmaps_list[2], apply_pooling=True,
                                                 Basic_Module=DoubleConv, order=layer_order,
                                                 GroupNumber=GroupNormNumber).to(self.device)
        self.EncoderLayer4 = UNet3D_Encoder(in_channels=fmaps_list[2], out_channels=fmaps_list[3], apply_pooling=True,
                                                 Basic_Module=DoubleConv, order=layer_order,
                                                 GroupNumber=GroupNormNumber).to(self.device)

        DecoderFmapList = list(reversed(fmaps_list))

        self.DecoderLayer1 = UNet3D_Decoder(in_channels=DecoderFmapList[0] + DecoderFmapList[1],
                                                 out_channels=DecoderFmapList[1],
                                                 Basic_Module=DoubleConv, order=layer_order, GroupNumber=GroupNormNumber).to(
            self.device)
        self.DecoderLayer2 = UNet3D_Decoder(in_channels=DecoderFmapList[1] + DecoderFmapList[2],
                                                 out_channels=DecoderFmapList[2],
                                                 Basic_Module=DoubleConv, order=layer_order, GroupNumber=GroupNormNumber).to(
            self.device)
        self.DecoderLayer3 = UNet3D_Decoder(in_channels=DecoderFmapList[2] + DecoderFmapList[3],
                                                 out_channels=DecoderFmapList[3],
                                                 Basic_Module=DoubleConv, order=layer_order, GroupNumber=GroupNormNumber).to(
            self.device)

        self.final_conv = torch.nn.Conv3d(in_channels=fmaps_list[0], out_channels=out_channels, kernel_size=1).to(
            self.device)

        if finalsigmoid:
            self.final_activation = torch.nn.Sigmoid().to(self.device)
        else:
            self.final_activation = torch.nn.Softmax(dim=1).to(self.device)

    def forward(self, x):
        encoder_features = []
        x1 = self.EncoderLayer1(x)
        encoder_features.insert(0, x1.to(self.device))
        x2 = self.EncoderLayer2(x1)
        encoder_features.insert(0, x2)
        x3 = self.EncoderLayer3(x2)
        encoder_features.insert(0, x3)
        x4 = self.EncoderLayer4(x3)

        x = self.DecoderLayer1(encoder_features[0], x4)
        x = self.DecoderLayer2(encoder_features[1], x).to(self.device)
        x = self.DecoderLayer3(encoder_features[2], x)

        x = self.final_conv(x)
        if not self.training:
            x = self.final_activation(x)
        return x

# ---------------------------

set_determinism(seed=0)
print_config()


image_folder = "/kaggle/input/images"
label_folder = "/kaggle/input/labels/Labels"

images = sorted(glob.glob(os.path.join(image_folder, "**", "*.nii*"), recursive=True))
labels = sorted(glob.glob(os.path.join(label_folder, "**", "*.nii*"), recursive=True))

def basename_noext(path):
    base = os.path.basename(path)
    if base.endswith('.nii.gz'):
        return base[:-7]
    return os.path.splitext(base)[0]

img_dict = {basename_noext(p): p for p in images}
lbl_dict = {basename_noext(p): p for p in labels}
common_keys = sorted(set(img_dict.keys()).intersection(set(lbl_dict.keys())))
if len(common_keys) == 0:
    raise ValueError("No matched image/label pairs found. Check paths and filenames.")

all_data = [{"image": img_dict[k], "label": lbl_dict[k]} for k in common_keys]
train_size = int(0.8 * len(all_data))
train_data_dicts = all_data[:train_size]
val_data_dicts = all_data[train_size:]


train_transforms = Compose([
    LoadImaged(keys=["image", "label"], reader=NibabelReader()),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=400, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),

    RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96),
                           pos=1, neg=1, num_samples=2, image_key="image", image_threshold=0),
])

val_transforms = Compose([
    LoadImaged(keys=["image", "label"], reader=NibabelReader()),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=400, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image"),
])

train_ds = Dataset(data=train_data_dicts, transform=train_transforms)
val_ds = Dataset(data=val_data_dicts, transform=val_transforms)


DEFAULT_BATCH_SIZE = 1  
cpu_count = os.cpu_count() or 1
num_workers = min(4, max(0, cpu_count - 1))
pin_memory = True if torch.cuda.is_available() else False

train_loader = DataLoader(train_ds, batch_size=DEFAULT_BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=max(0, num_workers-1), pin_memory=pin_memory)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, batch_size={DEFAULT_BATCH_SIZE}, num_workers={num_workers}, pin_memory={pin_memory}")


model = UNet3D(in_channels=1, out_channels=1, finalsigmoid=True,
               fmaps_degree=32, GroupNormNumber=8, fmaps_layer_number=4, layer_order='crg',
               device=device)

model = model.to(device)

loss_function = DiceCELoss(to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# mixed precision scaler
use_amp = True if torch.cuda.is_available() else False
scaler = amp.GradScaler(enabled=use_amp)

# number of epochs
epochs = 110

os.makedirs("logs", exist_ok=True)
os.makedirs("models", exist_ok=True)
epoch_losses = []


def save_model_cpu(model, epoch, folder="models"):
    path = os.path.join(folder, f"unet3d_epoch_{epoch}.pth")

    state = {k: v.cpu() for k, v in model.state_dict().items()}
    torch.save(state, path)
    print(f"Saved model checkpoint for epoch {epoch} -> {path}")
   
def plot_prediction_vs_gt(val_images, val_labels, prediction, epoch, show=True, save_fig=False, out_folder="logs"):
    image_np = val_images.cpu().squeeze().numpy()
    label_np = val_labels.cpu().squeeze().numpy()
    prediction_np = prediction.cpu().squeeze().numpy()
    (D, H, W) = image_np.shape
    axial_slice = D // 2
    sagittal_slice = W // 2
    coronal_slice = H // 2

    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    fig.suptitle(f'Prediction vs Ground Truth - Epoch {epoch}', fontsize=14)

    axes[0, 0].imshow(image_np[axial_slice, :, :], cmap='gray'); axes[0, 0].set_title('Image (Axial)')
    axes[0, 1].imshow(label_np[axial_slice, :, :], cmap='jet'); axes[0, 1].set_title('GT (Axial)')
    axes[0, 2].imshow(prediction_np[axial_slice, :, :], cmap='jet'); axes[0, 2].set_title('Pred (Axial)')

    axes[1, 0].imshow(image_np[:, :, sagittal_slice], cmap='gray'); axes[1, 0].set_title('Image (Sagittal)')
    axes[1, 1].imshow(label_np[:, :, sagittal_slice], cmap='jet'); axes[1, 1].set_title('GT (Sagittal)')
    axes[1, 2].imshow(prediction_np[:, :, sagittal_slice], cmap='jet'); axes[1, 2].set_title('Pred (Sagittal)')

    axes[2, 0].imshow(image_np[:, coronal_slice, :], cmap='gray'); axes[2, 0].set_title('Image (Coronal)')
    axes[2, 1].imshow(label_np[:, coronal_slice, :], cmap='jet'); axes[2, 1].set_title('GT (Coronal)')
    axes[2, 2].imshow(prediction_np[:, coronal_slice, :], cmap='jet'); axes[2, 2].set_title('Pred (Coronal)')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    if save_fig:
        os.makedirs(out_folder, exist_ok=True)
        figpath = os.path.join(out_folder, f"prediction_vs_gt_epoch_{epoch}.png")
        fig.savefig(figpath)
        print(f"Saved comparison figure: {figpath}")
    if show:
        plt.show()
    plt.close(fig)


if torch.cuda.is_available():
    torch.cuda.empty_cache()


torch.backends.cudnn.benchmark = True

print("Starting training...")
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    step = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
    for batch_idx, batch_data in enumerate(pbar):
        inputs = batch_data["image"].to(device)
        labels = batch_data["label"].to(device)

        
        optimizer.zero_grad(set_to_none=True)

        try:
            with amp.autocast(enabled=use_amp):
                outputs = model(inputs)
                loss = loss_function(outputs, labels)

            # scale & backward
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()


            step += 1
            epoch_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

            del outputs, loss
            torch.cuda.empty_cache()

        except RuntimeError as e:

            if 'out of memory' in str(e).lower():
                print(f"WARNING: CUDA OOM on batch {batch_idx+1} of epoch {epoch+1}. Skipping this batch.")
                torch.cuda.empty_cache()

                optimizer.zero_grad(set_to_none=True)
                continue
            else:
                raise

    if step == 0:
        avg_epoch_loss = float("nan")
        print("Warning: no successful training steps in this epoch (all batches failed?).")
    else:
        avg_epoch_loss = epoch_loss / step

    print(f"Epoch {epoch + 1} average loss: {avg_epoch_loss:.4f}")

    
    epoch_losses.append(avg_epoch_loss)
    df_partial = pd.DataFrame({"epoch": list(range(1, len(epoch_losses) + 1)), "loss": epoch_losses})
    df_partial.to_csv("logs/epoch_losses.csv", index=False)


    save_model_cpu(model, epoch + 1, folder="models")


    if (epoch + 1) % 10 == 0:
        model.eval()
        dice_metric = DiceMetric(include_background=False, reduction="mean")
        with torch.no_grad():
            for vbatch in tqdm(val_loader, desc="Validation"):
                val_images = vbatch["image"].to(device)
                val_labels = vbatch["label"].to(device)
                with amp.autocast(enabled=use_amp):
                    val_outputs = sliding_window_inference(val_images, (96, 96, 96), 2, model, overlap=0.5)
                val_outputs = (val_outputs > 0.5).float()
                dice_metric(y_pred=val_outputs, y=val_labels)
                del val_outputs
                torch.cuda.empty_cache()
            try:
                mean_dice = dice_metric.aggregate().item()
                dice_metric.reset()
                print(f"Validation Dice Score at epoch {epoch + 1}: {mean_dice:.4f}")
            except Exception:
                print("Validation metric aggregation failed or no valid predictions.")


    if (epoch + 1) % 5 == 0:
        if len(val_loader) == 0:
            print("No validation data for visual comparison.")
        else:
            model.eval()
            with torch.no_grad():
                try:
                    val_batch = next(iter(val_loader))
                    val_images = val_batch["image"].to(device)
                    val_labels = val_batch["label"].to(device)
                    with amp.autocast(enabled=use_amp):
                        val_outputs = sliding_window_inference(val_images, (96, 96, 96), 2, model, overlap=0.5)
                    prediction = (val_outputs > 0.5).float()


                    show_img = val_images[0] if val_images.shape[0] > 1 else val_images
                    show_lbl = val_labels[0] if val_labels.shape[0] > 1 else val_labels
                    show_pred = prediction[0] if prediction.shape[0] > 1 else prediction

                    plot_prediction_vs_gt(show_img, show_lbl, show_pred, epoch + 1, show=True, save_fig=True, out_folder="logs")

                    del val_outputs, prediction
                    torch.cuda.empty_cache()
                except Exception as e:
                    print("Visual comparison failed:", e)
                    torch.cuda.empty_cache()


    if torch.cuda.is_available():
        torch.cuda.empty_cache()

df = pd.DataFrame({"epoch": list(range(1, len(epoch_losses) + 1)), "loss": epoch_losses})
df.to_csv("logs/epoch_losses.csv", index=False)
print("Training finished. Saved epoch losses to logs/epoch_losses.csv")


final_model_path = os.path.join("models", "unet3d_last_epoch.pth")
save_model_cpu(model, "final", folder="models")
print("Done.")


In [None]:
!pip install scikit-image==0.19.3
!pip install networkx
!pip install gdown --upgrade
!pip install skimage

> **Transformer model**

In [None]:
import torch
from torch import nn
from einops import rearrange
import numpy as np
import torch.nn.functional as F


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class Attention_spd(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, spd, p, attn_mask=None):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + spd
        attn = self.attend(dots)
        attn = self.dropout(attn)

        mask = torch.ones_like(attn, requires_grad=False)
        a = np.random.binomial(1, 1 - p, size=mask.shape[1])
        while np.sum(a) == 0:
            a = np.random.binomial(1, 1 - p, size=mask.shape[1])
        for i in range(mask.shape[1]):
            if a[i] == 0:
                mask[:, i, :, :] = 0

        attn = attn * mask * mask.shape[1] / np.sum(a)  # normalization

        if attn_mask is not None:
            attn = attn * attn_mask

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer_postnorm_spd(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention_spd(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x, spd, p, attn_mask=None):
        for attn, ff in self.layers:
            x = attn(x, spd, p, attn_mask) + x
            x = self.norm(x)
            x = ff(x) + x
            x = self.norm(x)
        return x


class learnabel_mask(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.pairwise_processor = PairwiseProcessing(dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, 1), 
            nn.Sigmoid()  
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, toplogy_mask):
        K = self.mlp(x).repeat(1, 1, x.shape[1])  # 1*N*N
        x = x.squeeze(0)
        xpair1 = x.unsqueeze(0).repeat(x.shape[0], 1, 1)  # N*N*d
        xpair2 = x.unsqueeze(1).repeat(1, x.shape[0], 1)  # N*N*d
        nodepair = torch.cat([xpair1, xpair2], 2)  # N*N*2d
        nodepair = nodepair.unsqueeze(0).permute(0, 3, 1, 2).contiguous()  # 1*2d*N*N
        nodepair = self.pairwise_processor(nodepair)  # Process pairwise features
        nodepair = nodepair.permute(0, 2, 3, 1).contiguous()  # 1*N*N*2
        nodepair = self.softmax(nodepair)[:, :, :, 0]
        nodepair = toplogy_mask * (nodepair + (1 - nodepair) * K) + (1 - toplogy_mask) * nodepair
        return nodepair


class Attention_cross(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x1, x2, spd=None):
        q = self.to_q(x1)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        kv = self.to_kv(x2).chunk(2, dim=-1)
        k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), kv)
        if spd is None:
            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        else:
            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + spd
        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer_postnorm_cross_spd(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention_spd(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))
        self.cross = Attention_cross(dim, heads=heads, dim_head=dim_head, dropout=dropout)

    def forward(self, x, x2, spd=None, p=0, attn_mask=None):
        if spd is not None:
            for attn, ff in self.layers:
                x = attn(x, spd, p, attn_mask) + x
                x = self.norm(x)
                x = ff(x) + x
                x = self.norm(x)
                x = self.cross(x, x2, spd) + x
                x = self.norm(x)
        return x


class Attention_cross_base(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x1, x2):
        q = self.to_q(x1)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        kv = self.to_kv(x2).chunk(2, dim=-1)
        k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), kv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer_postnorm_cross(nn.Module):  
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention_cross_base(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x, x2):
        for attn, ff in self.layers:
            x = attn(x, x2) + x
            x = self.norm(x)
            x = ff(x) + x
            x = self.norm(x)
        return x


class PairwiseProcessing(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.pairbn0 = nn.BatchNorm2d(dim * 2)
        self.pairbn1 = nn.BatchNorm2d(dim)
        self.pairbn2 = nn.BatchNorm2d(dim // 2)
        self.pairbn3 = nn.BatchNorm2d(dim // 4)
        self.pairbn4 = nn.BatchNorm2d(dim // 8)

        self.pairconv10 = nn.Conv2d(dim * 2, dim * 2, kernel_size=1, bias=False)
        self.pairconv11 = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=False)
        self.pairconv20 = nn.Conv2d(dim, dim, kernel_size=1, bias=False)
        self.pairconv21 = nn.Conv2d(dim, dim // 2, kernel_size=1, bias=False)
        self.pairconv3 = nn.Conv2d(dim // 2, dim // 4, kernel_size=1, bias=False)
        self.pairconv4 = nn.Conv2d(dim // 4, dim // 8, kernel_size=1, bias=False)
        self.pairconv5 = nn.Conv2d(dim // 8, 2, kernel_size=1, bias=False)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, nodepair):
        nodepair = self.pairbn0(nodepair)  # 1*d*N*C 128
        nodepair = self.relu(self.pairbn1(self.pairconv11(self.relu(self.pairconv10(nodepair)))))  # 1*64*N*C
        nodepair = self.relu(self.pairbn2(self.pairconv21(self.relu(self.pairconv20(nodepair)))))  # 1*32*N*C
        nodepair = self.relu(self.pairbn3(self.pairconv3(nodepair)))  # 1*16*N*C
        nodepair = self.relu(self.pairbn4(self.pairconv4(nodepair)))  # 1*8*N*C
        nodepair = self.pairconv5(nodepair)  # 1*2*N*C
        return nodepair


class Outlier_detect(nn.Module):
    def __init__(self, depth, dim, heads, mlp_dim, dim_head=64, prototype_class=22, dropout=0., ):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1))
        self.pairwise_processor = PairwiseProcessing(dim)
        self.softmax = nn.Softmax(dim=-1)
        self.mlp = nn.Sequential(
            nn.LayerNorm(prototype_class * 2),
            nn.Linear(prototype_class * 2, 1),  
            nn.Sigmoid()  
        )
        self.relu = nn.ReLU(inplace=True)
        self.Trans_cross = Transformer_postnorm_cross(dim, depth, heads, dim_head, mlp_dim, dropout)

    def forward(self, x, logits):
        S = F.softmax(logits, dim=-1)
        S = S ** self.alpha
        H = torch.matmul(S.transpose(-1, -2), x)
        H = self.Trans_cross(H, x)

        x = x.squeeze(0)
        H = H.squeeze(0)
        xpair1 = H.unsqueeze(0).repeat(x.shape[0], 1, 1)  # N*C*d
        xpair2 = x.unsqueeze(1).repeat(1, H.shape[0], 1)  # N*C*d
        nodepair = torch.cat([xpair1, xpair2], 2)  # N*C*2d
        nodepair = nodepair.unsqueeze(0).permute(0, 3, 1, 2).contiguous()  # 1*2d*N*C
        nodepair = self.pairwise_processor(nodepair)  # Process pairwise features
        nodepair = nodepair.permute(0, 2, 3, 1).contiguous()  # 1*N*C*2
        nodepair = nodepair.view(nodepair.shape[0], nodepair.shape[1], -1)  # 1*N*2C
        outlier = self.mlp(nodepair)
        return outlier


class Stage_independent(nn.Module):
    def __init__(self, depth, outlier_depth, num_classes1, num_classes2, num_classes3, dim, heads, mlp_dim, dim_head=64,
                 dropout=0., ):
        super().__init__()
        hierarchy = [depth, depth, depth, depth]
        self.transformer = nn.ModuleList([])
        self.dense_linear = nn.ModuleList(
            [nn.Linear((i + 1) * dim, dim) for i in range(len(hierarchy))]
        )
        for d in hierarchy:
            self.transformer.append(
                Transformer_postnorm_spd(dim, d, heads, dim_head, mlp_dim, dropout)
            )
        self.att_mask = learnabel_mask(dim=dim)
        self.mlp_head1 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes1)
        )

        self.mlp_head2 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes2)
        )
        self.mlp_head3 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes3)
        )
        self.outlier = Outlier_detect(outlier_depth, dim, heads, mlp_dim, dim_head=64, dropout=0.,
                                      prototype_class=num_classes2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, spd, p, toplogy_mask):

        x_ = []
        list = []
        list.append(x)
        pred_ = []
        for i in range(len(self.transformer) - 1):
            x = self.dense_linear[i](torch.cat(list, dim=-1))
            if i == 0:
                x = self.transformer[i](x, spd[i], p, None)
                pred = self.mlp_head1(x)
            if i == 1:
                x = self.transformer[i](x, spd[i], p, None)
                pred = self.mlp_head2(x)
                nodepair = self.att_mask(x, toplogy_mask)
                prior = self.softmax(nodepair)
                pred = torch.matmul(prior, pred)
                outlier = self.outlier(x, pred)
                outlier_mask = 1 - (
                        outlier.repeat(1, 1, x.shape[1]) - outlier.transpose(1, 2).repeat(1, x.shape[1], 1)) ** 2

            if i == 2:
                x = self.transformer[i](x, spd[i], p, nodepair * outlier_mask)  
                x = self.transformer[i + 1](x, spd[i], p, None)
                pred = self.mlp_head3(x)
            x_.append(x)
            list.append(x)
            pred_.append(pred)
        return x_[0], x_[1], x_[2], pred_[0], pred_[1], pred_[2], nodepair, outlier


class Stage_guided(nn.Module):
    def __init__(self, input_depth, outlier_depth, num_classes1, num_classes2, num_classes3, dim, heads, mlp_dim,
                 dim_head=64,
                 dropout=0.):
        super().__init__()
        hierarchy = [input_depth, input_depth, input_depth, input_depth]
        self.transformer = nn.ModuleList([])
        self.dense_linear = nn.ModuleList(
            [nn.Linear((i + 1) * dim, dim) for i in range(len(hierarchy))]
        )
        layer_num = 0
        for d in hierarchy:
            if layer_num >= 2:
                self.transformer.append(
                    Transformer_postnorm_spd(dim, d, heads, dim_head, mlp_dim, dropout)
                )
            else:
                self.transformer.append(
                    Transformer_postnorm_cross_spd(dim, d, heads, dim_head, mlp_dim, dropout)
                )
            layer_num += 1
        self.att_mask = learnabel_mask(dim=dim)
        self.mlp_head1 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes1)
        )

        self.mlp_head2 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes2)
        )
        self.mlp_head3 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes3)
        )

        self.outlier = Outlier_detect(outlier_depth, dim, heads, mlp_dim, dim_head=64, dropout=0.,
                                      prototype_class=num_classes2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, spd, x2, p, toplogy_mask, outlier):
        x_ = []
        list = []
        list.append(x)
        pred_ = []
        for i in range(len(self.transformer) - 1):
            x = self.dense_linear[i](torch.cat(list, dim=-1))
            if i == 0:
                x = self.transformer[i](x, x2[i], spd[i], p, None)
                pred = self.mlp_head1(x)
            if i == 1:
                x = self.transformer[i](x, x2[i], spd[i], p, None)
                pred = self.mlp_head2(x)
                nodepair = self.att_mask(x, toplogy_mask)
                prior = self.softmax(nodepair)
                pred = torch.matmul(prior, pred)
                outlier = self.outlier(x, pred)
                outlier_mask = 1 - (outlier.repeat(1, 1, x.shape[1]) - outlier.transpose(1, 2).repeat(1, x.shape[1],
                                                                                                      1)) ** 2
            if i == 2:
                x = self.transformer[i](x, spd[i], p, nodepair * outlier_mask)
                x = self.transformer[i + 1](x, spd[i], p, None)
                pred = self.mlp_head3(x)
            x_.append(x)
            list.append(x)
            pred_.append(pred)

        return x_[0], x_[1], x_[2], pred_[0], pred_[1], pred_[2], nodepair, outlier


class our_net(nn.Module):
    def __init__(self, input_dim, num_classes1, num_classes2, num_classes3, dim, heads, mlp_dim, dim_head=64,
                 dropout=0., trans_depth=2, outlier_depth=2):
        super().__init__()

        self.accecpt = Stage_guided(trans_depth, outlier_depth, num_classes1, num_classes2, num_classes3, dim, heads,
                                    mlp_dim,
                                    dim_head=dim_head,
                                    dropout=dropout)

        self.give = Stage_independent(trans_depth, outlier_depth, num_classes1, num_classes2, num_classes3, dim, heads,
                                      mlp_dim, dim_head=dim_head,
                                      dropout=dropout)

        self.to_embedding = nn.Sequential(nn.Linear(input_dim, dim))
        self.spatial_pos_encoders = nn.ModuleList([nn.Embedding(30, heads, padding_idx=0) for _ in range(3)])
        self.softmax = nn.Softmax(dim=-1)

    def _get_dict(self, spd):
        """Encodes spatial position and prepares the dict."""
        return [encoder(spd).permute(0, 3, 1, 2) for encoder in self.spatial_pos_encoders]

    def forward(self, x, toplogy_mask, spd, p):
        x = self.to_embedding(x).unsqueeze(0)
        spd = spd.unsqueeze(0)
        dict = self._get_dict(spd)

        # First stage
        feature1_1, feature2_1, feature3_1, x1_1, x2_1, x3_1, node_pair1, outlier_1 = self.give(x, dict, p,
                                                                                                toplogy_mask)

        # Cross-stage input
        x_cross = [feature2_1, feature3_1]
        feature1_2, feature2_2, feature3_2, x1_2, x2_2, x3_2, node_pair2, outlier_2 = self.accecpt(x, dict, x_cross, p,
                                                                                                   toplogy_mask, None)

        # Process and return outputs
        x1_1 = x1_1.squeeze(0)
        x2_1 = x2_1.squeeze(0)
        x3_1 = x3_1.squeeze(0)
        node_pair1 = node_pair1.squeeze(0)
        outlier_1 = outlier_1.squeeze(0)
        outlier_1 = outlier_1.squeeze(-1)

        node_pair2 = node_pair2.squeeze(0)
        x1_2 = x1_2.squeeze(0)
        x2_2 = x2_2.squeeze(0)
        x3_2 = x3_2.squeeze(0)
        outlier_2 = outlier_2.squeeze(0)
        outlier_2 = outlier_2.squeeze(-1)
        return x1_1, x2_1, x3_1, x1_2, x2_2, x3_2, node_pair1, node_pair2, outlier_1, outlier_2


In [None]:
import networkx as nx
import torch
import numpy as np

def extract_graph_features_for_transformer(airway_graph):
    node_features = []
    
    if airway_graph.number_of_nodes() > 0:
        for node_id, data in airway_graph.nodes(data=True):
            pos = data['pos']
            node_features.append(list(pos))
            
    return torch.tensor(node_features, dtype=torch.float32)

def generate_topology_mask(airway_graph, max_nodes):
    adj_matrix = nx.to_numpy_array(airway_graph, nodelist=sorted(airway_graph.nodes()))
    
    padded_adj_matrix = np.zeros((max_nodes, max_nodes), dtype=np.float32)
    padded_adj_matrix[:adj_matrix.shape[0], :adj_matrix.shape[1]] = adj_matrix
    
    topology_mask = torch.tensor(padded_adj_matrix, dtype=torch.bool)
    
    return topology_mask

def process_with_transformer(airway_graph, transformer_model):
    max_nodes = 100 
    
    node_features = extract_graph_features_for_transformer(airway_graph)
    
    if node_features.shape[0] > max_nodes:
        node_features = node_features[:max_nodes]
    
    num_nodes = node_features.shape[0]
    
    padded_node_features = torch.zeros(max_nodes, node_features.shape[1], dtype=torch.float32)
    padded_node_features[:num_nodes, :] = node_features
    
    topology_mask = generate_topology_mask(airway_graph, max_nodes)
    
    spd = torch.rand(num_nodes, max_nodes, 4)
    p = torch.zeros(num_nodes, dtype=torch.long)
    
    with torch.no_grad():
        output = transformer_model(x=padded_node_features, toplogy_mask=topology_mask, spd=spd, p=p)
        
    return output

if 'airway_graph' in locals():
    print("Preparing graph data for transformer...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    from transformer import Transformer_postnorm_spd
    
    transformer_model = Transformer_postnorm_spd(
        dim=3, heads=4, dim_head=64, mlp_dim=128, dropout=0.1, attention_dropout=0.1
    ).to(device)
    
    graph_output = process_with_transformer(airway_graph, transformer_model)
    
    print("Transformer output shape:", [o.shape for o in graph_output])
else:
    print("The variable 'airway_graph' was not found. Please run the graph generation code first.")