In [None]:
!wandb login (api key)


In [None]:
import os
import pandas as pd
import numpy as np
from glob import glob
import time
import warnings
import re 
import wandb # Import Weights & Biases

# Neuroimaging libraries
import nibabel as nib
from scipy.ndimage import zoom

# PyTorch for deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

# Scikit-learn for evaluation
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

# Suppress unnecessary warnings for a cleaner output
warnings.filterwarnings("ignore", category=UserWarning)

# --- Login to W&B ---
# IMPORTANT: It is highly recommended to log in via your terminal first.
# 1. Open your terminal or Anaconda Prompt.
# 2. Type 'wandb login' and paste your API key.
# 3. Once done, you can comment out or delete the line below.
# wandb.login()


In [None]:
BASE_DIR = r'C:\Users\aadis\abide'
DATA_DIR = os.path.join(BASE_DIR, 'abide_data', 'Outputs', 'cpac', 'nofilt_noglobal', 'func_preproc')
PHENOTYPIC_FILE = os.path.join(BASE_DIR, 'Phenotypic_V1_0b_preprocessed1.csv')

# --- 2.2: Model & Training Hyperparameters ---
TARGET_SHAPE = (64, 64, 64)
BATCH_SIZE = 32
EPOCHS = 15
LEARNING_RATE = 1e-4
N_SPLITS = 5
RANDOM_STATE = 42
PROJECT_NAME = "ASD-fMRI-Classification" # Name for your W&B project

# --- 2.3: System Configuration ---
NUM_WORKERS = 0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Project Base Directory: {BASE_DIR}")
print(f"Data Directory: {DATA_DIR}")
print(f"Using device: {DEVICE}")



In [None]:

def load_phenotypic_data(phenotypic_file, data_dir):
    """Loads and prepares the phenotypic data from the CSV file."""
    print("--- Starting Data Loading and Preparation ---")
    if not os.path.exists(phenotypic_file):
        print(f"ERROR: Phenotypic file not found at {phenotypic_file}")
        return None
    pheno_df = pd.read_csv(phenotypic_file, encoding='latin1')
    
    # Select the required columns
    pheno_df = pheno_df[['SUB_ID', 'DX_GROUP']]
    
    # Apply the conversion only to the 'DX_GROUP' column.
    pheno_df['DX_GROUP'] = pheno_df['DX_GROUP'].apply(lambda x: 1 if x == 1 else 0)
    pheno_df.set_index('SUB_ID', inplace=True)
    
    # Search for files ending in .nii OR .nii.gz
    all_files = glob(os.path.join(data_dir, '**', '*.nii*'), recursive=True)
    if not all_files:
        print(f"ERROR: No .nii or .nii.gz files found.")
        return None
        
    subject_ids_from_files = []
    for f in all_files:
        basename = os.path.basename(f)
        match = re.search(r'_(\d{5,})', basename)
        if match:
            subject_ids_from_files.append(int(match.group(1)))
            
    valid_subjects_df = pheno_df.loc[pheno_df.index.isin(subject_ids_from_files)]
    print(f"Found {len(valid_subjects_df)} subjects with BOTH phenotypic data AND an fMRI scan.")
    if len(valid_subjects_df) == 0:
        print("\nCRITICAL ERROR: No matching subjects found.")
        return None
    return valid_subjects_df

class ABIDEDataset(Dataset):
    """Custom PyTorch Dataset for loading ABIDE fMRI data on-the-fly."""
    def __init__(self, subject_df, data_dir, target_shape):
        self.subject_df = subject_df
        self.target_shape = target_shape
        self.subjects = self.subject_df.index.tolist()
        self.file_paths = {}
        # Search for files ending in .nii OR .nii.gz
        all_filepaths = glob(os.path.join(data_dir, '**', '*.nii*'), recursive=True)
        for f in all_filepaths:
            basename = os.path.basename(f)
            match = re.search(r'_(\d{5,})', basename)
            if match:
                sub_id = int(match.group(1))
                self.file_paths[sub_id] = f

    def __len__(self):
        return len(self.subjects)

    def __getitem__(self, idx):
        subject_id = self.subjects[idx]
        filepath = self.file_paths.get(subject_id)
        if filepath is None:
            # BUG FIX: Return a numerical placeholder instead of a string
            return torch.zeros((1, *self.target_shape)), -1, -1
        try:
            img = nib.load(filepath)
            data = img.get_fdata()
            if data.ndim == 4:
                data = data.mean(axis=-1)
        except Exception as e:
            print(f"Error loading {filepath}: {e}")
            # BUG FIX: Return a numerical placeholder instead of a string
            return torch.zeros((1, *self.target_shape)), -1, -1
        zoom_factors = [t / s for t, s in zip(self.target_shape, data.shape)]
        resized_data = zoom(data, zoom_factors, order=1)
        mean, std = np.mean(resized_data), np.std(resized_data)
        normalized_data = (resized_data - mean) / std if std > 0 else resized_data
        tensor_data = torch.from_numpy(normalized_data).float().unsqueeze(0)
        label = self.subject_df.loc[subject_id, 'DX_GROUP']
        # Returning subject_id for potential debugging, site is removed
        return tensor_data, label, subject_id



