In [1]:
import torch
import numpy as np
from pathlib import Path
import json

# Assuming 'biomarkers' is installed or on the python path
from biomarkers.models.text import TextBiomarkerModel
from biomarkers.core.base import BiomarkerConfig

print(f"PyTorch version: {torch.__version__}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")



PyTorch version: 2.9.0+cu130
Using device: cuda


In [2]:
config_dict = {
    # --- Base Keys ---
    'modality': 'text',
    'model_type': 'TextBiomarkerModel', # Must match the class name potentially
    'hidden_dim': 128,          # Smaller hidden dim for faster example
    'num_diseases': 5,
    'dropout': 0.1,
    'device': device,           # Use detected device
    # --- Custom Keys (will go into metadata) ---
    'embedding_dim': 768,       # Dimension of input embeddings
    'max_seq_length': 128,
    'use_pretrained': False,    # Not loading a real pretrained model here
    'pretrained_model': 'bert-base-uncased',
    'biomarker_feature_dim': 50 # Dimension for aggregated biomarker vector
}

In [3]:
try:
    model = TextBiomarkerModel(config_dict)
    print("Model initialized successfully!")
    print(f"Model placed on device: {next(model.parameters()).device}")
    print(f"Number of parameters: {model.num_parameters:,}")
    # print("\nModel Config:")
    # print(model.config) 
except Exception as e:
    print(f"Error initializing model: {e}")
    # If initialization fails, stop the notebook execution for this path
    raise e

Model initialized successfully!
Model placed on device: cuda:0
Number of parameters: 11,403,355


In [4]:
model.eval() 
print("\nModel set to evaluation mode.")


Model set to evaluation mode.


In [5]:
batch_size = 4
seq_len = 100 # Should be <= max_seq_length, but model handles variable length
embedding_dim = config_dict['embedding_dim']

# Generate random embeddings tensor
# The input tensor can be on CPU; the model's forward pass will move it to the correct device.
dummy_embeddings = torch.randn(batch_size, seq_len, embedding_dim)

print(f"Dummy embeddings shape: {dummy_embeddings.shape}")

Dummy embeddings shape: torch.Size([4, 100, 768])


