In [1]:
%load_ext autoreload
%autoreload 2

# ALIAS Training Pipeline Demo

This notebook demonstrates the complete workflow from loading scRNA-seq data to training a sentence-transformer model.

## Pipeline Overview

1. **Load Data**: Load `.h5ad` file containing scRNA-seq data
2. **Configure**: Set up configurations for data processing
3. **Build Datasets**: Convert scRNA data to text-based datasets
4. **Generate Triplets**: Create training pairs for contrastive learning
5. **Train Model**: Fine-tune a sentence transformer model

## Prerequisites

**Option A:** Place your own `.h5ad` file at `data/demo.h5ad`  
**Option B:** Use scanpy's built-in PBMC dataset (no file needed!) üëà **Recommended for demo**


In [2]:
# Imports
import scanpy as sc
import numpy as np
from pathlib import Path

# ALIAS imports
from alias.data import (
    DatascRNAConfig,
    TripletGenerationConfig,
    build_datasets,
    build_triplets
)
from alias.model import TrainingSTConfig, train_model


  from .autonotebook import tqdm as notebook_tqdm


## Step 1: Load scRNA-seq Data

You have two options:
- **Option A**: Load your own `.h5ad` file
- **Option B**: Use scanpy's PBMC 3k dataset (2,638 cells, 8 cell types) ‚Üê **Easiest!**

The data needs:
- Cell type annotations in `obs`
- Gene expression matrix in `X`
- Gene names in `var`

In [3]:
# Option A: Load your own data (if you have a demo.h5ad file)
data_path = Path("../data/HIHA_sub1000.h5ad")
adata = sc.read_h5ad(data_path)

label_key = "AIFI_L1"

print(f"\n‚úì Loaded AnnData: {adata.shape[0]} cells x {adata.shape[1]} genes")
print(f"Available obs columns: {list(adata.obs.columns)}")
print(f"\nCell type distribution:")
print(adata.obs[label_key].value_counts())



‚úì Loaded AnnData: 1000 cells x 33538 genes
Available obs columns: ['cohort.cohortGuid', 'sample.sampleKitGuid', 'specimen.specimenGuid', 'pipeline.fileGuid', 'subject.subjectGuid', 'subject.biologicalSex', 'subject.birthYear', 'subject.ageAtFirstDraw', 'subject.ageGroup', 'subject.race', 'subject.ethnicity', 'subject.cmv', 'subject.bmi', 'sample.visitName', 'sample.drawYear', 'sample.subjectAgeAtDraw', 'batch_id', 'pool_id', 'chip_id', 'well_id', 'barcodes', 'original_barcodes', 'cell_name', 'n_reads', 'n_umis', 'n_genes', 'total_counts_mito', 'pct_counts_mito', 'doublet_score', 'AIFI_L1', 'AIFI_L2', 'AIFI_L3', 'original_AIFI_L1', 'original_AIFI_L2']

Cell type distribution:
AIFI_L1
T cell                   628
Monocyte                 184
B cell                    94
NK cell                   74
dendritic cell, human     15
Platelet                   2
Progenitor cell            2
Erythrocyte                1
Name: count, dtype: int64


## Step 2: Configure Data Processing

Set up configuration for scRNA dataset generation:
- `annotation_column`: Column in `obs` containing cell type labels
- `preprocessing`: Whether to run scanpy preprocessing
- `test_size`: Fraction of data to use for testing
- `cs_length`: Length of cell sentences (number of genes)


In [4]:
# Create scRNA data configuration
scrna_config = DatascRNAConfig(
    annotation_column=label_key,  # Change this to match your data
    preprocessing=True,  # Set to True if data needs preprocessing
    test_size=0.1,  # 10% for testing
    cs_length=(10, 20),  # Create cell sentences with 10 and 20 genes
    highly_variable_genes=True,
    semantic=True,  # Use semantic templates
    random_state=42
)

print("scRNA Configuration:")
print(f"  - Annotation column: {scrna_config.annotation_column}")
print(f"  - Test size: {scrna_config.test_size}")
print(f"  - Cell sentence lengths: {scrna_config.cs_length}")


scRNA Configuration:
  - Annotation column: AIFI_L1
  - Test size: 0.1
  - Cell sentence lengths: (10, 20)


## Step 3: Build Datasets

Convert the AnnData object into HuggingFace Dataset format.
This creates "cell sentences" - text representations of gene expression.


In [5]:
# Build datasets from AnnData
dataset_dict, adata_test = build_datasets(
    adata=adata,
    datasets=['scrna'],  # Only scRNA for this demo
    scrna_config=scrna_config
)

print(f"\nDatasets created:")
for dataset_name, splits in dataset_dict.items():
    print(f"\n{dataset_name}:")
    for split_name, split_data in splits.items():
        print(f"  - {split_name}: {len(split_data)} samples")
        if len(split_data) > 0:
            print(f"    Features: {list(split_data.features.keys())}")

