In [None]:
#!/usr/bin/env python3
import os
import torch
import pandas as pd
import logging

# Configure logging.
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

def pre_concatenate(embedding_dirs, csv_root, output_dir):
    """
    For each CSV file in the directory structure of csv_root (which contains
    subdirectories like fold-0, fold-1, fold-2 with train.csv, test.csv, and tune.csv),
    this function loads all matching .pt files from embedding_dirs for each case_id.
    The case_id (from the CSV) is only a part of the slide's filename, so a partial
    (substring) match is used to locate the corresponding .pt files. The matching tensors
    are concatenated and saved in the corresponding output folder structure:
    
      output_dir/fold-*/[train|tune|test]/{case_id}.pt
    """
    logging.info("Starting pre-concatenation process")
    
    # Ensure the output root exists.
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        logging.info("Created output directory: %s", output_dir)
    
    # List fold directories (expecting names like fold-0, fold-1, fold-2)
    fold_dirs = [d for d in os.listdir(csv_root) if os.path.isdir(os.path.join(csv_root, d)) and d.startswith("fold-")]
    if not fold_dirs:
        logging.error("No fold directories found in csv_root: %s", csv_root)
        return

    for fold in sorted(fold_dirs):
        fold_path = os.path.join(csv_root, fold)
        # Process each CSV file in the fold (train.csv, test.csv, tune.csv)
        for split in ["train", "tune", "test"]:
            csv_filename = f"{split}.csv"
            csv_path = os.path.join(fold_path, csv_filename)
            if not os.path.isfile(csv_path):
                logging.warning("CSV file '%s' not found in fold '%s'. Skipping.", csv_filename, fold)
                continue

            logging.info("Processing CSV file: %s", csv_path)
            df = pd.read_csv(csv_path)
            logging.info("CSV loaded. Total rows: %d", len(df))
            
            # Check that the CSV contains a "case_id" column.
            if "case_id" not in df.columns:
                logging.error("CSV file '%s' does not contain a 'case_id' column. Skipping.", csv_path)
                continue

            # Create output subdirectory for this fold and split.
            for idx, row in df.iterrows():
                # Directly use the "case_id" value, knowing the column exists.
                case_id = str(row["case_id"])
                matching_paths = []

                for directory_path in embedding_dirs:
                    if not os.path.isdir(directory_path):
                        logging.warning("'%s' is not a valid directory. Skipping.", directory_path)
                        continue

                    # PARTIAL MATCH: Check if the case_id substring is found anywhere in the filename.
                    files = [f for f in os.listdir(directory_path) if f.endswith('.pt')]
                    matched = [
                        os.path.join(directory_path, f)
                        for f in files
                        if case_id in f  # substring matching since case_id is part of the slide name
                    ]
                    if matched:
                        logging.debug("Found %d matching files for case_id '%s' in '%s'.", 
                                      len(matched), case_id, directory_path)
                    matching_paths.extend(matched)

                if len(matching_paths) == 0:
                    logging.warning("No .pt file found for '%s'. Skipping row %d.", case_id, idx)
                    continue

                # Load and concatenate all matching tensors.
                bags = []
                for path in matching_paths:
                    logging.debug("Loading tensor from: %s", path)
                    bag = torch.load(path)
                    # If bag is [1024, N] and needs transposition to [N, 1024]
                    if bag.ndim == 2 and bag.shape[0] == 1024 and bag.shape[1] != 1024:
                        logging.debug("Transposing tensor for case_id '%s' from shape %s", case_id, bag.shape)
                        bag = bag.transpose(0, 1)
                    bags.append(bag)

                concatenated_bag = torch.cat(bags, dim=0)
                output_path = os.path.join("/data/temporary/amirhosein/model_script_concatenated", f"{case_id}.pt")
                torch.save(concatenated_bag, output_path)
                logging.info("Saved concatenated tensor for case_id '%s' to '%s'", case_id, output_path)

    logging.info("Pre-concatenation process completed successfully.")

if __name__ == "__main__":
    # Replace these paths with your actual directories.
    breast_dir = "/data/temporary/amirhosein/modelscript_breast/pt_files"
    bladder_dir = "/data/temporary/amirhosein/model_script_bladder/pt_files"
    prostate_dir = "/data/temporary/amirhosein/model_script_prostate/pt_files"
    
    embedding_dirs = [breast_dir, bladder_dir, prostate_dir]
    csv_root = "/data/temporary/projects/mutation-prediction/csvs/3-folds-prostate-bladder-breast"  # This directory should contain fold-0, fold-1, fold-2 subdirectories.
    output_dir = "/data/temporary/amirhosein/model_script_concatenated"
    
    pre_concatenate(embedding_dirs, csv_root, output_dir)

