In [1]:
# %% [markdown]
# # TextBiomarkerModel Training Example
#
# This notebook demonstrates how to train the `TextBiomarkerModel` using the `BiomarkerTrainer`.
# We'll use dummy data for this example.

# %%
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import json
import logging # Added for trainer logging

# Assuming 'biomarkers' is installed or on the python path
from biomarkers.models.text import TextBiomarkerModel
from biomarkers.core.base import BiomarkerConfig
from biomarkers.training.trainer import BiomarkerTrainer # Assuming trainer is in biomarkers.training

print(f"PyTorch version: {torch.__version__}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Configure logging (optional, but helpful for trainer)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)




PyTorch version: 2.9.0+cu130
Using device: cuda


In [3]:

# %% [markdown]
# ## 1. Configuration
#
# Define the configuration for the model and training.

# %%
config_dict = {
    # --- Base Model Keys ---
    'modality': 'text',
    'model_type': 'TextBiomarkerModel',
    'hidden_dim': 128,
    'num_diseases': 5, # Number of output classes
    'dropout': 0.2,
    'device': device,
    # --- TextBiomarkerModel Specific Keys ---
    'embedding_dim': 768,
    'max_seq_length': 128,
    'use_pretrained': False,
    'biomarker_feature_dim': 50,
    # --- Trainer Keys ---
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    'gradient_clip': 1.0,
    'epochs': 3, # Keep epochs low for a quick example
    'batch_size': 8
}

# %% [markdown]
# ## 2. Dummy Data and DataLoader
#
# Create a simple dummy dataset and dataloader. In a real application, you would use your `MultiModalBiomarkerDataset` and `create_dataloaders`.

# %%
class DummyTextDataset(Dataset):
    def __init__(self, num_samples, seq_len, embedding_dim, num_classes):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random embeddings
        # The trainer will move these to the correct device
        embeddings = torch.randn(self.seq_len, self.embedding_dim)
        # Generate a random target class
        target = torch.randint(0, self.num_classes, (1,)).squeeze()

        # Minimal batch structure expected by trainer
        # Note: We are only providing 'input' and 'target'
        # If your model *requires* metadata even during training loss calculation,
        # you'll need to adapt this dummy dataset or the model.
        # The current TextBiomarkerModel forward pass uses metadata for biomarker extraction,
        # but the core classification loss only depends on 'logits' and 'target'.
        return {'input': embeddings, 'target': target}


In [4]:

# Dataset parameters
num_train_samples = 100
num_val_samples = 20
seq_len = config_dict['max_seq_length']
embedding_dim = config_dict['embedding_dim']
num_classes = config_dict['num_diseases']

# Create Datasets
train_dataset = DummyTextDataset(num_train_samples, seq_len, embedding_dim, num_classes)
val_dataset = DummyTextDataset(num_val_samples, seq_len, embedding_dim, num_classes)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config_dict['batch_size'],
    shuffle=True,
    num_workers=0 # Set num_workers > 0 for actual data loading
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config_dict['batch_size'],
    shuffle=False,
    num_workers=0
)

print(f"Created dummy train DataLoader with {len(train_loader)} batches.")
print(f"Created dummy validation DataLoader with {len(val_loader)} batches.")


Created dummy train DataLoader with 13 batches.
Created dummy validation DataLoader with 3 batches.


In [5]:


# %% [markdown]
# ## 3. Model Initialization
#
# Instantiate the `TextBiomarkerModel`.

# %%
try:
    model = TextBiomarkerModel(config_dict)
    print("Model initialized successfully!")
    print(f"Model placed on device: {next(model.parameters()).device}")
except Exception as e:
    logger.error(f"Error initializing model: {e}")
    raise e


Model initialized successfully!
Model placed on device: cuda:0


In [6]:

# %% [markdown]
# ## 4. Trainer Initialization
#
# Instantiate the `BiomarkerTrainer`. We'll disable `wandb` logging for this simple example.

# %%
try:
    trainer = BiomarkerTrainer(
        model=model,
        config=config_dict,
        use_wandb=False # Disable wandb for this local example
    )
    print("Trainer initialized successfully!")
except Exception as e:
    logger.error(f"Error initializing trainer: {e}")
    raise e


Trainer initialized successfully!


In [7]:

# %% [markdown]
# ## 5. Training Loop
#
# Run the training using the `trainer.fit()` method.

# %%
num_epochs = config_dict['epochs']
save_directory = "./biomarker_training_output"

print(f"\nStarting training for {num_epochs} epochs...")