In [None]:
class Simple3DCNN(nn.Module):
    """A simple 3D CNN to serve as a performance baseline."""
    def __init__(self, num_classes=2):
        super(Simple3DCNN, self).__init__()
        self.conv_block1 = self._create_conv_block(1, 8)
        self.conv_block2 = self._create_conv_block(8, 16)
        self.conv_block3 = self._create_conv_block(16, 32)
        self.flattened_size = 32 * (TARGET_SHAPE[0] // 8) * (TARGET_SHAPE[1] // 8) * (TARGET_SHAPE[2] // 8)
        self.classifier = nn.Sequential(nn.Linear(self.flattened_size, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, num_classes))
    def _create_conv_block(self, in_channels, out_channels):
        return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(), nn.MaxPool3d(2))
    def forward(self, x):
        x = self.conv_block1(x); x = self.conv_block2(x); x = self.conv_block3(x)
        x = x.view(x.size(0), -1); x = self.classifier(x)
        return x

def get_pretrained_3d_resnet(num_classes=2):
    """Loads a pre-trained 3D ResNet and adapts it for our 1-channel fMRI data."""
    print("Loading and adapting pre-trained 3D ResNet...")
    model = models.video.r3d_18(weights=models.video.R3D_18_Weights.KINETICS400_V1)
    # Correctly access the first convolutional layer in the stem
    original_weights = model.stem[0].weight.data
    new_weights = original_weights.mean(dim=1, keepdim=True)
    # Create a new Conv3d layer and assign the averaged weights
    new_first_layer = nn.Conv3d(1, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    new_first_layer.weight.data = new_weights
    # Replace the original first layer (which is the 0-th element of the stem Sequential module)
    model.stem[0] = new_first_layer
    
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model



In [None]:
def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    """A reusable function to handle the training and validation loop."""
    best_val_auc = 0.0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for inputs, labels, _ in train_loader:
            if -1 in labels: continue
            inputs, labels = inputs.to(device), labels.to(device).long()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        all_labels, all_preds = [], []
        with torch.no_grad():
            for inputs, labels, _ in val_loader:
                if -1 in labels: continue
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                preds = torch.softmax(outputs, dim=1)[:, 1]
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
        
        val_auc = 0.0
        if len(all_labels) > 0:
            val_auc = roc_auc_score(all_labels, all_preds)
            print(f"Epoch {epoch+1}/{epochs} | Validation AUC: {val_auc:.4f}")
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                torch.save(model.state_dict(), 'best_model_in_fold.pth')
        
        # --- W&B Logging ---
        wandb.log({
            "epoch": epoch + 1,
            "validation_auc": val_auc,
            "training_loss": train_loss / len(train_loader)
        })
            
    if os.path.exists('best_model_in_fold.pth'):
        model.load_state_dict(torch.load('best_model_in_fold.pth'))
    return model, best_val_auc



In [None]:
full_pheno_df = load_phenotypic_data(PHENOTYPIC_FILE, DATA_DIR)

# --- SCALE DOWN THE DATASET FOR DRAFTING RESULTS ---
if full_pheno_df is not None:
    # Set this fraction to control the dataset size. 0.12 is ~12% or ~10GB.
    # To run on the full dataset, comment out this entire block.
    #sample_fraction = 0.04
    
    # Use train_test_split to create a smaller, stratified sample.
    # We use the '_' to discard the larger part of the split that we don't need.
   # _, full_pheno_df = train_test_split(
    #    full_pheno_df,
     #   train_size=sample_fraction,
      #  stratify=full_pheno_df['DX_GROUP'], # Stratify by diagnosis to keep the class ratio
       # random_state=RANDOM_STATE
    #)
    #print(f"\n--- SCALING DOWN DATASET FOR FASTER DRAFTING ---")
    #print(f"Using a {int(sample_fraction*100)}% stratified sample: {len(full_pheno_df)} subjects total.")
# ---------------------------------------------

if full_pheno_df is not None:
    subject_ids = full_pheno_df.index.values
    labels = full_pheno_df['DX_GROUP'].values # Use the 1D diagnosis column for splitting
    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
    fold_results = []
    
    print(f"\nStarting {N_SPLITS}-Fold Cross-Validation...")

    for fold, (train_idx, val_idx) in enumerate(skf.split(subject_ids, labels)):
        print(f"\n===== FOLD {fold+1}/{N_SPLITS} =====")
        
        run = None # Initialize run to None
        try:
            run = wandb.init(
                project=PROJECT_NAME,
                group="Cross-Validation-Draft", # Use a different group for draft runs
                name=f"Fold-{fold+1}",
                reinit=True
            )
        except wandb.errors.CommError as e:
            print(f"W&B Error: Could not initialize run. Please ensure you are logged in.")
            print(f"To log in, run 'wandb login' in your terminal and paste your API key.")
            print(f"Original error: {e}")
            # We can choose to break or continue without wandb, let's break for now.
            break

        train_subjects_df = full_pheno_df.iloc[train_idx]
        val_subjects_df = full_pheno_df.iloc[val_idx]
        train_dataset = ABIDEDataset(train_subjects_df, DATA_DIR, TARGET_SHAPE)
        val_dataset = ABIDEDataset(val_subjects_df, DATA_DIR, TARGET_SHAPE)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
        
        model = get_pretrained_3d_resnet().to(DEVICE)
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
        criterion = nn.CrossEntropyLoss()
        
        _, best_auc = train_and_evaluate(
            model=model, train_loader=train_loader, val_loader=val_loader,
            optimizer=optimizer, criterion=criterion, epochs=EPOCHS, device=DEVICE
        )
        
        print(f"Fold {fold+1} Best AUC: {best_auc:.4f}")
        
        if run: # Check if wandb.init was successful
            wandb.summary["best_fold_auc"] = best_auc
            run.finish()

else:
    print("\nExecution stopped due to data loading errors.")

