In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import torch
import nibabel as nib
import os
import tensorflow as tf
from tensorflow.keras.utils import to_categorical # type: ignore # type: ignore
from skimage.transform import resize


In [2]:
print(pd.__version__)
print(np.__version__)
print(tf.__version__)
print(torch.__version__)
print("GPU is", "available" if torch.cuda.is_available() else "NOT AVAILABLE")

2.2.3
2.0.2
2.18.0
2.6.0+cpu
GPU is NOT AVAILABLE


In [3]:
import os
import numpy as np
import nibabel as nib
from skimage.transform import resize
from tensorflow.keras.utils import to_categorical

VOLUME_DIR = "train_folder"
SEGMENTATION_DIR = os.path.join("train_folder", "Segmentation")
IMG_SIZE = (128, 128)

def load_nifti(file_path):
    """Load a NIfTI file and return the NumPy array"""
    print(f"Loading file: {file_path}") 
    nifti_img = nib.load(file_path) 
    return nifti_img.get_fdata()

def preprocess_data(volume_path, segmentation_path):
    """Load and preprocess volume and segmentation data"""
    volume_data = load_nifti(volume_path)
    segmentation_data = load_nifti(segmentation_path).astype(int)

    volume_data = (volume_data - np.min(volume_data)) / (np.max(volume_data) - np.min(volume_data))

    volume_resized = np.array([
        resize(slice, IMG_SIZE, mode='constant', preserve_range=True) 
        for slice in volume_data.transpose(2, 0, 1)
    ])
    segmentation_resized = np.array([
        resize(slice, IMG_SIZE, mode='constant', preserve_range=True, order=0) 
        for slice in segmentation_data.transpose(2, 0, 1)
    ])

    segmentation_onehot = to_categorical(segmentation_resized, num_classes=3)

    return volume_resized, segmentation_onehot

X_train, Y_train = [], []

image_files = sorted([f for f in os.listdir(VOLUME_DIR) if f.endswith(".nii")])
mask_files = sorted([f for f in os.listdir(SEGMENTATION_DIR) if f.endswith(".nii")])

for img_file, mask_file in zip(image_files, mask_files):
    volume_path = os.path.join(VOLUME_DIR, img_file)
    segmentation_path = os.path.join(SEGMENTATION_DIR, mask_file)
    vol, seg = preprocess_data(volume_path, segmentation_path)
    X_train.append(vol)
    Y_train.append(seg)

X_train = np.concatenate(X_train, axis=0)
Y_train = np.concatenate(Y_train, axis=0)

print("Training data shape:", X_train.shape)
print("Segmentation mask shape:", Y_train.shape)

Loading file: train_folder\volume-0.nii
Loading file: train_folder\Segmentation\segmentation-0.nii
Loading file: train_folder\volume-1.nii
Loading file: train_folder\Segmentation\segmentation-1.nii
Loading file: train_folder\volume-2.nii
Loading file: train_folder\Segmentation\segmentation-2.nii
Loading file: train_folder\volume-3.nii
Loading file: train_folder\Segmentation\segmentation-3.nii
Loading file: train_folder\volume-4.nii
Loading file: train_folder\Segmentation\segmentation-4.nii
Training data shape: (2090, 128, 128)
Segmentation mask shape: (2090, 128, 128, 3)


In [4]:
import os
import numpy as np
import nibabel as nib
from tensorflow.keras.utils import to_categorical # type: ignore
# Function to load NIfTI files
def load_nifti(file_path):
    """Load a NIfTI file and return the NumPy array"""
    print(f"Loading: {file_path}")  # Debugging: Print file being loaded
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    
    nifti_img = nib.load(file_path)
    return nifti_img.get_fdata()

# Function to preprocess test data
def preprocess_data(test_volume_path, test_segmentation_path):
    """Load and preprocess a single test volume and its corresponding segmentation mask"""
    volume_data = load_nifti(test_volume_path)
    segmentation_data = load_nifti(test_segmentation_path).astype(int)

    # Normalize CT images (0-1)
    volume_data = (volume_data - np.min(volume_data)) / (np.max(volume_data) - np.min(volume_data))

    # Resize to (128, 128)
    from skimage.transform import resize
    IMG_SIZE = (128, 128)
    
    volume_resized = np.array([
        resize(slice, IMG_SIZE, mode='constant', preserve_range=True)
        for slice in volume_data.transpose(2, 0, 1)  # Rearrange to (slices, H, W)
    ])
    
    segmentation_resized = np.array([
        resize(slice, IMG_SIZE, mode='constant', preserve_range=True, order=0) 
        for slice in segmentation_data.transpose(2, 0, 1)
    ])

    # One-hot encode segmentation masks (background=0, liver=1, tumor=2)
    
    segmentation_onehot = to_categorical(segmentation_resized, num_classes=3)

    return volume_resized, segmentation_onehot