In [4]:
#!/usr/bin/env python3
from __future__ import print_function
import logging
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data_utils
from torch.autograd import Variable
import os
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, roc_auc_score
import torch.utils.data as data
from torch.autograd import Variable
from torch.utils.data import random_split, ConcatDataset, DataLoader

# Import your dataset and models
from dataloader import PreConcatenatedDataset
from model import Attention, GatedAttention

# Configure logging to display timestamps and log levels.
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

def train_one_epoch(model, loader, optimizer, device):
    logging.info("Starting training epoch")
    model.train()
    epoch_loss = 0.0
    epoch_error = 0.0

    for batch_idx, (bag, label) in enumerate(loader):
        logging.debug("Training batch %d - bag shape: %s", batch_idx, bag.shape)
        bag = bag.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        loss, _ = model.calculate_objective(bag, label)
        error, _ = model.calculate_classification_error(bag, label)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_error += error

    n = len(loader)
    avg_loss = epoch_loss / n if n > 0 else 0.0
    avg_error = epoch_error / n if n > 0 else 0.0
    logging.info("Finished training epoch: Avg Loss: %.4f, Avg Error: %.4f", avg_loss, avg_error)
    return avg_loss, avg_error

def evaluate(model, loader, device):
    logging.info("Starting evaluation")
    model.eval()
    epoch_loss = 0.0
    epoch_error = 0.0
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch_idx, (bag, label) in enumerate(loader):
            logging.debug("Evaluating batch %d - bag shape: %s", batch_idx, bag.shape)
            bag = bag.to(device)
            label = label.to(device)
            loss, _ = model.calculate_objective(bag, label)
            error, _ = model.calculate_classification_error(bag, label)

            epoch_loss += loss.item()
            epoch_error += error

            # Predict probabilities
            Y_prob, _, _ = model(bag)
            prob = Y_prob.mean().item()  # average if multiple branches

            all_labels.append(label.item())
            all_probs.append(prob)

    n = len(loader)
    avg_loss = epoch_loss / n if n else 0.0
    avg_error = epoch_error / n if n else 0.0
    logging.info("Finished evaluation: Avg Loss: %.4f, Avg Error: %.4f", avg_loss, avg_error)
    return avg_loss, avg_error, all_labels, all_probs

