In [4]:
import os
import shutil
from pathlib import Path

# Define input and output directories
input_dir = "/home/mhs/thesis/fastMRI/fastMRI_breast_IDS_001_150_DCM"
output_dir = "/home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80"

# Define slices to select (30, 35, 40, ..., 100)
slice_numbers = list(range(80, 101, 5))  # [30, 35, 40, ..., 100]
frame_number = "001"  # Only select frame_001

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Iterate through all sub-folders in the input directory
for subfolder in os.listdir(input_dir):
    subfolder_path = os.path.join(input_dir, subfolder)
    
    # Explicitly process only Seg2 sub-folders (ending with "_2_DCM")
    if os.path.isdir(subfolder_path) and subfolder.endswith("_2_DCM"):
        # Extract patient number (e.g., "001" from "fastMRI_breast_001_2_DCM")
        patient_id = subfolder.split("_")[2]
        # Create corresponding output sub-folder
        output_subfolder = os.path.join(output_dir, f"fastMRI_breast_{patient_id}_2_DCM")
        os.makedirs(output_subfolder, exist_ok=True)
        
        # Iterate through files in the Seg2 sub-folder
        for filename in os.listdir(subfolder_path):
            # Check if file matches the pattern slice_XXX_frame_001
            if filename.startswith("slice_") and f"frame_{frame_number}" in filename:
                # Extract slice number (e.g., "030" from "slice_030_frame_001")
                slice_num_str = filename.split("_")[1]
                try:
                    slice_num = int(slice_num_str)
                    # Check if slice number is in the desired list
                    if slice_num in slice_numbers:
                        # Copy the file to the output sub-folder
                        src_file = os.path.join(subfolder_path, filename)
                        dst_file = os.path.join(output_subfolder, filename)
                        shutil.copy2(src_file, dst_file)
                        print(f"Copied: {filename} to {output_subfolder}")
                except ValueError:
                    # Skip files with invalid slice numbers
                    print(f"Skipping invalid filename: {filename}")
    else:
        # Skip Seg1 sub-folders (ending with "_1_DCM") and non-directory items
        print(f"Skipping non-Seg2 folder: {subfolder}")
        continue

print("Filtering complete. Selected Seg2 slices saved to:", output_dir)

Skipping non-Seg2 folder: fastMRI_breast_082_1_DCM
Skipping non-Seg2 folder: fastMRI_breast_010_1_DCM
Copied: slice_100_frame_001.png:Zone.Identifier to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_080_frame_001.png:Zone.Identifier to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_085_frame_001.png to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_100_frame_001.png to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_080_frame_001.png to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_090_frame_001.png to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_095_frame_001.png:Zone.Identifier to /home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80/fastMRI_breast_039_2_DCM
Copied: slice_095_frame_001.png to /ho

In [None]:
# !pip install openpyxl

Collecting openpyxl
  Downloading openpyxl-3.1.5-py2.py3-none-any.whl.metadata (2.5 kB)
Collecting et-xmlfile (from openpyxl)
  Downloading et_xmlfile-2.0.0-py3-none-any.whl.metadata (2.7 kB)