In [6]:
dummy_metadata = {
    'tokens': [
        ['this', 'is', 'sample', 'one', '.', 'it', 'has', 'um', 'pauses', '.'] * (seq_len // 10) 
        for _ in range(batch_size)
    ],
    'pos_tags': [
        ['DT', 'VBZ', 'NN', 'CD', '.', 'PRP', 'VBZ', 'UH', 'NNS', '.'] * (seq_len // 10)
        for _ in range(batch_size)
    ],
    'parse_trees': [
        '(S (NP (DT this)) (VP (VBZ is) (NP (NN sample) (CD one))))' 
        for _ in range(batch_size)
    ],
    'timestamps': [
        np.linspace(0, seq_len * 0.5, seq_len) 
        for _ in range(batch_size)
     ],
     'content_words': [
         ['sample', 'one', 'pauses'] * (seq_len // 10)
         for _ in range(batch_size)
     ]
}

print(f"Metadata generated for {len(dummy_metadata['tokens'])} samples.")

Metadata generated for 4 samples.


In [7]:
with torch.no_grad(): # Disable gradient calculation for inference
    output = model(
        dummy_embeddings, 
        text_metadata=dummy_metadata,
        return_biomarkers=True,
        return_uncertainty=True,
        return_clinical=True,
        return_cognitive=True
    )

print("Forward pass completed.")
print("\nOutput keys:", output.keys())

# %% [markdown]
# ## 5. Examine Outputs

# %%
print("\n--- Basic Outputs ---")
print(f"Logits shape: {output['logits'].shape}") # (batch_size, num_diseases)
print(f"Probabilities shape: {output['probabilities'].shape}") # (batch_size, num_diseases)
print(f"Predictions shape: {output['predictions'].shape}") # (batch_size,)
print(f"Predictions: {output['predictions']}")
print(f"Features (mean-pooled input) shape: {output['features'].shape}") # (batch_size, embedding_dim)
print(f"Linguistic Pattern Probs shape: {output['linguistic_pattern'].shape}") # (batch_size, 8)

Forward pass completed.

Output keys: dict_keys(['logits', 'probabilities', 'features', 'predictions', 'linguistic_pattern', 'biomarkers', 'cognitive_scores', 'clinical_scores', 'change_prediction', 'uncertainty', 'confidence'])

--- Basic Outputs ---
Logits shape: torch.Size([4, 5])
Probabilities shape: torch.Size([4, 5])
Predictions shape: torch.Size([4])
Predictions: tensor([4, 4, 4, 4], device='cuda:0')
Features (mean-pooled input) shape: torch.Size([4, 768])
Linguistic Pattern Probs shape: torch.Size([4, 8])


In [8]:
print("\n--- Biomarkers ---")
biomarkers = output['biomarkers']
print(f"Number of biomarkers extracted: {len(biomarkers)}")
# Print a few example biomarker values for the first sample
print("Examples (first sample):")
for i, (key, value) in enumerate(biomarkers.items()):
    if i < 5: # Print first 5
         # Move to CPU for printing if needed
        val_cpu = value[0].cpu().numpy() if isinstance(value, torch.Tensor) else value[0]
        print(f"  - {key}: {val_cpu:.4f} (shape: {value.shape})")


--- Biomarkers ---
Number of biomarkers extracted: 29
Examples (first sample):
  - lexical_diversity: 0.4745 (shape: torch.Size([4]))
  - vocabulary_richness: 0.5013 (shape: torch.Size([4]))
  - semantic_diversity: 0.4974 (shape: torch.Size([4]))
  - lexical_density: 0.3000 (shape: torch.Size([4]))
  - syntactic_complexity: 0.5299 (shape: torch.Size([4]))


In [10]:
print("\n--- Cognitive Scores ---")
cognitive_scores = output['cognitive_scores']
print(f"Cognitive scores keys: {cognitive_scores.keys()}")
print(f"MMSE scores shape: {cognitive_scores['MMSE'].shape}") 
print(f"MMSE scores (first sample): {cognitive_scores['MMSE'][0].item():.2f}")
print(f"MoCA scores shape: {cognitive_scores['MoCA'].shape}")
print(f"MoCA scores (first sample): {cognitive_scores['MoCA'][0].item():.2f}")
print(f"CDR scores shape: {cognitive_scores['CDR'].shape}") # (batch_size, 5)
print(f"CDR scores (first sample): {cognitive_scores['CDR'][0].cpu().numpy()}")
print("\n--- Cognitive Scores ---")
cognitive_scores = output['cognitive_scores']
print(f"Cognitive scores keys: {cognitive_scores.keys()}")
print(f"MMSE scores shape: {cognitive_scores['MMSE'].shape}") 
print(f"MMSE scores (first sample): {cognitive_scores['MMSE'][0].item():.2f}")
print(f"MoCA scores shape: {cognitive_scores['MoCA'].shape}")
print(f"MoCA scores (first sample): {cognitive_scores['MoCA'][0].item():.2f}")
print(f"CDR scores shape: {cognitive_scores['CDR'].shape}") # (batch_size, 5)
print(f"CDR scores (first sample): {cognitive_scores['CDR'][0].cpu().numpy()}")


--- Cognitive Scores ---
Cognitive scores keys: dict_keys(['MMSE', 'MoCA', 'CDR'])
MMSE scores shape: torch.Size([4])
MMSE scores (first sample): 14.93
MoCA scores shape: torch.Size([4])
MoCA scores (first sample): 15.02
CDR scores shape: torch.Size([4, 5])
CDR scores (first sample): [0.18965001 0.20985328 0.2053401  0.2029706  0.19218606]

--- Cognitive Scores ---
Cognitive scores keys: dict_keys(['MMSE', 'MoCA', 'CDR'])
MMSE scores shape: torch.Size([4])
MMSE scores (first sample): 14.93
MoCA scores shape: torch.Size([4])
MoCA scores (first sample): 15.02
CDR scores shape: torch.Size([4, 5])
CDR scores (first sample): [0.18965001 0.20985328 0.2053401  0.2029706  0.19218606]


In [11]:
print("\n--- Clinical Scores ---")
clinical_scores = output['clinical_scores']
print(f"Clinical scores keys: {clinical_scores.keys()}")
print(f"Language Severity shape: {clinical_scores['language_severity'].shape}")
print(f"Language Severity (first sample): {clinical_scores['language_severity'][0].item():.2f}")
print(f"Decline Rate shape: {clinical_scores['decline_rate'].shape}")
print(f"Decline Rate (first sample): {clinical_scores['decline_rate'][0].item():.4f}")
print(f"Communication Eff. shape: {clinical_scores['communication_effectiveness'].shape}")
print(f"Communication Eff. (first sample): {clinical_scores['communication_effectiveness'][0].item():.2f}")


--- Clinical Scores ---
Clinical scores keys: dict_keys(['language_severity', 'decline_rate', 'communication_effectiveness'])
Language Severity shape: torch.Size([4])
Language Severity (first sample): 2.40
Decline Rate shape: torch.Size([4])
Decline Rate (first sample): 0.5174
Communication Eff. shape: torch.Size([4])
Communication Eff. (first sample): 4.93


In [12]:
print("\n--- Uncertainty ---")
print(f"Uncertainty shape: {output['uncertainty'].shape}") # (batch_size, num_diseases)
print(f"Confidence shape: {output['confidence'].shape}") # (batch_size, num_diseases)
# Confidence for the predicted class for the first sample
predicted_class_0 = output['predictions'][0].item()
print(f"Confidence for predicted class ({predicted_class_0}) (first sample): {output['confidence'][0, predicted_class_0].item():.4f}")

# %%



--- Uncertainty ---
Uncertainty shape: torch.Size([4, 5])
Confidence shape: torch.Size([4, 5])
Confidence for predicted class (4) (first sample): 0.4830


In [13]:
print("\n--- Change Prediction ---")
print(f"Change Prediction shape: {output['change_prediction'].shape}") # (batch_size, 5)
print(f"Change Prediction (first sample): {output['change_prediction'][0].cpu().numpy()}")


--- Change Prediction ---
Change Prediction shape: torch.Size([4, 5])
Change Prediction (first sample): [0.17957321 0.19120947 0.22037056 0.21395439 0.19489242]


In [15]:
with torch.no_grad():
    predictions_output = model.predict(
        dummy_embeddings, 
        # text_metadata=dummy_metadata, 
        return_confidence=True
    )

print("\n--- Output from model.predict() ---")
print("Keys:", predictions_output.keys())
print(f"Predictions: {predictions_output['predictions']}")
print(f"Confidence shape: {predictions_output['confidence'].shape}")


--- Output from model.predict() ---
Keys: dict_keys(['logits', 'probabilities', 'features', 'predictions', 'linguistic_pattern', 'biomarkers', 'cognitive_scores', 'clinical_scores', 'change_prediction', 'uncertainty', 'confidence'])
Predictions: tensor([4, 4, 4, 4], device='cuda:0')
Confidence shape: torch.Size([4, 5])


In [17]:
with torch.no_grad():
    biomarkers_only = model.get_biomarkers(dummy_embeddings)

print("\n--- Output from model.get_biomarkers() ---")
print(f"Number of biomarkers: {len(biomarkers_only)}")
print("Keys:", list(biomarkers_only.keys())[:10], "...") # Show first 10 keys


--- Output from model.get_biomarkers() ---
Number of biomarkers: 29
Keys: ['lexical_diversity', 'vocabulary_richness', 'semantic_diversity', 'lexical_density', 'syntactic_complexity', 'subordination_index', 'dependency_distance', 'grammar_accuracy', 'semantic_coherence', 'topic_consistency'] ...


In [19]:
with torch.no_grad():
    biomarkers_sample_0 = model.extract_biomarkers(dummy_embeddings[0:1], {k: [v[0]] for k, v in dummy_metadata.items() if isinstance(v, list)})

interpretation = model.get_clinical_interpretation(biomarkers_sample_0)
print(json.dumps(interpretation, indent=2))


{
  "linguistic_decline_syndrome": {
    "detected": true,
    "severity": "moderate",
    "composite_score": "0.50",
    "clinical_note": "Multiple linguistic decline markers detected. Comprehensive cognitive evaluation recommended."
  }
}


In [20]:

# %%
# Compare to normative data (using biomarkers for the first sample)
print("\n--- Normative Comparison (First Sample, Age 65, Edu 12) ---")
comparison = model.compare_to_normative_data(biomarkers_sample_0, age=65, education=12)
print(json.dumps(comparison, indent=2))




--- Normative Comparison (First Sample, Age 65, Edu 12) ---
{
  "lexical_diversity": {
    "value": 0.4745039939880371,
    "normative_mean": 0.6,
    "z_score": -1.0457999629496937,
    "interpretation": "mildly impaired"
  },
  "syntactic_complexity": {
    "value": 0.5299345254898071,
    "normative_mean": 0.5499999999999999,
    "z_score": -0.13376982114996397,
    "interpretation": "average"
  },
  "semantic_coherence": {
    "value": 0.5189153552055359,
    "normative_mean": 0.7,
    "z_score": -1.810846266860014,
    "interpretation": "moderately impaired"
  },
  "idea_density": {
    "value": 0.47677335143089294,
    "normative_mean": 0.5,
    "z_score": -0.1935553886129431,
    "interpretation": "average"
  },
  "information_content": {
    "value": 0.4920281767845154,
    "normative_mean": 0.6499999999999999,
    "z_score": -1.215167777413898,
    "interpretation": "mildly impaired"
  }
}


In [21]:

# %% [markdown]
# ## 8. Longitudinal Trajectory Prediction (Example)

# %%
# Create dummy data for 3 time points
num_time_points = 3
longitudinal_embeddings = [torch.randn(1, seq_len, embedding_dim) for _ in range(num_time_points)]
time_points_months = [0.0, 6.0, 12.0] # Example: Baseline, 6 months, 12 months

print(f"\n--- Predicting Trajectory over {num_time_points} time points ---")
trajectory_output = model.predict_cognitive_trajectory(longitudinal_embeddings, time_points_months)

print(f"Trajectory Class: {trajectory_output.get('trajectory_class', 'N/A')}")
print(f"Recommendation: {trajectory_output.get('recommendation', 'N/A')}")
# print("Metric Trends:")
# print(json.dumps(trajectory_output.get('metric_trends', {}), indent=2))




--- Predicting Trajectory over 3 time points ---
Trajectory Class: stable
Recommendation: Linguistic abilities remain stable. Continue routine monitoring as appropriate for age and risk factors.


In [22]:

# %% [markdown]
# ## 9. Save and Load Model
# 
# Use the `save` and `load` methods inherited from `BiomarkerModel`.

# %%
save_dir = Path("./temp_model_save")
save_dir.mkdir(exist_ok=True)
model_path = save_dir / "text_biomarker_model.pth"

# Save the model
model.save(model_path)
print(f"\nModel saved to {model_path}")

# %%
# Load the model
try:
    loaded_model = TextBiomarkerModel.load(model_path)
    loaded_model.eval() # Set to eval mode after loading
    print("\nModel loaded successfully!")
    print(f"Loaded model device: {next(loaded_model.parameters()).device}")

    # Verify consistency (optional)
    with torch.no_grad():
        output_original = model(dummy_embeddings)
        # Ensure input is on the correct device for the loaded model
        output_loaded = loaded_model(dummy_embeddings.to(loaded_model.device)) 
    
    # Compare logits (move loaded output back to CPU if necessary for comparison)
    logits_match = torch.allclose(
        output_original['logits'].cpu(), 
        output_loaded['logits'].cpu(), 
        atol=1e-6 # Add tolerance for potential float precision differences
    )
    print(f"Logits match after loading: {logits_match}")
    assert logits_match

except Exception as e:
    print(f"Error loading model: {e}")


Model saved to temp_model_save/text_biomarker_model.pth

Model loaded successfully!
Loaded model device: cuda:0
Logits match after loading: True