def main():
    logging.info("Script started")
    
    # Hardcoded configuration
    use_cuda = False
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    logging.info("Using device: %s", device)

    data_root = "/data/temporary/projects/mutation-prediction/csvs/3-folds-prostate-only"
    #breast_dir = "/data/temporary/amirhosein/modelscript_breast/pt_files"
    #bladder_dir = "/data/temporary/amirhosein/model_script_bladder/pt_files"
    prostate_dir = "/data/temporary/amirhosein/model_script_concatenated"

    model_type = "attention"  # or "gated_attention"
    input_dim = 1024
    hidden_dim = 500
    attn_dim = 128
    attn_branches = 1
    epochs = 10
    lr = 0.0005
    weight_decay = 1e-5
    patience = 7  # early stopping patience

    embedding_dirs = prostate_dir

    fold_metrics = {'precision': [], 'recall': [], 'auc': []}

    # 3-fold cross-validation loop
    for fold_idx in range(0, 3):
        logging.info("=== Starting FOLD %d ===", fold_idx)
        fold_subdir = f"fold-{fold_idx}"
        fold_path = os.path.join(data_root, fold_subdir)

        train_csv = os.path.join(fold_path, "train.csv")
        val_csv   = os.path.join(fold_path, "tune.csv")
        test_csv  = os.path.join(fold_path, "test.csv")

        logging.info("Loading datasets for fold %d", fold_idx)
        # Create datasets
        train_dataset = PreConcatenatedDataset(embedding_dirs, train_csv)
        val_dataset   = PreConcatenatedDataset(embedding_dirs, val_csv)
        test_dataset  = PreConcatenatedDataset(embedding_dirs, test_csv)

        train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)
        val_loader   = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True)
        test_loader  = DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)

        # Initialize model
        logging.info("Initializing model for fold %d", fold_idx)
        if model_type == "attention":
            model = Attention(input_dim=input_dim, M=hidden_dim, L=attn_dim, attention_branches=attn_branches)
        else:
            model = GatedAttention(input_dim=input_dim, M=hidden_dim, L=attn_dim, attention_branches=attn_branches)
        model.to(device)

        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        # Training and validation loop with early stopping
        best_val_loss = float("inf")
        best_state_dict = None
        no_improvement_counter = 0

        for epoch_idx in range(1, epochs + 1):
            logging.info("[Fold %d | Epoch %d] Starting training", fold_idx, epoch_idx)
            train_loss, train_error = train_one_epoch(model, train_loader, optimizer, device)
            logging.info("[Fold %d | Epoch %d] Training complete. Loss: %.4f, Error: %.4f", fold_idx, epoch_idx, train_loss, train_error)
            
            logging.info("[Fold %d | Epoch %d] Starting validation", fold_idx, epoch_idx)
            val_loss, val_error, _, _ = evaluate(model, val_loader, device)
            logging.info("[Fold %d | Epoch %d] Validation complete. Loss: %.4f, Error: %.4f", fold_idx, epoch_idx, val_loss, val_error)

            # Check if current epoch improved validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_state_dict = model.state_dict()
                no_improvement_counter = 0
                logging.info("[Fold %d | Epoch %d] New best model found (Validation Loss: %.4f)", fold_idx, epoch_idx, best_val_loss)
            else:
                no_improvement_counter += 1
                logging.info("[Fold %d | Epoch %d] No improvement. Counter: %d", fold_idx, epoch_idx, no_improvement_counter)

            # If no improvement for "patience" epochs, break early
            if no_improvement_counter >= patience:
                logging.info("[Fold %d] Early stopping triggered after %d epochs with no improvement.", fold_idx, patience)
                break

        # Load best model for testing
        if best_state_dict:
            model.load_state_dict(best_state_dict)
            logging.info("Best model loaded for fold %d", fold_idx)

        # Testing
        logging.info("Starting testing for fold %d", fold_idx)
        test_loss, test_error, y_true, y_scores = evaluate(model, test_loader, device)
        logging.info("[Fold %d] Testing complete. Loss: %.4f, Error: %.4f", fold_idx, test_loss, test_error)

        # Compute metrics
        y_pred = [1 if p >= 0.5 else 0 for p in y_scores]
        precision = precision_score(y_true, y_pred, zero_division=0)
        recall    = recall_score(y_true, y_pred, zero_division=0)
        try:
            auc_val = roc_auc_score(y_true, y_scores)
        except ValueError:
            auc_val = float('nan')

        logging.info("[Fold %d] Metrics - Precision: %.4f, Recall: %.4f, AUC: %.4f", fold_idx, precision, recall, auc_val)

        fold_metrics['precision'].append(precision)
        fold_metrics['recall'].append(recall)
        fold_metrics['auc'].append(auc_val)

    # After 3 folds, compute and log average metrics
    avg_precision = sum(fold_metrics['precision']) / 3
    avg_recall    = sum(fold_metrics['recall'])    / 3
    valid_aucs    = [a for a in fold_metrics['auc'] if not (isinstance(a, float) and (a != a))]
    avg_auc       = sum(valid_aucs) / len(valid_aucs) if valid_aucs else float('nan')

    logging.info("=== 3-FOLD CROSS-VALIDATION RESULTS ===")
    logging.info("Average Precision: %.4f", avg_precision)
    logging.info("Average Recall:    %.4f", avg_recall)
    logging.info("Average AUC:       %.4f", avg_auc)
    logging.info("Script finished successfully")

if __name__ == "__main__":
    main()


2025-02-27 09:04:57 INFO: Script started
2025-02-27 09:04:57 INFO: Using device: cpu
2025-02-27 09:04:57 INFO: === Starting FOLD 0 ===
2025-02-27 09:04:57 INFO: Loading datasets for fold 0
2025-02-27 09:04:57 INFO: Reading CSV file: /data/temporary/projects/mutation-prediction/csvs/3-folds-prostate-only/fold-0/train.csv
2025-02-27 09:04:57 INFO: CSV loaded. Total rows: 211
2025-02-27 09:05:00 INFO: Dataset initialized with 208 valid samples.
2025-02-27 09:05:00 INFO: Reading CSV file: /data/temporary/projects/mutation-prediction/csvs/3-folds-prostate-only/fold-0/tune.csv
2025-02-27 09:05:00 INFO: CSV loaded. Total rows: 53
2025-02-27 09:05:01 INFO: Dataset initialized with 53 valid samples.
2025-02-27 09:05:01 INFO: Reading CSV file: /data/temporary/projects/mutation-prediction/csvs/3-folds-prostate-only/fold-0/test.csv
2025-02-27 09:05:01 INFO: CSV loaded. Total rows: 133
2025-02-27 09:05:02 INFO: Dataset initialized with 131 valid samples.
2025-02-27 09:05:02 INFO: Initializing model

KeyboardInterrupt: 