# ---- Load Test Data ----
test_image_path = os.path.join("test_folder") 
test_mask_path = os.path.join("test_folder", "segmentation") 

# Ensure directories exist
if not os.path.exists(test_image_path) or not os.path.exists(test_mask_path):
    raise FileNotFoundError("Test data folder or segmentation folder not found!")

# Get sorted filenames
image_files = sorted([f for f in os.listdir(test_image_path) if f.endswith(".nii")])
mask_files = sorted([f for f in os.listdir(test_mask_path) if f.endswith(".nii")])

# Check if the number of images and masks match
if len(image_files) != len(mask_files):
    raise ValueError("Mismatch between the number of test images and segmentation masks!")

X_test, Y_test = [], []

# Load and preprocess each test image and mask
for img_file, mask_file in zip(image_files, mask_files):
    volume_path = os.path.join(test_image_path, img_file)  # Correct path joining
    segmentation_path = os.path.join(test_mask_path, mask_file)  # Correct path joining

    vol, seg = preprocess_data(volume_path, segmentation_path)
    X_test.append(vol)
    Y_test.append(seg)

# Convert lists to NumPy arrays
X_test = np.array(X_test)  # Shape: (num_volumes, num_slices, 128, 128)
Y_test = np.array(Y_test)  # Shape: (num_volumes, num_slices, 128, 128, 3)

# Reshape to flatten across all slices
X_test = X_test.reshape(-1, 128, 128, 1)  # Add channel dimension
Y_test = Y_test.reshape(-1, 128, 128, 3)  # Keep segmentation masks in one-hot format

# Print shapes
print("Final Test Data Shape:", X_test.shape)  # (total_slices, 128, 128, 1)
print("Final Test Mask Shape:", Y_test.shape)  # (total_slices, 128, 128, 3)


Loading: test_folder\volume-10.nii
Loading: test_folder\segmentation\segmentation-10.nii
Final Test Data Shape: (501, 128, 128, 1)
Final Test Mask Shape: (501, 128, 128, 3)


In [5]:
###Model
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=1, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim))

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, n_patches, embed_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, n_patches+1, embed_dim)
        x = x + self.pos_embed
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, depth=4, num_heads=8, mlp_ratio=4.):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim*mlp_ratio), batch_first=True)
            for _ in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class TransUNet(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_channels=1, out_channels=3, embed_dim=768):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.transformer = TransformerEncoder(embed_dim=embed_dim)

        self.decoder_dim = 256
        self.linear_decoder = nn.Sequential(
            nn.Linear(embed_dim, self.decoder_dim),
            nn.ReLU(inplace=True)
        )

        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(self.decoder_dim, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, out_channels, kernel_size=2, stride=2),
        )

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # (B, n_patches+1, embed_dim)
        x = self.transformer(x)  # (B, n_patches+1, embed_dim)
        x = x[:, 1:, :]  # Remove cls token, shape: (B, n_patches, embed_dim)

        # Reshape to 2D feature map
        h = w = int(math.sqrt(x.shape[1]))
        x = self.linear_decoder(x)  # (B, n_patches, decoder_dim)
        x = x.permute(0, 2, 1).contiguous().view(B, self.decoder_dim, h, w)  # (B, decoder_dim, H, W)

        x = self.upsample(x)  # (B, out_channels, 128, 128)
        return x

In [6]:
import torch
from torch.utils.data import DataLoader, Dataset

# Assuming x_train is (2090, 128, 128) and y_train is (2090, 128, 128, 3)
class SliceDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        # Convert to float32 and torch tensors
        self.x = torch.tensor(x, dtype=torch.float32).unsqueeze(1)  # (N, 1, 128, 128)
        self.y = torch.tensor(y, dtype=torch.float32).permute(0, 3, 1, 2)  # (N, 3, 128, 128)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# Create dataset and loader