try:
    trainer.fit(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=num_epochs,
        save_dir=save_directory
    )
    print("\nTraining finished!")
except Exception as e:
    logger.error(f"Error during training: {e}")
    raise e



Starting training for 3 epochs...


Epoch 0: 100%|██████████| 13/13 [00:00<00:00, 21.91it/s, loss=1.62, acc=0.183]
Evaluating: 100%|██████████| 3/3 [00:00<00:00, 80.92it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



Epoch 0/3
Train Loss: 1.6191, Train Acc: 0.1827
Val Acc: 0.2500, Val F1: 0.0800
Saved best model with F1 score: 0.0800


Epoch 1: 100%|██████████| 13/13 [00:00<00:00, 38.38it/s, loss=1.57, acc=0.163]
Evaluating: 100%|██████████| 3/3 [00:00<00:00, 89.21it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



Epoch 1/3
Train Loss: 1.6038, Train Acc: 0.1635
Val Acc: 0.2000, Val F1: 0.0667


Epoch 2: 100%|██████████| 13/13 [00:00<00:00, 38.31it/s, loss=1.64, acc=0.154]
Evaluating: 100%|██████████| 3/3 [00:00<00:00, 88.26it/s]


Epoch 2/3
Train Loss: 1.6157, Train Acc: 0.1538
Val Acc: 0.1000, Val F1: 0.0364

Training finished!



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [8]:

# %% [markdown]
# ## 6. Post-Training (Optional)
#
# After training, you can load the best model saved by the trainer and perform evaluation or inference.

# %%
best_model_path = Path(save_directory) / 'best_model.pt'

if best_model_path.exists():
    print(f"\nLoading best model from: {best_model_path}")
    try:
        # Load checkpoint (contains model state, config, etc.)
        checkpoint = torch.load(best_model_path, map_location=device)

        # Re-initialize model with saved config
        loaded_model_config = checkpoint['config']
        loaded_model = TextBiomarkerModel(loaded_model_config)

        # Load state dict
        loaded_model.load_state_dict(checkpoint['model_state_dict'])
        loaded_model.eval() # Set to evaluation mode

        print("Best model loaded successfully.")
        print(f"Model was saved at epoch {checkpoint.get('epoch', 'N/A')}")
        print(f"Validation metrics at save: {checkpoint.get('metrics', {})}")

        # Example: Evaluate the loaded model on the validation set
        print("\nEvaluating loaded model on validation set...")
        # Need a new trainer instance for evaluation if desired, or call evaluate directly
        eval_trainer = BiomarkerTrainer(loaded_model, loaded_model_config, use_wandb=False)
        final_val_metrics = eval_trainer.evaluate(val_loader, clinical_metrics=False) # Disable clinical for dummy data
        print("Evaluation results:")
        print(json.dumps(final_val_metrics, indent=2))

    except Exception as e:
        logger.error(f"Error loading or evaluating best model: {e}")
else:
    print(f"\nBest model checkpoint not found at {best_model_path}. Training might have failed or not saved.")


# %% [markdown]
# ---
# End of Training Example
# ---


Loading best model from: biomarker_training_output/best_model.pt
Best model loaded successfully.
Model was saved at epoch 0
Validation metrics at save: {'accuracy': 0.25, 'precision': 0.05, 'recall': 0.2, 'f1_score': 0.08, 'auc': 0.3823249299719888, 'disease_0_sensitivity': 0.0, 'disease_0_specificity': 1.0, 'disease_0_ppv': 0, 'disease_0_npv': 0.7, 'disease_1_sensitivity': 1.0, 'disease_1_specificity': 0.0, 'disease_1_ppv': 0.25, 'disease_1_npv': 0, 'disease_2_sensitivity': 0.0, 'disease_2_specificity': 1.0, 'disease_2_ppv': 0, 'disease_2_npv': 0.85, 'disease_3_sensitivity': 0.0, 'disease_3_specificity': 1.0, 'disease_3_ppv': 0, 'disease_3_npv': 0.75, 'disease_4_sensitivity': 0.0, 'disease_4_specificity': 1.0, 'disease_4_ppv': 0, 'disease_4_npv': 0.95, 'mean_uncertainty': 0.9767093658447266}

Evaluating loaded model on validation set...


Evaluating: 100%|██████████| 3/3 [00:00<00:00, 79.67it/s]

Evaluation results:
{
  "accuracy": 0.3,
  "precision": 0.06,
  "recall": 0.2,
  "f1_score": 0.09230769230769231,
  "auc": 0.46050420168067224,
  "mean_uncertainty": 0.9760211706161499
}



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