Downloading openpyxl-3.1.5-py2.py3-none-any.whl (250 kB)
Downloading et_xmlfile-2.0.0-py3-none-any.whl (18 kB)
Installing collected packages: et-xmlfile, openpyxl
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [openpyxl]
[1A[2KSuccessfully installed et-xmlfile-2.0.0 openpyxl-3.1.5


In [5]:
import os
import shutil
import pandas as pd
import numpy as np
import pydicom
from pathlib import Path

# Define paths
label_file = "/home/mhs/thesis/fastMRI/fastMRI_breast_labels.xlsx"
input_dir = "/home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80"
output_npz = "/home/mhs/thesis/fastMRI/fastmri_breast_data.npz"

# Define label mapping
label_mapping = {0: "negative", 1: "malignancy", 2: "benign"}

# Read the label Excel file
df = pd.read_excel(label_file, sheet_name="Sheet1")

# Create a dictionary mapping patient IDs to labels
label_dict = dict(zip(df["Patient Coded Name"], df["Lesion status (0 = negative, 1= malignancy, 2= benign)"]))

# Initialize lists to store images and labels
images = []
labels = []
image_filenames = []

# Process each sub-folder in the input directory
for subfolder in sorted(os.listdir(input_dir)):
    subfolder_path = os.path.join(input_dir, subfolder)
    
    # Check if it's a Seg2 directory
    if os.path.isdir(subfolder_path) and subfolder.endswith("_2_DCM"):
        # Extract patient ID (e.g., "001" from "fastMRI_breast_001_2_DCM")
        patient_id = subfolder.split("_")[2]
        patient_name = f"fastMRI_breast_{patient_id}"
        
        # Get the label for this patient
        if patient_name not in label_dict:
            print(f"Warning: No label found for {patient_name}. Skipping.")
            continue
        label = label_dict[patient_name]
        
        # Rename the sub-folder by appending the label
        new_subfolder_name = f"{subfolder}_label_{label}"
        new_subfolder_path = os.path.join(input_dir, new_subfolder_name)
        try:
            shutil.move(subfolder_path, new_subfolder_path)
            print(f"Renamed: {subfolder} to {new_subfolder_name}")
        except Exception as e:
            print(f"Error renaming {subfolder}: {e}")
            continue
        
        # Process images in the sub-folder
        for filename in sorted(os.listdir(new_subfolder_path)):
            if filename.startswith("slice_") and filename.endswith(f"frame_001"):
                file_path = os.path.join(new_subfolder_path, filename)
                try:
                    # Read DICOM image
                    dicom = pydicom.dcmread(file_path)
                    image = dicom.pixel_array
                    # Ensure image is 2D (320x320)
                    if image.shape != (320, 320):
                        print(f"Warning: Unexpected shape {image.shape} for {filename}. Skipping.")
                        continue
                    # Append image, label, and filename
                    images.append(image)
                    labels.append(label)
                    image_filenames.append(f"{new_subfolder_name}/{filename}")
                except Exception as e:
                    print(f"Error reading {filename}: {e}")
                    continue

# # Convert lists to numpy arrays
# images = np.array(images)  # Shape: (N, 320, 320)
# labels = np.array(labels, dtype=np.int64)  # Shape: (N,) - explicitly convert to integers

# # Save to NPZ file
# np.savez(
#     output_npz,
#     images=images,
#     labels=labels,
#     label_mapping=label_mapping,
#     image_filenames=image_filenames
# )

# print(f"NPZ file saved to: {output_npz}")
# print(f"Total images: {len(images)}")
# print(f"Label distribution: {np.bincount(labels)}")

Renamed: fastMRI_breast_001_2_DCM to fastMRI_breast_001_2_DCM_label_2
Renamed: fastMRI_breast_002_2_DCM to fastMRI_breast_002_2_DCM_label_2
Renamed: fastMRI_breast_003_2_DCM to fastMRI_breast_003_2_DCM_label_2
Renamed: fastMRI_breast_004_2_DCM to fastMRI_breast_004_2_DCM_label_2
Renamed: fastMRI_breast_005_2_DCM to fastMRI_breast_005_2_DCM_label_2
Renamed: fastMRI_breast_006_2_DCM to fastMRI_breast_006_2_DCM_label_2
Renamed: fastMRI_breast_007_2_DCM to fastMRI_breast_007_2_DCM_label_2
Renamed: fastMRI_breast_008_2_DCM to fastMRI_breast_008_2_DCM_label_2
Renamed: fastMRI_breast_009_2_DCM to fastMRI_breast_009_2_DCM_label_2
Renamed: fastMRI_breast_010_2_DCM to fastMRI_breast_010_2_DCM_label_1
Renamed: fastMRI_breast_011_2_DCM to fastMRI_breast_011_2_DCM_label_1
Renamed: fastMRI_breast_012_2_DCM to fastMRI_breast_012_2_DCM_label_0
Renamed: fastMRI_breast_013_2_DCM to fastMRI_breast_013_2_DCM_label_2
Renamed: fastMRI_breast_014_2_DCM to fastMRI_breast_014_2_DCM_label_1
Renamed: fastMRI_bre

In [6]:
import os
import numpy as np
from PIL import Image
from pathlib import Path

# Define paths
input_dir = "/home/mhs/thesis/fastMRI/fastMRI_breast_filtered_DCM_80"
output_npz = "/home/mhs/thesis/fastMRI/fastmri_breast_data_80.npz"

# Define label mapping
label_mapping = {0: "negative", 1: "malignancy", 2: "benign"}

# Initialize lists to store images, labels, and filenames
images = []
labels = []
image_filenames = []

# Process each sub-folder in the input directory
for subfolder in sorted(os.listdir(input_dir)):
    subfolder_path = os.path.join(input_dir, subfolder)
    
    # Check if it's a directory and has a label in the name
    if os.path.isdir(subfolder_path) and "_label_" in subfolder:
        # Extract label from sub-folder name (e.g., "label_2" -> 2)
        try:
            label_str = subfolder.split("_label_")[1]
            label = int(label_str)
            if label not in label_mapping:
                print(f"Warning: Invalid label {label} in {subfolder}. Skipping.")
                continue
        except (IndexError, ValueError):
            print(f"Warning: Could not parse label from {subfolder}. Skipping.")
            continue
        
        # Process PNG images in the sub-folder
        for filename in sorted(os.listdir(subfolder_path)):
            if filename.endswith(".png"):
                file_path = os.path.join(subfolder_path, filename)
                try:
                    # Read PNG image
                    image = Image.open(file_path).convert("L")  # Convert to grayscale
                    image_np = np.array(image)
                    # Ensure image is 2D (e.g., 320x320)
                    if len(image_np.shape) != 2:
                        print(f"Warning: Unexpected shape {image_np.shape} for {filename}. Skipping.")
                        continue
                    # Append image, label, and filename
                    images.append(image_np)
                    labels.append(label)
                    image_filenames.append(f"{subfolder}/{filename}")
                except Exception as e:
                    print(f"Error reading {filename}: {e}")
                    continue
    else:
        print(f"Skipping non-subfolder or invalid folder: {subfolder}")
        continue

# Convert lists to numpy arrays
images = np.array(images)  # Shape: (N, height, width)
labels = np.array(labels)  # Shape: (N,)

# Save to NPZ file
np.savez(
    output_npz,
    images=images,
    labels=labels,
    label_mapping=label_mapping,
    image_filenames=image_filenames
)

print(f"NPZ file saved to: {output_npz}")
print(f"Total images: {len(images)}")
print(f"Label distribution: {np.bincount(labels)}")

NPZ file saved to: /home/mhs/thesis/fastMRI/fastmri_breast_data_80.npz
Total images: 1500
Label distribution: [255 450 795]


In [7]:
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# ─── 1) Load the NPZ file ─────────────────────────────────────────────
npz_path = "/home/mhs/thesis/fastMRI/fastmri_breast_data_80.npz"
data = np.load(npz_path, allow_pickle=True)

# images: shape (N, H, W), dtype e.g. uint16 or float32
# labels: shape (N,), dtype int in {0,1,2}
# label_mapping: dict {0:"negative", 1:"malignancy", 2:"benign"}
images_all       = data["images"]
labels_all       = data["labels"]
label_mapping    = data["label_mapping"].item()
image_filenames  = data["image_filenames"]  # (optional – not used below)

print(f"Total images in NPZ: {images_all.shape[0]}")
print(f"Label distribution (all): {np.bincount(labels_all)}")
print(f"Label mapping: {label_mapping}")


# ─── 2) Stratified train/test split ───────────────────────────────────
# We’ll hold out 20% for testing, 80% for training, stratifying by label.
train_imgs, test_imgs, train_lbls, test_lbls = train_test_split(
    images_all,
    labels_all,
    test_size=0.20,
    stratify=labels_all,
    random_state=42,
)

print("\nAfter stratified split:")
print(f"  Train set size: {train_imgs.shape[0]}")
print(f"  Train label counts: {np.bincount(train_lbls)}")
print(f"  Test  set size: {test_imgs.shape[0]}")
print(f"  Test  label counts: {np.bincount(test_lbls)}")


# ─── 3) Define the same transforms as your reference ───────────────────
# (You can omit `Grayscale` if your array is already single‐channel, but we keep it here
#  in case you want to re‐enforce a single channel and then repeat to 3 channels.)
transform = transforms.Compose([
    transforms.ToPILImage(),                         # numpy (H,W) → PIL Image
    transforms.Grayscale(num_output_channels=1),     # ensure 1¬channel
    transforms.Resize((224, 224)),                   # 224×224 for backbone
    transforms.ToTensor(),                           # PIL → [0,1] Tensor C×H×W
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 1¬channel → 3¬channel
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])


# ─── 4) Create a custom Dataset to wrap (images, labels) pairs ──────────
class FastMRIBreastNPZDataset(Dataset):
    def __init__(self, images_np: np.ndarray, labels_np: np.ndarray, transform=None):
        """
        images_np: np.ndarray of shape (N, H, W), dtype e.g. uint16 or float32
        labels_np: np.ndarray of shape (N,), dtype int (0,1,2)
        """
        assert images_np.ndim == 3, "Images should be (N, H, W)"
        assert labels_np.ndim == 1 and len(labels_np) == images_np.shape[0]
        self.images = images_np
        self.labels = labels_np
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # 1) Grab the raw 2D numpy array (H,W)
        img_np = self.images[idx]

        # 2) Convert to PIL + apply transforms → Tensor of shape (3,224,224)
        if self.transform:
            img_tensor = self.transform(img_np)
        else:
            # If no transform is given, convert directly to Tensor (1, H, W)
            img_tensor = torch.from_numpy(img_np).unsqueeze(0).float()

        # 3) Return (image_tensor, label)
        lbl = int(self.labels[idx])
        return img_tensor, lbl


# ─── 5) Instantiate Datasets for train & test ─────────────────────────
train_dataset = FastMRIBreastNPZDataset(train_imgs, train_lbls, transform=transform)
test_dataset  = FastMRIBreastNPZDataset(test_imgs,  test_lbls,  transform=transform)

print(f"\nDataset objects created:")
print(f"  train_dataset: {len(train_dataset)} samples")
print(f"  test_dataset:  {len(test_dataset)} samples")


# ─── 6) Create DataLoaders ─────────────────────────────────────────────
batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,       # shuffle training data each epoch
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,      # no need to shuffle test data
    num_workers=4,
    pin_memory=True
)


# ─── 7) Final sanity checks ────────────────────────────────────────────
num_classes = len(label_mapping)
print(f"\nNumber of classes: {num_classes}")
print(f"Training samples (in DataLoader): {len(train_dataset)}")
print(f"Testing  samples (in DataLoader): {len(test_dataset)}")

# Verify one‐batch to ensure shapes are correct
batch_imgs, batch_lbls = next(iter(train_loader))
print(f"\nExample batch shapes: images {batch_imgs.shape}, labels {batch_lbls.shape}")
# → images should be (batch_size, 3, 224, 224), labels (batch_size,)


Total images in NPZ: 1500
Label distribution (all): [255 450 795]
Label mapping: {0: 'negative', 1: 'malignancy', 2: 'benign'}

After stratified split:
  Train set size: 1200
  Train label counts: [204 360 636]
  Test  set size: 300
  Test  label counts: [ 51  90 159]

Dataset objects created:
  train_dataset: 1200 samples
  test_dataset:  300 samples

Number of classes: 3
Training samples (in DataLoader): 1200
Testing  samples (in DataLoader): 300

Example batch shapes: images torch.Size([32, 3, 224, 224]), labels torch.Size([32])


In [8]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, classification_report

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np

class BSpline(nn.Module):
    """B-spline basis function for KAN layer."""
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order
        
        # Initialize grid points
        grid = torch.linspace(-1, 1, grid_size)
        self.grid = nn.Parameter(grid, requires_grad=False)
        
        # Spline coefficients
        self.coeff = nn.Parameter(
            torch.randn(out_features, in_features, grid_size + spline_order - 1) * 0.1
        )
        
    def bspline_basis(self, x, grid, order):
        """Compute B-spline basis functions."""
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:-1]) & (x < grid[1:])).float()
        for k in range(1, order + 1):
            bases = ((x - grid[:-k-1]) / (grid[k:-1] - grid[:-k-1]) * bases[..., :-1] +
                    (grid[k+1:] - x) / (grid[k+1:] - grid[1:-k]) * bases[..., 1:])
        return bases
    
    def forward(self, x):
        x = x.view(-1, self.in_features)
        x = (x - x.mean(dim=0, keepdim=True)) / (x.std(dim=0, keepdim=True) + 1e-6)
        basis = self.bspline_basis(x, self.grid, self.spline_order)
        out = torch.einsum('bi...,oi...->bo', basis, self.coeff)
        return out