batch_size = 16
dataset = SliceDataset(X_train, Y_train)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Fetch one batch from the DataLoader and print details
for images, masks in train_loader:
    print("✅ Images shape:", images.shape)  # Should be (16, 1, 128, 128)
    print("✅ Masks shape:", masks.shape)    # Should be (16, 3, 128, 128)
    
    # Check value ranges
    print("Image min/max:", images.min().item(), images.max().item())
    print("Mask unique values (flattened):", torch.unique(masks))

    # Optional: check one-hot encoding validity
    is_one_hot = torch.all((masks.sum(dim=1) == 1) | (masks.sum(dim=1) == 0))
    print("Is one-hot encoded?", is_one_hot.item())

    break  # Just check the first batch

# Test DataLoader
for batch_idx, (images, masks) in enumerate(train_loader):
    print(f"Batch {batch_idx + 1}")
    print(f"Image batch shape: {images.shape}")  # Should be (16, 1, 128, 128)
    print(f"Mask batch shape: {masks.shape}")    # Should be (16, 3, 128, 128)
    
    # Optionally test forward pass through your model
    model = TransUNet(in_channels=1, out_channels=3)  # Ensure model is defined
    outputs = model(images)  # Forward pass
    print(f"Model output shape: {outputs.shape}")  # Should be (16, 3, 128, 128)
    
    break  # Just check one batch

In [8]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Dice Loss implementation for multi-class segmentation
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        # logits: (B, C, H, W), targets: (B, H, W)
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=logits.shape[1]).permute(0, 3, 1, 2).float()

        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_one_hot, dims)
        union = torch.sum(probs + targets_one_hot, dims)
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

# === Setup ===
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransUNet().to(device)
ce_loss = nn.CrossEntropyLoss()
dice_loss = DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Save directory
save_dir = "checkpoints/transunet_experiment"
os.makedirs(save_dir, exist_ok=True)
log_file = os.path.join(save_dir, "training_log.txt")

# === Training Loop with per-epoch saving and best-epoch tracking ===
num_epochs = 15
best_loss = float("inf")
best_epoch = -1

with open(log_file, "w") as f:
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            targets_idx = torch.argmax(targets, dim=1)

            outputs = model(inputs)
            loss = ce_loss(outputs, targets_idx) + dice_loss(outputs, targets_idx)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)

        # Save model every epoch
        model_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), model_path)

        # Check if best
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_epoch = epoch + 1
            best_model_path = os.path.join(save_dir, "best_model.pth")
            torch.save(model.state_dict(), best_model_path)

        # Print and log
        log_msg = f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}"
        print(log_msg, flush=True)  # Ensures the message is printed immediately
        f.write(log_msg + "\n")

    final_msg = f"\nBest Epoch: {best_epoch} with Loss: {best_loss:.4f}"
    print(final_msg, flush=True)  # Prints the final best epoch message immediately
    f.write(final_msg + "\n")


Epoch [1/15] Loss: 0.8761
Epoch [2/15] Loss: 0.7868
Epoch [3/15] Loss: 0.7869
Epoch [4/15] Loss: 0.7866
Epoch [5/15] Loss: 0.7886
Epoch [6/15] Loss: 0.7867
Epoch [7/15] Loss: 0.7891
Epoch [8/15] Loss: 0.7863
Epoch [9/15] Loss: 0.7863
Epoch [10/15] Loss: 0.7859
Epoch [11/15] Loss: 0.7875
Epoch [12/15] Loss: 0.7864
Epoch [13/15] Loss: 0.7869
Epoch [14/15] Loss: 0.7874
Epoch [15/15] Loss: 0.7858

Best Epoch: 15 with Loss: 0.7858


In [9]:
# Print shapes
X_test = np.squeeze(X_test, axis=-1)  # Now shape: (501, 128, 128)
print("Final Test Data Shape:", X_test.shape)  # (total_slices, 128, 128, 1)
print("Final Test Mask Shape:", Y_test.shape)  # (total_slices, 128, 128, 3)
print("Training data shape:", X_train.shape)
print("Segmentation mask shape:", Y_train.shape)

Final Test Data Shape: (501, 128, 128)
Final Test Mask Shape: (501, 128, 128, 3)
Training data shape: (2090, 128, 128)
Segmentation mask shape: (2090, 128, 128, 3)


In [10]:
class TestSliceDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        # Remove the last dimension if it's singleton (1)
        if x.shape[-1] == 1:
            x = np.squeeze(x, axis=-1)  # Now (N, 128, 128)

        self.x = torch.tensor(x, dtype=torch.float32).unsqueeze(1)  # (N, 1, 128, 128)
        self.y = torch.tensor(y, dtype=torch.float32).permute(0, 3, 1, 2)  # (N, 3, 128, 128)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# Use it