# Preview first example
print("\n" + "="*50)
print("Example cell sentence:")
print("="*50)
example = dataset_dict['scrna']['data'][0]
print(f"Sentence: {example['sentence1'][:100]}...")
print(f"Label: {example['label']}")


Preprocessing AnnData object...
Filtering cells with fewer than 100 expressed genes...
Filtering cells with a percentage of mitochondrial genes expressed over 15...
Filtering genes expressed in fewer than 5 cells...
Analyzing cell type sizes...
Initial number of celtypes: 8
Batches with fewer than 5 cells: 3
Filtering small batches...
Normalizing total counts per cell...
Performing log transformation...
Identifying 3000 highly variable genes...
Subsetting to highly variable genes...
Preprocessing complete. Your AnnData object is ready for further analysis.
AnnData object with n_obs √ó n_vars = 995 √ó 3000
    obs: 'cohort.cohortGuid', 'sample.sampleKitGuid', 'specimen.specimenGuid', 'pipeline.fileGuid', 'subject.subjectGuid', 'subject.biologicalSex', 'subject.birthYear', 'subject.ageAtFirstDraw', 'subject.ageGroup', 'subject.race', 'subject.ethnicity', 'subject.cmv', 'subject.bmi', 'sample.visitName', 'sample.drawYear', 'sample.subjectAgeAtDraw', 'batch_id', 'pool_id', 'chip_id', 'well

Building dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 995/995 [00:00<00:00, 33649.94it/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 895/895 [00:00<00:00, 35486.15 examples/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 100/100 [00:00<00:00, 24758.30 examples/s]


Datasets created:

scrna:
  - data: 895 samples
    Features: ['index', 'gene_list', 'AIFI_L1', 'label', 'sentence1']
  - test: 100 samples
    Features: ['index', 'gene_list', 'AIFI_L1', 'label', 'sentence1']

Example cell sentence:
Sentence: MALAT1, LYZ, HLA-DRA, ACTB, CD74, VIM, FTH1, CST3, HLA-DPA1, S100A4 expression places this cell in t...
Label: dendritic cell, human





## Step 4: Generate Triplets

Create training triplets for contrastive learning:
- Anchor: Original cell sentence
- Positive: Similar cell (same cell type)
- Negative: Dissimilar cell (different cell type)


In [6]:
# Configure triplet generation
triplet_config = TripletGenerationConfig(
    annotation_column=label_key,
    loss='MNR',  # MultipleNegativesRanking loss
    eval_split=0.1,  # 10% for evaluation
    random_negative_mining=True,  # Use random negatives
    hard_negative_mining=False,  # Set True for hard negative mining (slower)
    testrun=False,
    seed=42
)

print("Triplet Configuration:")
print(f"  - Loss type: {triplet_config.loss}")
print(f"  - Eval split: {triplet_config.eval_split}")
print(f"  - Random negative mining: {triplet_config.random_negative_mining}")
print(f"  - Hard negative mining: {triplet_config.hard_negative_mining}")


Triplet Configuration:
  - Loss type: MNR
  - Eval split: 0.1
  - Random negative mining: True
  - Hard negative mining: False


In [7]:
# Generate triplets
triplet_dict = build_triplets(
    dataset_dict=dataset_dict,
    triplet_config=triplet_config
)

Building triplets for scrna...
Generating semantic training pairs for training with MultipleNgeativesRanking loss...
Processing scrna data...


Generating pairs by label: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:00<00:00, 1262.96it/s]

‚úÖ Created 4475 total semantic pairs.
                                            sentence1  \
0   A B cell signature includes: MALAT1, IGKC, CD7...   
1   A B cell signature includes: MALAT1, IGKC, CD7...   
2   A B cell signature includes: MALAT1, IGKC, CD7...   
3   A B cell signature includes: MALAT1, IGKC, CD7...   
4   A B cell signature includes: MALAT1, IGKC, CD7...   
5   This profile resembles B cell cells, based on ...   
6   This profile resembles B cell cells, based on ...   
7   This profile resembles B cell cells, based on ...   
8   This profile resembles B cell cells, based on ...   
9   This profile resembles B cell cells, based on ...   
10  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
11  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
12  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
13  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
14  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
15  Expression profile includes: MALAT1, IGKC, CD




In [8]:
# Generate triplets
triplet_dict = build_triplets(
    dataset_dict=dataset_dict,
    triplet_config=triplet_config
)

print(f"\nTriplets created:")
for dataset_name, splits in triplet_dict.items():
    print(f"\n{dataset_name}:")
    for split_name, split_data in splits.items():
        print(f"  - {split_name}: {len(split_data)} triplets")
        if len(split_data) > 0:
            print(f"    Features: {list(split_data.features.keys())}")

# Preview a triplet
print("\n" + "="*50)
print("Example triplet:")
print("="*50)
example_triplet = triplet_dict['scrna']['train_MNR_rnm'][0]
print(f"Anchor: {example_triplet['sentence1'][:80]}...")
print(f"Positive: {example_triplet['sentence2'][:80]}...")
print(f"Negative: {example_triplet['negative'][:80]}...")


Building triplets for scrna...
Generating semantic training pairs for training with MultipleNgeativesRanking loss...
Processing scrna data...


Generating pairs by label: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:00<00:00, 1299.03it/s]


