# Single Simulation Training Demo
Generate one simulation dataset and train a BINN while printing epoch metrics.

In [None]:
import sys, subprocess
from pathlib import Path

ROOT = Path.cwd()
if (ROOT / 'openbinn').exists():
    sys.path.insert(0, str(ROOT))

beta = 2.0
gamma = 2.0
sim_id = 1
nonlinear = True
data_root = Path(f'./data/b{beta}_g{gamma}/{sim_id}')

if not data_root.exists():
    cmd = [
        'python', 'analysis/generate_simulations.py',
        '--beta', str(beta), '--gamma', str(gamma),
        '--n_sim', '1', '--start_sim', str(sim_id), '--end_sim', str(sim_id),
        '--pathway_nonlinear' if nonlinear else ''
    ]
    cmd = [c for c in cmd if c]
    subprocess.run(cmd, check=True)
else:
    print('Dataset exists at', data_root)

In [None]:
from openbinn.binn.data import PnetSimDataSet, ReactomeNetwork, get_layer_maps
from torch.utils.data.sampler import SubsetRandomSampler
from torch_geometric.loader import DataLoader as GeoLoader

ds = PnetSimDataSet(root=str(data_root), num_features=3)
ds.split_index_by_file(
    train_fp=data_root/'splits'/'training_set_0.csv',
    valid_fp=data_root/'splits'/'validation_set.csv',
    test_fp =data_root/'splits'/'test_set.csv'
)

reactome = ReactomeNetwork(dict(
    reactome_base_dir='biological_knowledge/simulation',
    relations_file_name='SimulationPathwaysRelation.txt',
    pathway_names_file_name='SimulationPathways.txt',
    pathway_genes_file_name='SimulationPathways.gmt',
))
maps = get_layer_maps(genes=list(ds.node_index), reactome=reactome,
                      n_levels=3, direction='root_to_leaf', add_unk_genes=False)
ds.node_index = [g for g in ds.node_index if g in maps[0].index]

bs = 16
tr_loader = GeoLoader(ds, bs, sampler=SubsetRandomSampler(ds.train_idx), num_workers=0)
va_loader = GeoLoader(ds, bs, sampler=SubsetRandomSampler(ds.valid_idx), num_workers=0)
te_loader = GeoLoader(ds, bs, sampler=SubsetRandomSampler(ds.test_idx), num_workers=0)

In [None]:
from openbinn.binn import PNet
from openbinn.binn.util import (
    InMemoryLogger, eval_metrics, EpochMetricsPrinter, GradNormPrinter, MetricsRecorder
)
import pytorch_lightning as pl

model = PNet(layers=maps, num_genes=maps[0].shape[0], lr=1e-3)
init_loss, init_acc, init_auc = eval_metrics(model, va_loader)
print(f'Start: loss={init_loss:.4f} acc={init_acc:.4f} auc={init_auc:.4f}')

recorder = MetricsRecorder(data_root/'results', tr_loader, va_loader, te_loader)
trainer = pl.Trainer(
    accelerator='auto', deterministic=True, max_epochs=50,
    callbacks=[
        pl.callbacks.EarlyStopping('val_loss', patience=10, mode='min', verbose=True, min_delta=0.01),
        EpochMetricsPrinter(tr_loader, va_loader),
        GradNormPrinter(),
        recorder,
    ],
    logger=InMemoryLogger(), enable_progress_bar=False
)
trainer.fit(model, tr_loader, va_loader)

fin_loss, fin_acc, fin_auc = eval_metrics(model, va_loader)
print(f'End: loss={fin_loss:.4f} acc={fin_acc:.4f} auc={fin_auc:.4f}')

In [None]:
import pandas as pd, matplotlib.pyplot as plt
metrics = pd.read_csv(data_root/'results/performance/metrics.csv')
metrics.head()

plt.plot(metrics['epoch'], metrics['train_auc'], label='train')
plt.plot(metrics['epoch'], metrics['val_auc'], label='val')
plt.xlabel('Epoch'); plt.ylabel('AUC'); plt.legend(); plt.show()