test_dataset = TestSliceDataset(X_test, Y_test)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
for images, masks in test_loader:
    print("Fixed input shape:", images.shape)  # ✅ Should be (16, 1, 128, 128)
    print("Fixed mask shape:", masks.shape)    # ✅ Should be (16, 3, 128, 128)
    
    # Forward pass
    outputs = model(images)
    print("Model output shape:", outputs.shape)
    break


Fixed input shape: torch.Size([16, 1, 128, 128])
Fixed mask shape: torch.Size([16, 3, 128, 128])
Model output shape: torch.Size([16, 3, 128, 128])


In [None]:
# # Recreate the model architecture first
# model = UNetPlusPlus(in_channels=1, out_channels=3)

# # Load the weights
# model.load_state_dict(torch.load('unetplusplus_epoch_15.pth'))

# # Put it in eval mode if you're using it for inference
# model.eval()

In [11]:
def dice_score(preds, targets, epsilon=1e-6):
    # Assumes preds are one-hot or softmax outputs (batch, C, H, W)
    preds = torch.argmax(preds, dim=1)  # (batch, H, W)
    targets = torch.argmax(targets, dim=1)  # (batch, H, W)

    dice_total = 0
    for cls in range(3):  # For each class
        pred_cls = (preds == cls).float()
        target_cls = (targets == cls).float()
        
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        
        dice = (2. * intersection + epsilon) / (union + epsilon)
        dice_total += dice

    return dice_total / 3  # Average over 3 classes

total_dice = 0
num_batches = 0

with torch.no_grad():
    for images, masks in test_loader:
        outputs = model(images)  # Forward pass
        batch_dice = dice_score(outputs, masks)
        total_dice += batch_dice.item()
        num_batches += 1

avg_dice = total_dice / num_batches
print(f"Average Dice Score on Test Set: {avg_dice:.4f}")

Average Dice Score on Test Set: 0.7775


In [12]:
def dice_score_per_class(preds, targets, epsilon=1e-6):
    # Convert from softmax/one-hot to label maps
    preds = torch.argmax(preds, dim=1)    # (batch, H, W)
    targets = torch.argmax(targets, dim=1)  # (batch, H, W)

    class_dice_scores = []

    for cls in range(3):  # 3 classes
        pred_cls = (preds == cls).float()
        target_cls = (targets == cls).float()

        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()

        dice = (2. * intersection + epsilon) / (union + epsilon)
        class_dice_scores.append(dice.item())

    return class_dice_scores  # List: [dice_class_0, dice_class_1, dice_class_2]

In [13]:
total_dice = [0.0, 0.0, 0.0]
num_batches = 0

with torch.no_grad():
    for images, masks in test_loader:
        outputs = model(images)
        dice_scores = dice_score_per_class(outputs, masks)
        for i in range(3):
            total_dice[i] += dice_scores[i]
        num_batches += 1

avg_dice_per_class = [d / num_batches for d in total_dice]
for i, score in enumerate(avg_dice_per_class):
    print(f"Average Dice Score for Class {i}: {score:.4f}")

Average Dice Score for Class 0: 0.9888
Average Dice Score for Class 1: 0.6250
Average Dice Score for Class 2: 0.7188


In [14]:
def dice_score_foreground(preds, targets, epsilon=1e-6):
    preds = torch.argmax(preds, dim=1)  # (B, H, W)
    targets = torch.argmax(targets, dim=1)  # (B, H, W)

    dice_scores = []

    for cls in [1, 2]:  # Only foreground classes
        pred_cls = (preds == cls).float()
        target_cls = (targets == cls).float()
        
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        
        dice = (2. * intersection + epsilon) / (union + epsilon)
        dice_scores.append(dice)

    return sum(dice_scores) / len(dice_scores)

# Compute combined Dice for foreground (classes 1 and 2)
total_dice_fg = 0
num_batches = 0

with torch.no_grad():
    for images, masks in test_loader:
        outputs = model(images)
        batch_dice_fg = dice_score_foreground(outputs, masks)
        total_dice_fg += batch_dice_fg.item()
        num_batches += 1

avg_dice_fg = total_dice_fg / num_batches
print(f"Average Dice Score for Foreground (Classes 1 & 2): {avg_dice_fg:.4f}")

Average Dice Score for Foreground (Classes 1 & 2): 0.6719
