In [7]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pydicom
import nibabel as nib
import cv2

# ----------------------------
# 1. SCAN LOADING & PREPROCESSING
# ----------------------------
def load_scan(file_path):
    if file_path.endswith('.nii') or file_path.endswith('.nii.gz'):
        img = nib.load(file_path)
        return img.get_fdata()
    elif os.path.isdir(file_path):
        slices = [pydicom.dcmread(os.path.join(file_path, f)) for f in os.listdir(file_path) if f.endswith('.dcm')]
        slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
        volume = np.stack([s.pixel_array for s in slices], axis=-1)
        return volume
    else:
        raise ValueError("Unsupported format")

def normalize_volume(volume):
    volume = volume.astype(np.float32)
    volume -= np.min(volume)
    volume /= np.max(volume)
    return volume

def extract_12_slices(volume):
    slices = []
    z_len = volume.shape[2]
    axial_indices = np.linspace(0, z_len - 1, 10, dtype=int)
    axial_slices = [volume[:, :, idx] for idx in axial_indices]

    coronal_idx = volume.shape[1] // 2
    coronal_slice = volume[:, coronal_idx, :]

    sagittal_idx = volume.shape[0] // 2
    sagittal_slice = volume[sagittal_idx, :, :]

    slices.extend(axial_slices)
    slices.append(coronal_slice)
    slices.append(sagittal_slice)

    return slices

def preprocess_scan(path, target_size=(224, 224)):
    volume = load_scan(path)
    volume = normalize_volume(volume)
    slices = extract_12_slices(volume)
    resized = [cv2.resize(s, target_size, interpolation=cv2.INTER_AREA) for s in slices]
    return np.stack(resized, axis=0)  # (12, 224, 224)

# ----------------------------
# 2. DATASET CLASS
# ----------------------------
class CTADataset(Dataset):
    def __init__(self, labels_csv, root_dir, transform=None):
        self.data = pd.read_csv(labels_csv)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data.iloc[idx]
        file_path = os.path.join(self.root_dir, entry['filename'])
        slices = preprocess_scan(file_path)  # (12, 224, 224)
        slices = torch.tensor(slices).unsqueeze(1).float()  # (12, 1, 224, 224)

        if self.transform:
            slices = self.transform(slices)

        label = torch.tensor(entry['label']).float()
        return slices, label

# ----------------------------
# 3. MODEL DEFINITION (2D CNN + Aggregation)
# ----------------------------
class SliceCNN(nn.Module):
    def __init__(self, feature_dim=128):
        super(SliceCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(128, feature_dim)

    def forward(self, x):  # (B, 1, 224, 224)
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class CTAQuality2D(nn.Module):
    def __init__(self, feature_dim=128):
        super(CTAQuality2D, self).__init__()
        self.slice_cnn = SliceCNN(feature_dim)
        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):  # (B, 12, 1, 224, 224)
        B, S, C, H, W = x.shape
        x = x.view(B * S, C, H, W)
        feats = self.slice_cnn(x)  # (B×12, F)
        feats = feats.view(B, S, -1)
        pooled = feats.mean(dim=1)
        out = self.regressor(pooled)
        return out.squeeze(1)

# ----------------------------
# 4. TRAINING AND EVALUATION LOOPS
# ----------------------------
def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
    return total_loss / len(dataloader.dataset)

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    preds, targets = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            preds.append(outputs.cpu())
            targets.append(labels.cpu())
    preds = torch.cat(preds)
    targets = torch.cat(targets)
    return total_loss / len(dataloader.dataset), preds, targets

# ----------------------------
# 5. TRAINING DRIVER CODE
# ----------------------------
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error

# Settings
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root_dir = 'root_dir = "C:\\Users\\HenryLi\\Downloads\\Scans"'
labels_csv = "C:/Users/HenryLi/Desktop/Python Projects/Primitive Image Classifier/Labels.csv"
batch_size = 8
num_epochs = 30

# Load and split CSV
df = pd.read_csv(labels_csv)
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
train_df.to_csv('train.csv', index=False)
val_df.to_csv('val.csv', index=False)

train_loader = DataLoader(CTADataset('train.csv', root_dir), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(CTADataset('val.csv', root_dir), batch_size=batch_size)

model = CTAQuality2D().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

best_val_loss = float('inf')
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_preds, val_targets = evaluate(model, val_loader, criterion)
    scheduler.step()

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')

    r2 = r2_score(val_targets.numpy(), val_preds.numpy())
    mae = mean_absolute_error(val_targets.numpy(), val_preds.numpy())
    print(f"Val R^2: {r2:.3f}, MAE: {mae:.3f}")


FileNotFoundError: No such file or no access: 'root_dir = "C:/Users/HenryLi/Downloads/Scans"/PREFFIR-11115 (Cor CTAAdapt 0.75 B26f 75%).nii.gz'