class KANLayer(nn.Module):
    """KAN layer replacing linear layer with spline-based transformation."""
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.spline = BSpline(in_features, out_features, grid_size, spline_order)
        self.norm = nn.LayerNorm(out_features)
        
    def forward(self, x):
        x = self.spline(x)
        x = self.norm(x)
        return x

class KernelAttention(nn.Module):
    def __init__(self, in_dim, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(in_dim, in_dim, kernel_size=kernel_size, 
                             padding=kernel_size//2, groups=in_dim)
        self.spatial_gate = nn.Sequential(
            nn.Conv2d(in_dim, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        local_feat = self.conv(x)
        attn = self.spatial_gate(local_feat)
        return x * attn

class KANModel(nn.Module):
    def __init__(self, num_classes, backbone='resnet18'):
        super().__init__()
        
        # 1. Feature Extraction Backbone
        if backbone == 'resnet18':
            base = models.resnet18(pretrained=True)
            self.feature_dim = 512
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
            
        self.features = nn.Sequential(*list(base.children())[:-2])
        
        # 2. Kernel Attention Module
        self.attention = KernelAttention(self.feature_dim)
        
        # 3. Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # 4. KAN-based Classifier
        self.classifier = nn.Sequential(
            KANLayer(self.feature_dim, 256, grid_size=5, spline_order=3),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            KANLayer(256, num_classes, grid_size=5, spline_order=3)
        )
        
    def forward(self, x):
        x = self.features(x)  # [B, 512, H', W']
        x = self.attention(x)
        x = self.gap(x)      # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 512]
        return self.classifier(x)

In [22]:
# class KernelAttention(nn.Module):
#     def __init__(self, in_dim, kernel_size=7):
#         super().__init__()
#         self.conv = nn.Conv2d(in_dim, in_dim, kernel_size=kernel_size, 
#                              padding=kernel_size//2, groups=in_dim)
#         self.spatial_gate = nn.Sequential(
#             nn.Conv2d(in_dim, 1, kernel_size=1),
#             nn.Sigmoid()
#         )
        
#     def forward(self, x):
#         # Local feature aggregation
#         local_feat = self.conv(x)
#         # Generate attention weights
#         attn = self.spatial_gate(local_feat)
#         return x * attn

# class KANModel(nn.Module):
#     def __init__(self, num_classes, backbone='resnet18'):
#         super().__init__()
        
#         # 1. Feature Extraction Backbone
#         if backbone == 'resnet18':
#             base = models.resnet18(pretrained=True)
#             self.feature_dim = 512
#         else:
#             raise ValueError(f"Unsupported backbone: {backbone}")
            
#         # Remove the final FC layer
#         self.features = nn.Sequential(*list(base.children())[:-2])
        
#         # 2. Kernel Attention Module
#         self.attention = KernelAttention(self.feature_dim)
        
#         # 3. Global Average Pooling
#         self.gap = nn.AdaptiveAvgPool2d(1)
        
#         # 4. Classifier
#         self.classifier = nn.Sequential(
#             nn.Linear(self.feature_dim, 256),
#             nn.ReLU(inplace=True),
#             nn.Dropout(0.5),
#             nn.Linear(256, num_classes)
#         )
        
#     def forward(self, x):
#         # Extract features
#         x = self.features(x)  # [B, 512, H', W']
        
#         # Apply kernel attention
#         x = self.attention(x)
        
#         # Global average pooling
#         x = self.gap(x)      # [B, 512, 1, 1]
#         x = x.view(x.size(0), -1)  # [B, 512]
        
#         # Classification
#         return self.classifier(x)

In [10]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = KANModel(num_classes=num_classes).to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Learning rate scheduler (removed verbose parameter)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.5,
    patience=5
)

Using device: cuda




In [11]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc='Training'):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Compute epoch metrics
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Compute metrics
    avg_loss = running_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    # Use label_mapping values as class names
    class_names = [label_mapping[i] for i in range(len(label_mapping))]
    report = classification_report(
        all_labels, 
        all_preds, 
        target_names=class_names
    )
    
    return avg_loss, accuracy, report

# Training loop
num_epochs = 10
best_acc = 0.0

for epoch in range(1, num_epochs + 1):
    # Training phase
    train_loss, train_acc = train_epoch(model, train_loader, criterion, 
                                      optimizer, device)
    
    # Evaluation phase
    val_loss, val_acc, val_report = evaluate(model, test_loader, criterion, device)
    
    # Learning rate scheduling
    scheduler.step(val_acc)
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'kan_best_model.pth')
    
    # Print epoch results
    print(f"\nEpoch {epoch}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Print detailed validation report every 5 epochs
    if epoch % 5 == 0:
        print("\nValidation Report:")
        print(val_report)

# Final evaluation
print("\nLoading best model for final evaluation...")
model.load_state_dict(torch.load('kan_best_model.pth'))
test_loss, test_acc, test_report = evaluate(model, test_loader, criterion, device)

print("\nFinal Test Results:")
print(f"Test Accuracy: {test_acc:.4f}")
print("\nDetailed Classification Report:")
print(test_report)

Training: 100%|██████████| 38/38 [00:02<00:00, 16.79it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 44.96it/s]



Epoch 1/10
Train Loss: 1.1568, Train Acc: 0.5142
Val Loss: 1.2375, Val Acc: 0.4133


Training: 100%|██████████| 38/38 [00:54<00:00,  1.44s/it]
Evaluating: 100%|██████████| 10/10 [-1:59:07<00:00, -0.19it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2/10
Train Loss: 1.0882, Train Acc: 0.5175
Val Loss: 1.1574, Val Acc: 0.4900


Training: 100%|██████████| 38/38 [00:01<00:00, 25.46it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 43.52it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3/10
Train Loss: 1.0928, Train Acc: 0.5200
Val Loss: 1.1275, Val Acc: 0.5167


Training: 100%|██████████| 38/38 [00:01<00:00, 25.65it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 46.98it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4/10
Train Loss: 1.0783, Train Acc: 0.5200
Val Loss: 1.0947, Val Acc: 0.5233


Training: 100%|██████████| 38/38 [00:01<00:00, 25.61it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 46.36it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5/10
Train Loss: 1.0679, Train Acc: 0.5292
Val Loss: 1.1106, Val Acc: 0.5400

Validation Report:
              precision    recall  f1-score   support

    negative       0.00      0.00      0.00        51
  malignancy       0.50      0.12      0.20        90
      benign       0.54      0.95      0.69       159

    accuracy                           0.54       300
   macro avg       0.35      0.36      0.30       300
weighted avg       0.44      0.54      0.43       300



Training: 100%|██████████| 38/38 [00:01<00:00, 25.46it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 50.11it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 6/10
Train Loss: 1.0694, Train Acc: 0.5242
Val Loss: 1.1013, Val Acc: 0.5067


Training: 100%|██████████| 38/38 [00:01<00:00, 26.09it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 43.94it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 7/10
Train Loss: 1.0654, Train Acc: 0.5258
Val Loss: 1.1283, Val Acc: 0.5067


Training: 100%|██████████| 38/38 [00:01<00:00, 25.44it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 44.53it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 8/10
Train Loss: 1.0652, Train Acc: 0.5275
Val Loss: 1.1362, Val Acc: 0.4767


Training: 100%|██████████| 38/38 [00:01<00:00, 26.04it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 44.80it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 9/10
Train Loss: 1.0637, Train Acc: 0.5308
Val Loss: 1.1108, Val Acc: 0.5067


Training: 100%|██████████| 38/38 [00:01<00:00, 25.78it/s]
Evaluating: 100%|██████████| 10/10 [00:00<00:00, 45.35it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 10/10
Train Loss: 1.0587, Train Acc: 0.5300
Val Loss: 1.1072, Val Acc: 0.5233

Validation Report:
              precision    recall  f1-score   support

    negative       0.00      0.00      0.00        51
  malignancy       0.31      0.04      0.08        90
      benign       0.53      0.96      0.69       159

    accuracy                           0.52       300
   macro avg       0.28      0.34      0.25       300
weighted avg       0.37      0.52      0.39       300


Loading best model for final evaluation...


Evaluating: 100%|██████████| 10/10 [00:00<00:00, 46.66it/s]


Final Test Results:
Test Accuracy: 0.5400

Detailed Classification Report:
              precision    recall  f1-score   support

    negative       0.00      0.00      0.00        51
  malignancy       0.50      0.12      0.20        90
      benign       0.54      0.95      0.69       159

    accuracy                           0.54       300
   macro avg       0.35      0.36      0.30       300
weighted avg       0.44      0.54      0.43       300




  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [12]:
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold

import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms

# ─── 1) Load the NPZ file ─────────────────────────────────────────────
npz_path = "/home/mhs/thesis/fastMRI/fastmri_breast_data_80.npz"
data = np.load(npz_path, allow_pickle=True)

images_all = data["images"]      # shape: (N, H, W)
labels_all = data["labels"]      # shape: (N,)
label_mapping = data["label_mapping"].item()
# (Optional) image_filenames = data["image_filenames"]

print(f"Total images in NPZ: {images_all.shape[0]}")
print(f"Original label distribution: {np.bincount(labels_all)}")
print(f"Label mapping: {label_mapping}")


# ─── 2) Stratified 80/20 train/test split ─────────────────────────────
train_imgs, test_imgs, train_lbls, test_lbls = train_test_split(
    images_all,
    labels_all,
    test_size=0.20,
    stratify=labels_all,
    random_state=42
)

print("\nAfter stratified train/test split:")
print(f"  Train set size: {train_imgs.shape[0]}")
print(f"  Train label counts: {np.bincount(train_lbls)}")
print(f"  Test  set size: {test_imgs.shape[0]}")
print(f"  Test  label counts: {np.bincount(test_lbls)}")


# ─── 3) Define transforms (same as your reference) ────────────────────
transform = transforms.Compose([
    transforms.ToPILImage(),                         # numpy (H,W) → PIL Image
    transforms.Grayscale(num_output_channels=1),     # ensure single‐channel
    transforms.Resize((224, 224)),                   # resize for backbone
    transforms.ToTensor(),                           # PIL → [0,1] Tensor C×H×W
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 1‐channel → 3‐channel
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])


# ─── 4) Create a custom Dataset class ─────────────────────────────────
class FastMRIBreastNPZDataset(Dataset):
    def __init__(self, images_np: np.ndarray, labels_np: np.ndarray, transform=None):
        """
        images_np: np.ndarray of shape (N, H, W)
        labels_np: np.ndarray of shape (N,)
        """
        assert images_np.ndim == 3, "Expected images_np of shape (N, H, W)"
        assert labels_np.ndim == 1 and len(labels_np) == images_np.shape[0]
        self.images = images_np
        self.labels = labels_np
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_np = self.images[idx]            # 2D array (H, W)
        if self.transform:
            img_tensor = self.transform(img_np)
        else:
            # Fallback: convert to FloatTensor (1, H, W) if no transform
            img_tensor = torch.from_numpy(img_np).unsqueeze(0).float()

        label = int(self.labels[idx])
        return img_tensor, label


# ─── 5) Instantiate Datasets for train & test ─────────────────────────
train_dataset = FastMRIBreastNPZDataset(train_imgs, train_lbls, transform=transform)
test_dataset  = FastMRIBreastNPZDataset(test_imgs,  test_lbls,  transform=transform)

print(f"\nDataset sizes:")
print(f"  train_dataset: {len(train_dataset)} samples")
print(f"  test_dataset:  {len(test_dataset)} samples")


# ─── 6) Stratified 10-Fold on the training set ────────────────────────
NUM_FOLDS = 10
batch_size = 32

# Prepare the StratifiedKFold splitter
skf = StratifiedKFold(
    n_splits=NUM_FOLDS,
    shuffle=True,
    random_state=42
)

# We need the array of labels for the training set
all_train_labels = train_lbls.copy()

# To store (train_loader, val_loader) for each fold:
fold_data_loaders = []

for fold_idx, (train_idx, val_idx) in enumerate(
        skf.split(np.zeros(len(all_train_labels)), all_train_labels)
    ):

    print(f"\n=== Fold {fold_idx+1}/{NUM_FOLDS} ===")
    print(f"  Train indices: {len(train_idx)},  Val indices: {len(val_idx)}")

    # Create Subset objects for this fold
    train_subset = Subset(train_dataset, train_idx)
    val_subset   = Subset(train_dataset, val_idx)

    # Create DataLoaders for this fold
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,    # shuffle the train subset each epoch
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,   # no need to shuffle validation
        num_workers=4,
        pin_memory=True
    )

    # Inspect label distribution in this fold (optional)
    fold_train_labels = all_train_labels[train_idx]
    fold_val_labels   = all_train_labels[val_idx]
    print("  Fold Train label counts:", np.bincount(fold_train_labels))
    print("  Fold Val   label counts:", np.bincount(fold_val_labels))

    fold_data_loaders.append((train_loader, val_loader))


# ─── 7) (Optional) Example Training Loop Sketch ───────────────────────


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for fold_idx, (train_loader, val_loader) in enumerate(fold_data_loaders):
    print(f"\n>> Training on Fold {fold_idx+1}/{NUM_FOLDS}")

    # 1) Initialize a fresh model
    model = KANModel(num_classes=len(label_mapping)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = torch.nn.CrossEntropyLoss()

    NUM_EPOCHS = 5
    for epoch in range(NUM_EPOCHS):
        # ---- Training Loop ----
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        for inputs, targets in train_loader:
            inputs = inputs.to(device)       # (batch, 3, 224, 224)
            targets = targets.to(device)     # (batch,)
            optimizer.zero_grad()
            outputs = model(inputs)          # (batch, num_classes)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == targets).item()
            total_samples += inputs.size(0)

        epoch_loss = running_loss / total_samples
        epoch_acc  = running_corrects / total_samples
        print(f" Fold {fold_idx+1} | Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}")

        # ---- Validation Loop ----
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_samples = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += torch.sum(preds == targets).item()
                val_samples += inputs.size(0)

        val_epoch_loss = val_loss / val_samples
        val_epoch_acc  = val_corrects / val_samples
        print(f"             | Epoch {epoch+1}/{NUM_EPOCHS} | "
              f" Val Loss: {val_epoch_loss:.4f} | Val Acc: {val_epoch_acc:.4f}")

    # Optionally: save fold checkpoint
    # torch.save(model.state_dict(), f"model_fold{fold_idx+1}.pth")


# ─── 8) Final Test-Set Evaluation ───────────────────────────────────────


final_test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# Suppose `model` is your final, fully-trained network on all training data:
model.eval()
test_loss = 0.0
test_corrects = 0
test_total = 0
with torch.no_grad():
    for inputs, targets in final_test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        test_corrects += torch.sum(preds == targets).item()
        test_total += inputs.size(0)

test_loss = test_loss / test_total
test_acc  = test_corrects / test_total
print(f"\nFinal Test Loss: {test_loss:.4f} | Final Test Acc: {test_acc:.4f}")


Total images in NPZ: 1500
Original label distribution: [255 450 795]
Label mapping: {0: 'negative', 1: 'malignancy', 2: 'benign'}

After stratified train/test split:
  Train set size: 1200
  Train label counts: [204 360 636]
  Test  set size: 300
  Test  label counts: [ 51  90 159]

Dataset sizes:
  train_dataset: 1200 samples
  test_dataset:  300 samples

=== Fold 1/10 ===
  Train indices: 1080,  Val indices: 120
  Fold Train label counts: [183 324 573]
  Fold Val   label counts: [21 36 63]

=== Fold 2/10 ===
  Train indices: 1080,  Val indices: 120
  Fold Train label counts: [183 324 573]
  Fold Val   label counts: [21 36 63]

=== Fold 3/10 ===
  Train indices: 1080,  Val indices: 120
  Fold Train label counts: [183 324 573]
  Fold Val   label counts: [21 36 63]

=== Fold 4/10 ===
  Train indices: 1080,  Val indices: 120
  Fold Train label counts: [183 324 573]
  Fold Val   label counts: [21 36 63]

=== Fold 5/10 ===
  Train indices: 1080,  Val indices: 120
  Fold Train label counts:



 Fold 1 | Epoch 1/5 | Train Loss: 1.4013 | Train Acc: 0.3769
             | Epoch 1/5 |  Val Loss: 1.1628 | Val Acc: 0.4583
 Fold 1 | Epoch 2/5 | Train Loss: 1.0918 | Train Acc: 0.5185
             | Epoch 2/5 |  Val Loss: 1.0977 | Val Acc: 0.5000
 Fold 1 | Epoch 3/5 | Train Loss: 1.0767 | Train Acc: 0.5250
             | Epoch 3/5 |  Val Loss: 1.1375 | Val Acc: 0.4250
 Fold 1 | Epoch 4/5 | Train Loss: 1.0800 | Train Acc: 0.5278
             | Epoch 4/5 |  Val Loss: 1.1655 | Val Acc: 0.3667
 Fold 1 | Epoch 5/5 | Train Loss: 1.0736 | Train Acc: 0.5213
             | Epoch 5/5 |  Val Loss: 1.0892 | Val Acc: 0.4917

>> Training on Fold 2/10
 Fold 2 | Epoch 1/5 | Train Loss: 1.7782 | Train Acc: 0.2630
             | Epoch 1/5 |  Val Loss: 1.4234 | Val Acc: 0.3167
 Fold 2 | Epoch 2/5 | Train Loss: 1.1634 | Train Acc: 0.4796
             | Epoch 2/5 |  Val Loss: 1.1844 | Val Acc: 0.4833
 Fold 2 | Epoch 3/5 | Train Loss: 1.1010 | Train Acc: 0.5167
             | Epoch 3/5 |  Val Loss: 1.1036 