‚úÖ Created 4475 total semantic pairs.
                                            sentence1  \
0   A B cell signature includes: MALAT1, IGKC, CD7...   
1   A B cell signature includes: MALAT1, IGKC, CD7...   
2   A B cell signature includes: MALAT1, IGKC, CD7...   
3   A B cell signature includes: MALAT1, IGKC, CD7...   
4   A B cell signature includes: MALAT1, IGKC, CD7...   
5   This profile resembles B cell cells, based on ...   
6   This profile resembles B cell cells, based on ...   
7   This profile resembles B cell cells, based on ...   
8   This profile resembles B cell cells, based on ...   
9   This profile resembles B cell cells, based on ...   
10  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
11  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
12  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
13  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
14  Leading genes expressed: MALAT1, CD74, HLA-DRA...   
15  Expression profile includes: MALAT1, IGKC, CD

## Step 5: Configure Training

Set up the model training configuration:
- Choose base model from HuggingFace
- Set training hyperparameters
- Choose where to save the model


In [9]:
# Training configuration
training_config = TrainingSTConfig(
    model="neuml/pubmedbert-base-embeddings",  # Base model
    loss='MNR',  # Must match triplet_config.loss
    new_model_name="alias_demo_model",
    
    # Training hyperparameters
    batch_size=32,  # Adjust based on your GPU memory
    epochs=2,  # Increase for better results
    warmup_steps=100,
    weight_decay=0.01,
    logging_steps=10,
    
    # Save options
    save_to_local=True,
    save_to_hf=False,  # Set True to push to HuggingFace Hub
    
    # Hardware
    fp16=False,  # Set True if you have a GPU with fp16 support
    
    # Testing
    testrun=False,  # Set True for quick test
    seed=73
)

print("Training Configuration:")
print(f"  - Base model: {training_config.model}")
print(f"  - Loss: {training_config.loss}")
print(f"  - Batch size: {training_config.batch_size}")
print(f"  - Epochs: {training_config.epochs}")
print(f"  - Save locally: {training_config.save_to_local}")
print(f"  - Save to HF: {training_config.save_to_hf}")


Training Configuration:
  - Base model: neuml/pubmedbert-base-embeddings
  - Loss: MNR
  - Batch size: 32
  - Epochs: 2
  - Save locally: True
  - Save to HF: False


## Step 6: Train the Model

‚ö†Ô∏è **Note**: Training may take a while depending on:
- Dataset size
- Number of epochs
- Available hardware (GPU/CPU)

The model will be saved to `out/models/` directory.


In [13]:
# Train the model
trained_model = train_model(
    dataset_dict=triplet_dict,
    datasets='scrna',
    train_config=training_config
)

print("\n" + "="*50)
print("Training completed!")
print("="*50)


Model 'neuml/pubmedbert-base-embeddings' is public. No token needed.
Loaded model neuml/pubmedbert-base-embeddings on mps
MultipleNegativesRanking Loss loaded!




Starting fine-tuning on scrna


Step,Training Loss,Validation Loss




Training completed.
 Model saved to /Users/mengerj/repos/alias/notebooks/models/alias_demo_model

Training completed!


## Step 7: Verify the Trained Model

Test that the model can generate embeddings.


In [15]:
# Load the trained model
from sentence_transformers import SentenceTransformer

# The model is saved in out/models/
model_path = Path("models") / training_config.new_model_name
print(f"Loading model from: {model_path}")

model = SentenceTransformer(str(model_path))

# Test embedding generation
test_sentences = [
    dataset_dict['scrna']['test'][0]['sentence1'],
    dataset_dict['scrna']['test'][1]['sentence1']
]

embeddings = model.encode(test_sentences)
print(f"\nGenerated embeddings shape: {embeddings.shape}")
print(f"Embedding dimension: {embeddings.shape[1]}")

# Compute similarity
from sklearn.metrics.pairwise import cosine_similarity
similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
print(f"Cosine similarity between two test cells: {similarity:.4f}")


Loading model from: models/alias_demo_model

Generated embeddings shape: (2, 768)
Embedding dimension: 768
Cosine similarity between two test cells: 0.8159


## Summary

You've successfully:
1. ‚úÖ Loaded scRNA-seq data from `.h5ad` file
2. ‚úÖ Converted cells to text-based "cell sentences"
3. ‚úÖ Generated training triplets for contrastive learning
4. ‚úÖ Fine-tuned a sentence transformer model
5. ‚úÖ Verified the model can generate embeddings

## Next Steps

- Use the model to generate embeddings for your full dataset
- Evaluate the model on downstream tasks
- Integrate with NCBI literature data for multi-modal training
- Experiment with different hyperparameters
