In [1]:
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
import random
import albumentations as A
from sklearn.metrics import roc_auc_score
from torch.optim.lr_scheduler import SequentialLR, ConstantLR, CosineAnnealingWarmRestarts
import random
import albumentations as A
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.ndimage
import plotly.graph_objects as go
import plotly.io as pio
import math
import torch.utils.model_zoo as model_zoo
pio.renderers.default = "notebook"

  check_for_updates()


In [15]:
import os
import random
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.ndimage as ndi
# You may need to install PraNet if it's a custom or external library
# from lib.PraNet_Res2Net import PraNet

# =====================================================================================
# 0. Model Placeholder 
# =====================================================================================



__all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s']

model_urls = {
    'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
    'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
}
class Bottle2neck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
        """ Constructor
        Args:
            inplanes: input channel dimensionality
            planes: output channel dimensionality
            stride: conv stride. Replaces pooling layer.
            downsample: None when stride = 1
            baseWidth: basic width of conv3x3
            scale: number of scale.
            type: 'normal': normal set. 'stage': first block of a new stage.
        """
        super(Bottle2neck, self).__init__()

        width = int(math.floor(planes * (baseWidth / 64.0)))
        self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width * scale)

        if scale == 1:
            self.nums = 1
        else:
            self.nums = scale - 1
        if stype == 'stage':
            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))
            bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stype = stype
        self.scale = scale
        self.width = width

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0 or self.stype == 'stage':
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        if self.scale != 1 and self.stype == 'normal':
            out = torch.cat((out, spx[self.nums]), 1)
        elif self.scale != 1 and self.stype == 'stage':
            out = torch.cat((out, self.pool(spx[self.nums])), 1)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Res2Net(nn.Module):

    def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
        self.inplanes = 64
        super(Res2Net, self).__init__()
        self.baseWidth = baseWidth
        self.scale = scale
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False)
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=stride, stride=stride,
                             ceil_mode=True, count_include_pad=False),
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                            stype='stage', baseWidth=self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def res2net50_v1b(pretrained=True, **kwargs):
    """Constructs a Res2Net-50_v1b lib.
    Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
    Args:
        pretrained (bool): If True, returns a lib pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
    return model
def res2net101_v1b(pretrained=True, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s lib.
    Args:
        pretrained (bool): If True, returns a lib pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model


def res2net50_v1b_26w_4s(pretrained=True, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s lib.
    Args:
        pretrained (bool): If True, returns a lib pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
#         model_state = torch.load('/content/drive/MyDrive/ROAD TO MICCAI/MoNuSeg/MonuSeg notebooks/previous notebooks/res2net50_v1b_26w_4s-3cf99910.pth')
#         model.load_state_dict(model_state)
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
    return model


def res2net101_v1b_26w_4s(pretrained=True, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s lib.
    Args:
        pretrained (bool): If True, returns a lib pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model


def res2net152_v1b_26w_4s(pretrained=True, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s lib.
    Args:
        pretrained (bool): If True, returns a lib pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s']))
    return model

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class RFB_modified(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RFB_modified, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
        )
        self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

        x = self.relu(x_cat + self.conv_res(x))
        return x


class aggregation(nn.Module):
    # dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on.
    # used after MSF
    def __init__(self, channel):
        super(aggregation, self).__init__()
        self.relu = nn.ReLU(True)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)

        self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
        self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
        self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
        self.conv5 = nn.Conv2d(3*channel, 1, 1)

    def forward(self, x1, x2, x3):
        x1_1 = x1
        x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
        x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
               * self.conv_upsample3(self.upsample(x2)) * x3

        x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
        x2_2 = self.conv_concat2(x2_2)

        x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
        x3_2 = self.conv_concat3(x3_2)

        x = self.conv4(x3_2)
        x = self.conv5(x)

        return x


class PraNet(nn.Module):
    # res2net based encoder decoder
    def __init__(self, channel=32,num_classes=1):
        super(PraNet, self).__init__()
        # ---- ResNet Backbone ----
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)
        # ---- Receptive Field Block like module ----
        self.rfb2_1 = RFB_modified(512, channel)
        self.rfb3_1 = RFB_modified(1024, channel)
        self.rfb4_1 = RFB_modified(2048, channel)
        # ---- Partial Decoder ----
        self.agg1 = aggregation(channel)
        # ---- reverse attention branch 4 ----
        self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1)
        self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2)
        self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2)
        self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2)
        self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1)
        # ---- reverse attention branch 3 ----
        self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1)
        self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)
        # ---- reverse attention branch 2 ----
        self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1)
        self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv4 = BasicConv2d(64, num_classes, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)      # bs, 64, 88, 88
        # ---- low-level features ----
        x1 = self.resnet.layer1(x)      # bs, 256, 88, 88
        x2 = self.resnet.layer2(x1)     # bs, 512, 44, 44

        x3 = self.resnet.layer3(x2)     # bs, 1024, 22, 22
        x4 = self.resnet.layer4(x3)     # bs, 2048, 11, 11
        x2_rfb = self.rfb2_1(x2)        # channel -> 32
        x3_rfb = self.rfb3_1(x3)        # channel -> 32
        x4_rfb = self.rfb4_1(x4)        # channel -> 32

        ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb)
        lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear')    # NOTES: Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352)

        # ---- reverse attention branch_4 ----
        crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear')
        x = -1*(torch.sigmoid(crop_4)) + 1
        x = x.expand(-1, 2048, -1, -1).mul(x4)
        x = self.ra4_conv1(x)
        x = F.relu(self.ra4_conv2(x))
        x = F.relu(self.ra4_conv3(x))
        x = F.relu(self.ra4_conv4(x))
        ra4_feat = self.ra4_conv5(x)
        x = ra4_feat + crop_4
        lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear')  # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352)

        # ---- reverse attention branch_3 ----
        crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = -1*(torch.sigmoid(crop_3)) + 1
        x = x.expand(-1, 1024, -1, -1).mul(x3)
        x = self.ra3_conv1(x)
        x = F.relu(self.ra3_conv2(x))
        x = F.relu(self.ra3_conv3(x))
        ra3_feat = self.ra3_conv4(x)
        x = ra3_feat + crop_3
        lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear')  # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352)

        # ---- reverse attention branch_2 ----
        crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = -1*(torch.sigmoid(crop_2)) + 1
        x = x.expand(-1, 512, -1, -1).mul(x2)
        x = self.ra2_conv1(x)
        x = F.relu(self.ra2_conv2(x))
        x = F.relu(self.ra2_conv3(x))
        ra2_feat = self.ra2_conv4(x)
        x = ra2_feat + crop_2
        lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear')   # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352)

        # return lateral_map_2,lateral_map_3, lateral_map_4, lateral_map_5
        return lateral_map_2
def test():
    print('testing the code')
    x = torch.randn((2, 3, 224,224)) # batch, channel, H, W    
    model = PraNet(num_classes=6)
    # print(model)
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    pytorch_total__trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('total number of parameters: {}'.format(pytorch_total_params))
    print('total number of trainable parameters: {}'.format(pytorch_total__trainable_params))
    model.eval()
    model.cuda()
    with torch.no_grad():
      preds3 = model(x.cuda())
    print(f'Input shape : {x.shape}')
    print(f'Output shape : {preds3.shape}')
    
if __name__ == "__main__":
    test()

# =====================================================================================
# 1. Configuration
# =====================================================================================

# --- Core Settings ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 384
BATCH_SIZE = 8
LEARNING_RATE_S1 = 1e-4
LEARNING_RATE_S2 = 5e-5 # Lower learning rate for fine-tuning
NUM_EPOCHS_STAGE_1 = 1
NUM_EPOCHS_STAGE_2 = 1
NUM_WORKERS = 2
PIN_MEMORY = True
VALIDATION_SPLIT = 0.2
MIN_LESION_AREA = 50 # Post-processing threshold

# --- Path Settings ---
# Note: Replace with your actual data paths
ORIGINAL_DATA_DIR = "/kaggle/input/picai-processed/with_mask/kaggle/working/picai_processed_resampled"
UNLABELED_DATA_DIR = "/kaggle/input/picai-processed/no_mask/kaggle/working/picai_processed_incomplete_cases"
OUTPUT_DIR = "/kaggle/working/semisupervised_output"

# --- Derived Paths ---
PSEUDO_MASK_DIR = os.path.join(OUTPUT_DIR, "pseudo_masks")
STAGE_1_MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "best_supervised_model.pth")
FINAL_MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "best_semisupervised_model.pth")

os.makedirs(PSEUDO_MASK_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)


# =====================================================================================
# 2. Model, Loss, and Data Handling
# =====================================================================================

class DiceBCELoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceBCELoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        # inputs are logits, targets are binary masks
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets)
        inputs_prob = torch.sigmoid(inputs)
        intersection = (inputs_prob * targets).sum()
        dice_score = (2. * intersection + self.smooth) / (inputs_prob.sum() + targets.sum() + self.smooth)
        return bce_loss + (1 - dice_score)


class SemiSupPiCaiDataset(Dataset):
    def __init__(self, base_dir, sample_ids, mask_dir, is_validation=False):
        self.base_dir = base_dir
        self.sample_ids = sample_ids
        self.mask_dir = mask_dir
        self.slice_infos = []

        if not self.sample_ids: return

        # Determine slice axis from the first sample
        first_mask_path = os.path.join(mask_dir, f'{self.sample_ids[0]}.npy')

        # Check if the first mask file exists before trying to load it
        if not os.path.exists(first_mask_path):
            print(f"Warning: First mask file not found at {first_mask_path}. Cannot initialize dataset.")
            # Handle the case where pseudo-masks might not have been generated for the first ID
            # A more robust approach might be to find the first available mask
            for s_id in self.sample_ids:
                potential_path = os.path.join(mask_dir, f'{s_id}.npy')
                if os.path.exists(potential_path):
                    first_mask_path = potential_path
                    print(f"Found an existing mask instead: {first_mask_path}")
                    break
            else: # No masks found at all
                print("Error: No mask files found in the specified directory. Aborting dataset creation.")
                return

        data_shape = np.load(first_mask_path).shape
        self.slice_axis = np.argmin(data_shape)

        desc = "Finding validation slices" if is_validation else "Finding training slices"
        for sample_idx, sample_id in enumerate(tqdm(self.sample_ids, desc=desc)):
            mask_path = os.path.join(self.mask_dir, f'{sample_id}.npy')
            if not os.path.exists(mask_path):
                continue # Skip if a pseudo-mask failed to generate for this ID

            mask_3d = np.load(mask_path)
            num_slices = mask_3d.shape[self.slice_axis]

            # For final evaluation, we need all slices. Otherwise, balance them.
            if is_validation:
                for slice_idx in range(num_slices):
                    self.slice_infos.append((sample_idx, slice_idx))
            else:
                pos_slices = [i for i in range(num_slices) if np.sum(np.take(mask_3d, i, self.slice_axis)) > 0]
                neg_slices = [i for i in range(num_slices) if np.sum(np.take(mask_3d, i, self.slice_axis)) == 0]
                num_pos = len(pos_slices)

                for slice_idx in pos_slices: self.slice_infos.append((sample_idx, slice_idx))

                if neg_slices: # Add an equal number of negative slices
                    neg_samples = random.sample(neg_slices, min(num_pos, len(neg_slices)))
                    for slice_idx in neg_samples: self.slice_infos.append((sample_idx, slice_idx))

        random.shuffle(self.slice_infos)
        print(f"Loaded {len(self.slice_infos)} 2D slices for this dataset partition.")

    def __len__(self):
        return len(self.slice_infos)

    def __getitem__(self, idx):
        sample_idx, slice_idx = self.slice_infos[idx]
        sample_id = self.sample_ids[sample_idx]

        modalities = [np.load(os.path.join(self.base_dir, m, f'{sample_id}.npy')) for m in ['t2w', 'adc', 'hbv']]
        img_3d = np.stack(modalities, axis=-1)

        mask_path = os.path.join(self.mask_dir, f'{sample_id}.npy')
        mask_3d = np.load(mask_path)

        image_slice = np.take(img_3d, slice_idx, axis=self.slice_axis)
        mask_slice = np.take(mask_3d, slice_idx, axis=self.slice_axis)

        # Ensure mask is binary
        mask_slice = (mask_slice > 0).astype(np.float32)

        return image_slice, mask_slice


class AugmentationWrapper(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image_np, mask_np = self.dataset[idx]

        if self.transform:
            augmented = self.transform(image=image_np.astype(np.float32), mask=mask_np)
            image_np, mask_np = augmented['image'], augmented['mask']

        # CRITICAL: Apply percentile normalization AFTER augmentations but BEFORE ToTensor
        for i in range(image_np.shape[2]):
            channel = image_np[:, :, i]
            non_zero = channel[channel > 1e-6]
            if non_zero.size > 0:
                p1, p99 = np.percentile(non_zero, 1), np.percentile(non_zero, 99)
                channel = np.clip(channel, p1, p99)
            min_val, max_val = channel.min(), channel.max()
            image_np[:, :, i] = (channel - min_val) / (max_val - min_val) if max_val > min_val else 0

        image = torch.from_numpy(image_np.transpose(2, 0, 1)).float()
        mask = torch.from_numpy(mask_np).unsqueeze(0).float() # Add channel dim
        return image, mask


# =====================================================================================
# 3. Training and Mid-Training Validation
# =====================================================================================

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader, desc="Training")
    for data, targets in loop:
        data, targets = data.to(DEVICE), targets.to(DEVICE)
        with torch.amp.autocast(device_type=DEVICE.split(':')[0], enabled=(DEVICE=="cuda")):
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loop.set_postfix(loss=loss.item())

def check_accuracy(loader, model, device="cuda"):
    """ Fast check on a balanced validation set for model selection. """
    model.eval()
    dice_num, dice_den = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            with torch.amp.autocast(device_type=device.split(':')[0], enabled=(device=="cuda")):
                preds = torch.sigmoid(model(x))
            preds_binary = (preds > 0.5).float()
            dice_num += 2 * (preds_binary * y).sum()
            dice_den += preds_binary.sum() + y.sum()
    model.train()
    dice_score = (dice_num + 1e-8) / (dice_den + 1e-8)
    return float(dice_score)

# =====================================================================================
# 4. Final Evaluation (2D and 3D)
# =====================================================================================

def calculate_final_2d_metrics(val_loader_all_slices, model, device):
    print("\nCalculating Final 2D Metrics (on all validation slices)...")
    model.eval()
    all_preds_flat, all_targets_flat = [], []
    total_dice_score, slice_count = 0, 0

    with torch.no_grad():
        for x, y in tqdm(val_loader_all_slices, desc="2D Evaluation"):
            x, y = x.to(device), y.to(device)
            with torch.amp.autocast(device_type=device.split(':')[0], enabled=(device=="cuda")):
                preds_prob = torch.sigmoid(model(x))

            for i in range(x.size(0)):
                pred_prob_slice = preds_prob[i].squeeze()
                target_slice = y[i].squeeze()
                pred_binary_slice = (pred_prob_slice > 0.5).float()

                # Dice Score
                intersection = (pred_binary_slice * target_slice).sum()
                union = pred_binary_slice.sum() + target_slice.sum()
                total_dice_score += (2. * intersection + 1e-8) / (union + 1e-8)
                slice_count += 1

                # For AUROC/AP
                all_preds_flat.append(pred_prob_slice.view(-1).cpu().numpy())
                all_targets_flat.append(target_slice.view(-1).cpu().numpy().astype(bool))

    # Aggregate and compute final metrics
    avg_dice = total_dice_score / slice_count if slice_count > 0 else 0

    y_true_all = np.concatenate(all_targets_flat)
    y_pred_all = np.concatenate(all_preds_flat)

    auroc = roc_auc_score(y_true_all, y_pred_all) if len(np.unique(y_true_all)) > 1 else 0
    ap = average_precision_score(y_true_all, y_pred_all) if len(np.unique(y_true_all)) > 1 else 0

    return {"2d_dice": float(avg_dice), "2d_auroc": auroc, "2d_ap": ap}


def calculate_final_3d_metrics(val_ids, base_dir, model, device):
    print("\nCalculating Final 3D Metrics (on validation volumes)...")
    model.eval()

    # Define the exact same transform used for training/validation
    eval_transform = AugmentationWrapper(dataset=None, transform=A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE)]))

    patient_metrics = []

    with torch.no_grad():
        for patient_id in tqdm(val_ids, desc="3D Evaluation"):
            try:
                # Load ground truth volume
                gt_path = os.path.join(base_dir, 'mask', f'{patient_id}.npy')
                if not os.path.exists(gt_path): continue
                gt_volume = np.load(gt_path)
                gt_volume_binary = (gt_volume > 0).astype(bool)

                # Determine slice axis and prepare for reconstruction
                slice_axis = np.argmin(gt_volume.shape)
                pred_volume_prob = np.zeros_like(gt_volume, dtype=float)

                # Iterate over each slice for the patient
                for slice_idx in range(gt_volume.shape[slice_axis]):
                    # Manually create the input slice
                    modalities = [np.load(os.path.join(base_dir, m, f'{patient_id}.npy')) for m in ['t2w', 'adc', 'hbv']]
                    image_slice_np = np.stack([np.take(vol, slice_idx, axis=slice_axis) for vol in modalities], axis=-1)

                    # Apply the same augmentation/normalization
                    eval_transform.dataset = [(image_slice_np, np.zeros_like(image_slice_np[..., 0]))]
                    image_tensor, _ = eval_transform[0]
                    image_tensor = image_tensor.unsqueeze(0).to(device)

                    # Get model prediction
                    with torch.amp.autocast(device_type=device.split(':')[0], enabled=(device=="cuda")):
                        pred_prob = torch.sigmoid(model(image_tensor)).squeeze()

                    # Resize back to original dimensions and place in volume
                    original_dims = (gt_volume.shape[1], gt_volume.shape[2]) # H, W
                    pred_prob_resized = F.interpolate(pred_prob.unsqueeze(0).unsqueeze(0), size=original_dims, mode='bilinear', align_corners=False).squeeze().cpu().numpy()

                    # Place the resized slice back into the 3D volume using direct assignment
                    slicer = [slice(None)] * pred_volume_prob.ndim
                    slicer[slice_axis] = slice_idx
                    pred_volume_prob[tuple(slicer)] = pred_prob_resized

                # Calculate metrics for the entire volume
                pred_volume_binary = (pred_volume_prob > 0.5)

                # 3D Dice
                intersection = (pred_volume_binary & gt_volume_binary).sum()
                union = pred_volume_binary.sum() + gt_volume_binary.sum()
                dice_3d = (2. * intersection + 1e-8) / (union + 1e-8)

                # 3D AUROC & AP
                gt_flat = gt_volume_binary.flatten()
                pred_flat = pred_volume_prob.flatten()

                auroc_3d = roc_auc_score(gt_flat, pred_flat) if len(np.unique(gt_flat)) > 1 else 0
                ap_3d = average_precision_score(gt_flat, pred_flat) if len(np.unique(gt_flat)) > 1 else 0

                patient_metrics.append({"dice": dice_3d, "auroc": auroc_3d, "ap": ap_3d})

            except Exception as e:
                print(f"Warning: Could not process patient {patient_id} for 3D eval. Error: {e}")

    # Average metrics across all patients
    avg_dice = np.mean([m['dice'] for m in patient_metrics]) if patient_metrics else 0
    avg_auroc = np.mean([m['auroc'] for m in patient_metrics]) if patient_metrics else 0
    avg_ap = np.mean([m['ap'] for m in patient_metrics]) if patient_metrics else 0

    model.train()
    return {"3d_dice": avg_dice, "3d_auroc": avg_auroc, "3d_ap": avg_ap}


def generate_pseudo_labels(unlabeled_ids, base_dir, out_dir, model, device):
    """Slightly simplified version of 3D eval to just save masks."""
    model.eval()
    eval_transform = AugmentationWrapper(dataset=None, transform=A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE)]))
    with torch.no_grad():
        for patient_id in tqdm(unlabeled_ids, desc="Generating Pseudo-Masks"):
            try:
                ref_vol_path = os.path.join(base_dir, 't2w', f'{patient_id}.npy')
                if not os.path.exists(ref_vol_path): continue
                ref_vol = np.load(ref_vol_path)
                slice_axis = np.argmin(ref_vol.shape)
                pred_volume = np.zeros_like(ref_vol, dtype=np.uint8)

                for slice_idx in range(ref_vol.shape[slice_axis]):
                    modalities = [np.load(os.path.join(base_dir, m, f'{patient_id}.npy')) for m in ['t2w', 'adc', 'hbv']]
                    image_slice_np = np.stack([np.take(vol, slice_idx, axis=slice_axis) for vol in modalities], axis=-1)

                    eval_transform.dataset = [(image_slice_np, np.zeros_like(image_slice_np[...,0]))]
                    image_tensor, _ = eval_transform[0]
                    image_tensor = image_tensor.unsqueeze(0).to(device)

                    with torch.amp.autocast(device_type=device.split(':')[0], enabled=(device=="cuda")):
                        pred_prob = torch.sigmoid(model(image_tensor)).squeeze()

                    pred_binary = (pred_prob > 0.5).cpu().numpy().astype(np.uint8)
                    original_dims = (ref_vol.shape[1], ref_vol.shape[2]) # H, W

                    # ✅ CORRECTED: Instantiate Resize and then apply it.
                    resizer = A.Resize(height=original_dims[0], width=original_dims[1], interpolation=0)
                    pred_resized = resizer(image=pred_binary)['image']

                    # Post-process to remove small artifacts
                    pred_processed = remove_small_lesions(pred_resized, MIN_LESION_AREA)

                    # Place the processed slice back into the 3D volume using direct assignment
                    slicer = [slice(None)] * pred_volume.ndim
                    slicer[slice_axis] = slice_idx
                    pred_volume[tuple(slicer)] = pred_processed

                np.save(os.path.join(out_dir, f"{patient_id}.npy"), pred_volume)
            except Exception as e:
                print(f"Warning: Could not pseudo-label patient {patient_id}. Error: {e}")
    model.train()

def remove_small_lesions(mask_np, min_size):
    """Binary version of remove_small_lesions."""
    labeled_array, num_features = ndi.label(mask_np)
    if num_features == 0:
        return mask_np

    component_sizes = np.bincount(labeled_array.ravel())
    large_enough = component_sizes > min_size
    large_enough[0] = False # Background is not a lesion

    return large_enough[labeled_array].astype(np.uint8)


# =====================================================================================
# 5. Main Execution Block
# =====================================================================================
def main():
    # --- Data Setup ---
    all_original_ids = sorted([f.replace('.npy', '') for f in os.listdir(os.path.join(ORIGINAL_DATA_DIR, 'mask')) if f.endswith('.npy')])
    random.seed(42); random.shuffle(all_original_ids)
    split_idx = int(len(all_original_ids) * (1 - VALIDATION_SPLIT))
    original_train_ids, val_ids = all_original_ids[:split_idx], all_original_ids[split_idx:]
    print(f"Labeled Data Split: {len(original_train_ids)} training, {len(val_ids)} validation patients.")

    # --- Transforms ---
    train_transform = A.Compose([
        A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE), A.Rotate(limit=15, p=0.5),
        A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1),
        A.RandomBrightnessContrast(p=0.3), A.GaussNoise(p=0.2)
    ])
    val_transform = A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE)])

    # --- STAGE 1: Supervised Training ---
    print("\n--- STAGE 1: Starting Supervised Training ---")
    model = PraNet().to(DEVICE)
    loss_fn = DiceBCELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE_S1)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))

    # Create datasets for Stage 1
    train_base_s1 = SemiSupPiCaiDataset(ORIGINAL_DATA_DIR, original_train_ids, mask_dir=os.path.join(ORIGINAL_DATA_DIR, 'mask'))
    val_base_s1 = SemiSupPiCaiDataset(ORIGINAL_DATA_DIR, val_ids, mask_dir=os.path.join(ORIGINAL_DATA_DIR, 'mask'), is_validation=True) # Use full validation set here
    train_dataset_s1 = AugmentationWrapper(train_base_s1, transform=train_transform)
    val_dataset_s1 = AugmentationWrapper(val_base_s1, transform=val_transform)

    train_loader_s1 = DataLoader(train_dataset_s1, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    # For model selection, use a balanced loader from the validation set to be consistent
    val_base_s1_balanced = SemiSupPiCaiDataset(ORIGINAL_DATA_DIR, val_ids, mask_dir=os.path.join(ORIGINAL_DATA_DIR, 'mask'), is_validation=False)
    val_dataset_s1_balanced = AugmentationWrapper(val_base_s1_balanced, transform=val_transform)
    val_loader_s1_balanced = DataLoader(val_dataset_s1_balanced, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    best_val_dice = -1.0
    for epoch in range(NUM_EPOCHS_STAGE_1):
        print(f"\n--- Stage 1, Epoch {epoch+1}/{NUM_EPOCHS_STAGE_1} ---")
        train_fn(train_loader_s1, model, optimizer, loss_fn, scaler)

        current_dice = check_accuracy(val_loader_s1_balanced, model, device=DEVICE)
        print(f"Validation Dice (Balanced Set): {current_dice:.4f}")

        if current_dice > best_val_dice:
            best_val_dice = current_dice
            torch.save(model.state_dict(), STAGE_1_MODEL_SAVE_PATH)
            print(f"==> New best Stage 1 model saved with Lesion Dice: {best_val_dice:.4f}")

    # --- Pseudo-Label Generation ---
    print("\n--- Generating Pseudo-Labels for Stage 2 ---")
    model.load_state_dict(torch.load(STAGE_1_MODEL_SAVE_PATH)) # Load best model
    unlabeled_ids = sorted([f.replace('.npy', '') for f in os.listdir(os.path.join(UNLABELED_DATA_DIR, 't2w'))])

    generate_pseudo_labels(unlabeled_ids, UNLABELED_DATA_DIR, PSEUDO_MASK_DIR, model, DEVICE)


    # --- STAGE 2: Semi-Supervised Training ---
    print("\n--- STAGE 2: Starting Semi-Supervised Training ---")
    # Load best model from Stage 1 to fine-tune
    model.load_state_dict(torch.load(STAGE_1_MODEL_SAVE_PATH))
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE_S2) # Reset optimizer with lower LR

    # Create dataset for pseudo-labeled data
    pseudo_train_base = SemiSupPiCaiDataset(UNLABELED_DATA_DIR, unlabeled_ids, mask_dir=PSEUDO_MASK_DIR)

    # Combine original labeled data and pseudo-labeled data
    if len(pseudo_train_base) > 0:
        combined_train_dataset = ConcatDataset([train_base_s1, pseudo_train_base])
        print(f"Combining {len(train_base_s1)} labeled slices with {len(pseudo_train_base)} pseudo-labeled slices.")
    else:
        print("No pseudo-labeled slices were generated. Proceeding with only labeled data for Stage 2.")
        combined_train_dataset = train_base_s1

    train_dataset_s2 = AugmentationWrapper(combined_train_dataset, transform=train_transform)
    train_loader_s2 = DataLoader(train_dataset_s2, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    for epoch in range(NUM_EPOCHS_STAGE_2):
        print(f"\n--- Stage 2, Epoch {epoch+1}/{NUM_EPOCHS_STAGE_2} ---")
        train_fn(train_loader_s2, model, optimizer, loss_fn, scaler)

        current_dice = check_accuracy(val_loader_s1_balanced, model, device=DEVICE)
        print(f"Validation Dice (Balanced Set): {current_dice:.4f}")

        if current_dice > best_val_dice:
            best_val_dice = current_dice
            torch.save(model.state_dict(), FINAL_MODEL_SAVE_PATH)
            print(f"==> New best Stage 2 model saved with Lesion Dice: {best_val_dice:.4f}")

    # --- Final Comprehensive Evaluation ---
    print("\n--- FINAL EVALUATION ---")
    # Load the absolute best model (could be from stage 1 or 2)
    final_model_path = FINAL_MODEL_SAVE_PATH if os.path.exists(FINAL_MODEL_SAVE_PATH) else STAGE_1_MODEL_SAVE_PATH
    print(f"Loading best model from: {final_model_path}")
    model.load_state_dict(torch.load(final_model_path))

    # Create a validation loader with ALL slices for realistic metrics
    val_loader_all_slices = DataLoader(val_dataset_s1, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

    metrics_2d = calculate_final_2d_metrics(val_loader_all_slices, model, DEVICE)
    metrics_3d = calculate_final_3d_metrics(val_ids, ORIGINAL_DATA_DIR, model, DEVICE)

    print("\n--- Final 2D Slice-Level Metrics ---")
    print(f"  Dice Score: {metrics_2d['2d_dice']:.4f}")
    print(f"  AUROC:      {metrics_2d['2d_auroc']:.4f}")
    print(f"  AP:         {metrics_2d['2d_ap']:.4f}")

    print("\n--- Final 3D Patient-Level Metrics ---")
    print(f"  Dice Score: {metrics_3d['3d_dice']:.4f}")
    print(f"  AUROC:      {metrics_3d['3d_auroc']:.4f}")
    print(f"  AP:         {metrics_3d['3d_ap']:.4f}")

    print(f"\nTraining finished. Best model at: {final_model_path}")
    # shutil.rmtree(PSEUDO_MASK_DIR)
    # print("Cleanup complete.")

if __name__ == '__main__':
    # Note: Deprecation warnings for torch.amp are informational. The corrected code
    # uses the newer `torch.amp.autocast` and `GradScaler` syntax where appropriate.
    main()

testing the code
total number of parameters: 32550209
total number of trainable parameters: 32550209
Input shape : torch.Size([2, 3, 224, 224])
Output shape : torch.Size([2, 6, 224, 224])
Labeled Data Split: 176 training, 44 validation patients.

--- STAGE 1: Starting Supervised Training ---



`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.

Finding training slices: 100%|██████████| 176/176 [00:01<00:00, 113.72it/s]


Loaded 2052 2D slices for this dataset partition.


Finding validation slices: 100%|██████████| 44/44 [00:00<00:00, 458.71it/s]


Loaded 1056 2D slices for this dataset partition.


Finding training slices: 100%|██████████| 44/44 [00:00<00:00, 109.56it/s]


Loaded 508 2D slices for this dataset partition.

--- Stage 1, Epoch 1/1 ---


Training: 100%|██████████| 257/257 [01:02<00:00,  4.12it/s, loss=0.713]


Validation Dice (Balanced Set): 0.4663
==> New best Stage 1 model saved with Lesion Dice: 0.4663

--- Generating Pseudo-Labels for Stage 2 ---


Generating Pseudo-Masks: 100%|██████████| 205/205 [04:07<00:00,  1.21s/it]



--- STAGE 2: Starting Semi-Supervised Training ---


Finding training slices: 100%|██████████| 205/205 [00:01<00:00, 139.47it/s]


Loaded 4458 2D slices for this dataset partition.
Combining 2052 labeled slices with 4458 pseudo-labeled slices.

--- Stage 2, Epoch 1/1 ---


Training: 100%|██████████| 814/814 [03:16<00:00,  4.13it/s, loss=0.366]


Validation Dice (Balanced Set): 0.4812
==> New best Stage 2 model saved with Lesion Dice: 0.4812

--- FINAL EVALUATION ---
Loading best model from: /kaggle/working/semisupervised_output/best_semisupervised_model.pth

Calculating Final 2D Metrics (on all validation slices)...


2D Evaluation: 100%|██████████| 132/132 [00:20<00:00,  6.29it/s]



Calculating Final 3D Metrics (on validation volumes)...


3D Evaluation: 100%|██████████| 44/44 [02:42<00:00,  3.70s/it]


--- Final 2D Slice-Level Metrics ---
  Dice Score: 0.6403
  AUROC:      0.9328
  AP:         0.3407

--- Final 3D Patient-Level Metrics ---
  Dice Score: 0.2723
  AUROC:      0.9069
  AP:         0.3331

Training finished. Best model at: /kaggle/working/semisupervised_output/best_semisupervised_model.pth





In [None]:
def calculate_2d_slice_metrics(loader, model, device="cuda", threshold=0.5):
    """Calculates 2D metrics (Dice, AUROC, AP or PR-AUC) using the validation loader."""
    model.eval()
    all_targets_flat = []
    all_preds_prob_flat = []
    total_dice_score = 0
    num_batches = 0

    with torch.no_grad():
        for x, y in tqdm(loader, desc="Calculating 2D Slice Metrics"):
            x, y = x.to(device), y.to(device)
            with torch.amp.autocast(device_type=str(device)):
                preds = model(x)
                preds_prob = torch.sigmoid(preds)

            preds_binary = (preds_prob > threshold).float()
            intersection = (preds_binary * y).sum()
            dice_score = (2. * intersection) / (preds_binary.sum() + y.sum() + 1e-6)
            total_dice_score += dice_score.item()
            num_batches += 1

            all_targets_flat.append(y.cpu().numpy().flatten())
            all_preds_prob_flat.append(preds_prob.cpu().numpy().flatten())

    avg_dice = total_dice_score / num_batches if num_batches > 0 else 0
    y_true = np.concatenate(all_targets_flat)
    y_pred = np.concatenate(all_preds_prob_flat)

    y_true = (y_true > 0.5).astype(np.uint8)

    if len(np.unique(y_true)) > 1:
        auroc = roc_auc_score(y_true, y_pred)
        pr_auc = average_precision_score(y_true, y_pred)
    else:
        auroc = 0.0
        pr_auc = 0.0

    return {"dice": avg_dice, "auroc": auroc, "pr_auc": pr_auc}
def calculate_3d_volume_metrics(model, val_ids, base_dir, val_transform, device, threshold=0.5):
    """Calculates 3D Dice, AUROC, and PR-AUC by iterating through each patient volume."""
    print("\nStarting 3D volume evaluation...")
    model.eval()
    patient_dice_scores = []
    # Lists to store all flattened GT and predictions for overall metrics
    all_patient_gt_flat = []
    all_patient_pred_flat = []

    def normalize_slice(image_np):
        for i in range(image_np.shape[2]):
            channel = image_np[:, :, i]
            non_zero_pixels = channel[channel > 1e-6]
            if non_zero_pixels.size > 0:
                p1, p99 = np.percentile(non_zero_pixels, 1), np.percentile(non_zero_pixels, 99)
                channel = np.clip(channel, p1, p99)
            min_val, max_val = channel.min(), channel.max()
            if max_val > min_val:
                image_np[:, :, i] = (channel - min_val) / (max_val - min_val)
            else:
                image_np[:, :, i] = np.zeros_like(channel)
        return image_np

    with torch.no_grad():
        for patient_id in tqdm(val_ids, desc="Calculating 3D Volume Metrics"):
            modalities = []
            for modality in ['t2w', 'adc', 'hbv']:
                img_path = os.path.join(base_dir, modality, f'{patient_id}.npy')
                modalities.append(np.load(img_path))
            
            img_3d = np.stack(modalities, axis=-1)
            mask_3d_gt = np.load(os.path.join(base_dir, 'mask', f'{patient_id}.npy'))

            slice_axis = np.argmin(mask_3d_gt.shape)
            num_slices = mask_3d_gt.shape[slice_axis]
            
            all_dims = list(range(mask_3d_gt.ndim))
            all_dims.remove(slice_axis)
            original_h, original_w = mask_3d_gt.shape[all_dims[0]], mask_3d_gt.shape[all_dims[1]]
            
            pred_slices_list = []

            for slice_idx in range(num_slices):
                slicer = [slice(None)] * 3
                slicer[slice_axis] = slice_idx
                image_slice_np = img_3d[tuple(slicer)]

                augmented = val_transform(image=image_slice_np.astype(np.float32))
                image_slice_np_aug = augmented['image']
                
                image_slice_np_norm = normalize_slice(image_slice_np_aug)
                image_tensor = torch.from_numpy(image_slice_np_norm.transpose(2, 0, 1)).float().unsqueeze(0).to(device)

                with torch.amp.autocast(device_type=str(device)):
                    pred_logit = model(image_tensor)
                    pred_prob = torch.sigmoid(pred_logit)

                pred_prob_resized = F.interpolate(pred_prob, size=(original_h, original_w), mode='bilinear', align_corners=False)
                pred_slices_list.append(pred_prob_resized.squeeze().cpu())

            pred_3d_volume_prob = torch.stack(pred_slices_list, axis=slice_axis)
            
            pred_3d_binary = (pred_3d_volume_prob > threshold).float()
            
            mask_3d_gt_tensor = torch.from_numpy(mask_3d_gt).float()
            
            intersection = (pred_3d_binary * mask_3d_gt_tensor).sum()
            dice_3d = (2. * intersection) / (pred_3d_binary.sum() + mask_3d_gt_tensor.sum() + 1e-6)
            patient_dice_scores.append(dice_3d.item())

            all_patient_gt_flat.append(mask_3d_gt_tensor.numpy().flatten())
            all_patient_pred_flat.append(pred_3d_volume_prob.numpy().flatten())

    avg_3d_dice = np.mean(patient_dice_scores) if patient_dice_scores else 0.0

    y_true_3d = np.concatenate(all_patient_gt_flat)
    y_pred_3d = np.concatenate(all_patient_pred_flat)

    y_true_3d = (y_true_3d > 0.5).astype(np.uint8)

    if len(np.unique(y_true_3d)) > 1:
        auroc_3d = roc_auc_score(y_true_3d, y_pred_3d)
        pr_auc_3d = average_precision_score(y_true_3d, y_pred_3d)
    else:
        print("Warning: Only one class present in 3D validation volumes. AUROC and PR-AUC cannot be computed.")
        auroc_3d = 0.0
        pr_auc_3d = 0.0

    return {"3d_dice": avg_3d_dice, "3d_auroc": auroc_3d, "3d_pr_auc": pr_auc_3d}


def main():
    val_transform = A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE)])
    print(f"Using device: {DEVICE}")
    model = PraNet().to(DEVICE)

    try:
        mask_dir = os.path.join(DATA_DIR, 'mask')
        all_sample_ids = sorted([f.replace('.npy', '') for f in os.listdir(mask_dir) if f.endswith('.npy')])
        if not all_sample_ids: raise FileNotFoundError
    except FileNotFoundError:
        print(f"\nERROR: No mask files (.npy) found in {mask_dir}. Check DATA_DIR.")
        return

    # Use the same seed and split to ensure the validation set is identical
    random.seed(42)
    random.shuffle(all_sample_ids)
    split_idx = int(len(all_sample_ids) * (1 - VALIDATION_SPLIT))
    train_ids, val_ids = all_sample_ids[:split_idx], all_sample_ids[split_idx:]
    print(f"\nTotal patient volumes: {len(all_sample_ids)}, Training: {len(train_ids)}, Validation: {len(val_ids)}\n")

    # We only need the validation dataset for evaluation
    val_base_dataset = PiCai2DDataset(DATA_DIR, sample_ids=val_ids)
    val_dataset = AugmentationWrapper(val_base_dataset, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=False)

    print("--- Starting Evaluation on Pre-trained Model ---")
    
    try:
        model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    except FileNotFoundError:
        print(f"ERROR: Model file not found at '{MODEL_SAVE_PATH}'. Cannot evaluate.")
        return
    except Exception as e:
        print(f"An error occurred while loading the model: {e}")
        return

    # 1. Final 2D Slice-based Metrics
    final_metrics_2d = calculate_2d_slice_metrics(val_loader, model, device=DEVICE)
    print("\n--- Final 2D Slice-Based Metrics ---")
    print(f"Dice Score: {final_metrics_2d['dice']:.4f}")
    print(f"AUROC: {final_metrics_2d['auroc']:.4f}")
    print(f"PR-AUC (Average Precision): {final_metrics_2d['pr_auc']:.4f}")

    # 2. Final 3D Volume-based Metrics
    final_metrics_3d = calculate_3d_volume_metrics(model, val_ids, DATA_DIR, val_transform, DEVICE)
    print("\n--- Final 3D Volume-Based Metrics ---")
    print(f"Average Patient Dice Score: {final_metrics_3d['3d_dice']:.4f}")
    print(f"Overall AUROC: {final_metrics_3d['3d_auroc']:.4f}")
    print(f"Overall PR-AUC (Average Precision): {final_metrics_3d['3d_pr_auc']:.4f}")

if __name__ == '__main__':
    main()

In [None]:
val_transform = A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE)])

print(f"Using device: {DEVICE}")
model = PraNet().to(DEVICE)

try:
    mask_dir = os.path.join(DATA_DIR, 'mask')
    all_sample_ids = sorted([f.replace('.npy', '') for f in os.listdir(mask_dir) if f.endswith('.npy')])
    if not all_sample_ids: raise FileNotFoundError
except FileNotFoundError:
    print(f"\nERROR: No mask files (.npy) found in {mask_dir}. Check DATA_DIR.")
    # return

MODEL_SAVE_PATH = "best_unet_model_balanced_eval.pth"
random.seed(42)
random.shuffle(all_sample_ids)
split_idx = int(len(all_sample_ids) * (1 - VALIDATION_SPLIT))
train_ids, val_ids = all_sample_ids[:split_idx], all_sample_ids[split_idx:]
print(f"\nTotal samples: {len(all_sample_ids)}, Training: {len(train_ids)}, Validation: {len(val_ids)}\n")

val_base_dataset = PiCai2DDataset(DATA_DIR, sample_ids=val_ids)
val_dataset = AugmentationWrapper(val_base_dataset, transform=val_transform)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=False)

def visualize_all_predictions(model, loader, device):
    print("\nLoading best model for visualization...")
    try:
        model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
        model.to(device)
    except FileNotFoundError:
        print(f"ERROR: Model file not found at '{MODEL_SAVE_PATH}'. Cannot visualize.")
        print("Please ensure the model was trained and saved successfully.")
        return

    model.eval()

    if len(loader.dataset) == 0:
        print("Validation loader is empty. Cannot visualize predictions.")
        return

    print(f"Generating predictions for all {len(loader.dataset)} validation images...")
    with torch.no_grad():
        # Iterate through each batch in the loader
        for batch_idx, (images, masks) in enumerate(tqdm(loader, desc="Visualizing Batches")):
            images, masks = images.to(device), masks.to(device)

            # Get model predictions
            with torch.amp.autocast(device_type=str(device)):
                preds = torch.sigmoid(model(images))
                preds_binary = (preds > 0.5).float()

            num_images_in_batch = len(images)

            # Create a flexible subplot grid for the current batch
            # Handle the edge case where the batch size is 1
            fig, axes = plt.subplots(num_images_in_batch, 3, figsize=(15, num_images_in_batch * 5))
            if num_images_in_batch == 1:
                axes = np.array([axes])

            fig.suptitle(f"Batch {batch_idx + 1}/{len(loader)}: Predictions vs. Ground Truth", fontsize=18)

            for i in range(num_images_in_batch):
                # Column 1: Input Image (T2W channel)
                ax = axes[i, 0]
                ax.imshow(images[i][0].cpu().numpy(), cmap='gray')
                ax.set_title(f"Input Image (T2W)")
                ax.axis("off")

                # Column 2: Ground Truth Mask
                ax = axes[i, 1]
                ax.imshow(masks[i].squeeze().cpu().numpy(), cmap='gray')
                ax.set_title("Ground Truth Mask")
                ax.axis("off")

                # Column 3: Predicted Mask
                ax = axes[i, 2]
                ax.imshow(preds_binary[i].squeeze().cpu().numpy(), cmap='gray')
                ax.set_title("Predicted Mask")
                ax.axis("off")

            plt.tight_layout()
            plt.subplots_adjust(top=0.95) # Adjust for suptitle
            plt.show()

visualize_all_predictions(model, val_loader, device=DEVICE)

In [None]:
val_transform = A.Compose([A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE)])

print(f"Using device: {DEVICE}")
model = PraNet().to(DEVICE)

try:
    mask_dir = os.path.join(DATA_DIR, 'mask')
    all_sample_ids = sorted([f.replace('.npy', '') for f in os.listdir(mask_dir) if f.endswith('.npy')])
    if not all_sample_ids: raise FileNotFoundError
except FileNotFoundError:
    print(f"\nERROR: No mask files (.npy) found in {mask_dir}. Check DATA_DIR.")
    # return

MODEL_SAVE_PATH = "best_unet_model_balanced_eval.pth"
random.seed(42)
random.shuffle(all_sample_ids)
split_idx = int(len(all_sample_ids) * (1 - VALIDATION_SPLIT))
train_ids, val_ids = all_sample_ids[:split_idx], all_sample_ids[split_idx:]
print(f"\nTotal samples: {len(all_sample_ids)}, Training: {len(train_ids)}, Validation: {len(val_ids)}\n")

val_base_dataset = PiCai2DDataset(DATA_DIR, sample_ids=val_ids)
val_dataset = AugmentationWrapper(val_base_dataset, transform=val_transform)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=False)

def visualize_3d_with_image_overlay(
    image_3d, 
    ground_truth_mask_3d, 
    predicted_mask_3d, 
    sample_id="", 
    downsample_factor=0.25,
    show_gt=True,
    show_pred=True
):

    print(f"Generating interactive 3D plot for {sample_id}...")

    # --- 1. Downsample for stable visualization ---
    if downsample_factor < 1.0:
        vis_image_3d = scipy.ndimage.zoom(image_3d, downsample_factor, order=1)
        vis_gt_mask = scipy.ndimage.zoom(ground_truth_mask_3d, downsample_factor, order=0)
        vis_pred_mask = scipy.ndimage.zoom(predicted_mask_3d, downsample_factor, order=0)
    else:
        vis_image_3d, vis_gt_mask, vis_pred_mask = image_3d, ground_truth_mask_3d, predicted_mask_3d
    
    gt_binary = (vis_gt_mask > 0.5)
    pred_binary = (vis_pred_mask > 0.5)

    pio.renderers.default = "notebook"
    x, y, z = np.mgrid[:vis_image_3d.shape[0], :vis_image_3d.shape[1], :vis_image_3d.shape[2]]
    fig = go.Figure()

    # --- 2. Add base image volume ---
    fig.add_trace(go.Volume(
        x=x.flatten(), y=y.flatten(), z=z.flatten(), value=vis_image_3d.flatten(),
        isomin=np.min(vis_image_3d), isomax=np.max(vis_image_3d), opacity=0.15,
        surface_count=15, colorscale='Greys', name="T2W Image", showscale=False
    ))
    
    # If showing both GT and Prediction, use the 3-color analysis mode
    if show_gt and show_pred:
        # Calculate Overlap (True Positives), Prediction Only (False Positives), and GT Only (False Negatives)
        overlap_mask = gt_binary & pred_binary
        pred_only_mask = ~gt_binary & pred_binary
        gt_only_mask = gt_binary & ~pred_binary

        # Add Overlap region in GREEN
        if np.any(overlap_mask):
            fig.add_trace(go.Isosurface(
                x=x.flatten(), y=y.flatten(), z=z.flatten(), value=overlap_mask.flatten().astype(np.uint8),
                isomin=0.5, isomax=1.0, surface_count=1, opacity=0.8,
                name="Overlap (GT & Pred)", colorscale=[[0, 'green'], [1, 'green']], showscale=False
            ))
        
        # Add Prediction Only region in RED
        if np.any(pred_only_mask):
            fig.add_trace(go.Isosurface(
                x=x.flatten(), y=y.flatten(), z=z.flatten(), value=pred_only_mask.flatten().astype(np.uint8),
                isomin=0.5, isomax=1.0, surface_count=1, opacity=0.7,
                name="Prediction Only", colorscale=[[0, 'red'], [1, 'red']], showscale=False
            ))

        # Add Ground Truth Only region in BLUE
        if np.any(gt_only_mask):
            fig.add_trace(go.Isosurface(
                x=x.flatten(), y=y.flatten(), z=z.flatten(), value=gt_only_mask.flatten().astype(np.uint8),
                isomin=0.5, isomax=1.0, surface_count=1, opacity=0.5,
                name="Ground Truth Only", colorscale=[[0, 'blue'], [1, 'blue']], showscale=False
            ))
        
        title = "3D Analysis: Overlap (Green), Prediction Only (Red), GT Only (Blue)"

    # If showing only GT
    elif show_gt:
        if np.any(gt_binary):
            fig.add_trace(go.Isosurface(
                x=x.flatten(), y=y.flatten(), z=z.flatten(), value=gt_binary.flatten().astype(np.uint8),
                isomin=0.5, isomax=1.0, opacity=0.5, surface_count=1, name="Ground Truth", colorscale='blues', showscale=False
            ))
        title = "3D View: Ground Truth"

    # If showing only Prediction
    elif show_pred:
        if np.any(pred_binary):
            fig.add_trace(go.Isosurface(
                x=x.flatten(), y=y.flatten(), z=z.flatten(), value=pred_binary.flatten().astype(np.uint8),
                isomin=0.5, isomax=1.0, opacity=0.7, surface_count=1, name="Prediction", colorscale='reds', showscale=False
            ))
        title = "3D View: Prediction"
    else:
        title = "3D View: Base Image Only"

    fig.update_layout(
        title=f"{title}<br>Sample: {sample_id}",
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z', aspectratio=dict(x=1, y=1, z=1), bgcolor='rgb(10, 10, 10)'),
        legend=dict(x=0, y=1)
    )
    
    print("Displaying plot inline...")
    fig.show()

def calculate_3d_metrics(ground_truth, prediction):
    gt, pred = (ground_truth > 0.5).astype(bool), (prediction > 0.5).astype(bool)
    if not np.any(gt) and not np.any(pred): return {"dice": 1.0, "precision": 1.0, "recall": 1.0}
    tp, fp, fn = np.sum(gt & pred), np.sum(~gt & pred), np.sum(gt & ~pred)
    epsilon = 1e-8
    dice, precision, recall = (2 * tp) / (2 * tp + fp + fn + epsilon), tp / (tp + fp + epsilon), tp / (tp + fn + epsilon)
    return {"dice": dice, "precision": precision, "recall": recall} 
    
def predict_evaluate_and_visualize_3d(model, base_dir, sample_id, device, image_size):
    print(f"\n--- Processing Sample ID: {sample_id} ---")
    model.eval()
    modalities = [np.load(os.path.join(base_dir, m, f'{sample_id}.npy')) for m in ['t2w', 'adc', 'hbv']]
    image_3d_multichannel = np.stack(modalities, axis=-1)
    mask_3d = np.load(os.path.join(base_dir, 'mask', f'{sample_id}.npy'))
    prediction_3d = np.zeros_like(mask_3d, dtype=np.float32)
    slice_axis = np.argmin(image_3d_multichannel.shape[:-1])
    num_slices = image_3d_multichannel.shape[slice_axis]
    h, w = (image_3d_multichannel.shape[1], image_3d_multichannel.shape[2]) if slice_axis == 0 else (image_3d_multichannel.shape[0], image_3d_multichannel.shape[2])
    val_transform = A.Compose([A.Resize(height=image_size, width=image_size)])
    resize_back_transform = A.Compose([A.Resize(height=h, width=w)])
    with torch.no_grad():
        for slice_idx in tqdm(range(num_slices), desc=f"Predicting {sample_id}"):
            slicer = [slice(None)] * 3; slicer[slice_axis] = slice_idx
            image_slice_np = image_3d_multichannel[tuple(slicer)]
            resized_image_np = val_transform(image=image_slice_np.astype(np.float32))['image']
            normalized_slice_np = resized_image_np.copy()
            for i in range(normalized_slice_np.shape[2]):
                channel = normalized_slice_np[:, :, i]
                non_zero_pixels = channel[channel > 1e-6]
                if non_zero_pixels.size > 0:
                    p1, p99 = np.percentile(non_zero_pixels, 1), np.percentile(non_zero_pixels, 99)
                    channel = np.clip(channel, p1, p99)
                min_val, max_val = channel.min(), channel.max()
                normalized_slice_np[:, :, i] = (channel - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(channel)
            image_tensor = torch.from_numpy(normalized_slice_np.transpose(2, 0, 1)).float().unsqueeze(0).to(device)
            with torch.amp.autocast(device_type=str(device)): pred_prob = torch.sigmoid(model(image_tensor))
            pred_mask = (pred_prob > 0.5).float().cpu().squeeze().numpy()
            resized_pred_mask = resize_back_transform(image=pred_mask)['image']
            prediction_3d[tuple(slicer)] = resized_pred_mask

    metrics = calculate_3d_metrics(mask_3d, prediction_3d)
    print(f"Metrics for {sample_id}: Dice={metrics['dice']:.4f}, Precision={metrics['precision']:.4f}, Recall={metrics['recall']:.4f}")
    
    t2w_image_3d = image_3d_multichannel[:, :, :, 0]
    visualize_3d_with_image_overlay(t2w_image_3d, mask_3d, prediction_3d, sample_id, downsample_factor=0.25)
    return metrics

model = PraNet().to(DEVICE) 
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.eval()

num_images_to_predict = 3

ids_to_predict = val_ids[10:10+num_images_to_predict]

if not ids_to_predict:
    print("No validation IDs available for 3D prediction.")
else:
    print(f"Will run 3D prediction on the following sample IDs: {ids_to_predict}")
    for sample_id in ids_to_predict:
        predict_evaluate_and_visualize_3d(
            model=model,
            base_dir=DATA_DIR,
            sample_id=sample_id,
            device=DEVICE,
            image_size=IMAGE_SIZE
        )