In [1]:
import json
from collections import defaultdict
from pathlib import Path

import anndata as ad
import numpy as np
import pandas as pd
import torch
from safetensors.torch import load_file
from sklearn.metrics import mean_squared_error, r2_score
from torch.utils.data import Dataset
from transformers import PretrainedConfig, Trainer, TrainingArguments

from methformer import MethformerRegressor


In [2]:
meth_adata = ad.read_h5ad("data/methformer_pretraining_dataset.h5ad")
mll_adata = ad.read_h5ad("data/mll_n.h5ad")

# Extract patient/sample identifiers
meth_ids = meth_adata.obs_names.tolist()
ids = [mid.split("-")[2] for mid in meth_ids]

# Filter MLL-N IDs by sample matches
mll_ids = mll_adata.obs_names.tolist()
mll_series = pd.Series(mll_ids)
mll_samples = sorted(mll_series[mll_series.str.contains('|'.join(ids))])

# Build MethFormer sample â†’ MLL-N replicate mapping
meth_to_mll = defaultdict(list)
for mll_id in mll_samples:
    match_id = mll_id.split("-")[1]
    for meth_id in meth_ids:
        if match_id in meth_id:
            meth_to_mll[meth_id].append(mll_id)

meth_to_mll = dict(meth_to_mll)


In [3]:
# --- Hold out test patient 23003 and validation patient 22620 ---
test_dict = {'METH-patient-23003': meth_to_mll['METH-patient-23003']}
val_dict = {'METH-patient-22620': meth_to_mll['METH-patient-22620']}

# Remove from training pool
for key in ['METH-patient-23003', 'METH-patient-22620']:
    meth_to_mll.pop(key)

train_dict = meth_to_mll

# --- Dataset class ---
class MethformerRepAveragedDataset(Dataset):
    def __init__(self, meth_adata, mll_adata, match_dict, input_layer="methylation"):
        self.inputs = []
        self.labels = []

        for meth_id, mll_ids in match_dict.items():
            meth_idx = meth_adata.obs_names.get_loc(meth_id)
            meth_data = meth_adata.layers[input_layer][meth_idx]  # (515400, 2)
            meth_img = np.asarray(meth_data).astype(np.float32).T  # shape: (2, 515400)


            mll_rpks = []
            for mll_id in mll_ids:
                mll_idx = mll_adata.obs_names.get_loc(mll_id)
                rpkm = mll_adata[mll_idx].X.astype(np.float32)
                mll_rpks.append(rpkm)

            mean_rpkm = np.mean(mll_rpks, axis=0)
            log_mean_rpkm = np.log1p(mean_rpkm).mean()

            self.inputs.append(meth_img)
            self.labels.append(log_mean_rpkm)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "pixel_values": torch.tensor(self.inputs[idx], dtype=torch.float32),
            "labels": torch.tensor(self.labels[idx], dtype=torch.float32)
        }

# --- Load datasets ---
train_dataset = MethformerRepAveragedDataset(meth_adata, mll_adata, train_dict)
val_dataset = MethformerRepAveragedDataset(meth_adata, mll_adata, val_dict)
test_dataset = MethformerRepAveragedDataset(meth_adata, mll_adata, test_dict)


In [4]:


# --- Load config ---
config_path = Path("output/methformer_2025-05-30_2327/methformer_pretrained/config.json")
with open(config_path) as f:
    config_dict = json.load(f)

config = PretrainedConfig.from_dict(config_dict)

# --- Instantiate model ---
model = MethformerRegressor(config)

# --- Load safetensors weights ---
weights_path = "output/methformer_2025-05-30_2327/methformer_pretrained/model.safetensors"
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=False)



_IncompatibleKeys(missing_keys=['regression_head.weight', 'regression_head.bias'], unexpected_keys=[])

In [5]:
model


MethformerRegressor(
  (embed): Linear(in_features=2, out_features=128, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_head): Linear(in_features=128, out_features=2, bias=True)
  (regression_head): Linear(in_features=128, out_features=1, bias=True)
)

In [6]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="results",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=20,
    learning_rate=1e-4,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="r2",
    greater_is_better=True,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.flatten()
    r2 = r2_score(labels, preds)
    mse = mean_squared_error(labels, preds)
    return {"r2": r2, "mse": mse}


In [7]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)


In [8]:
trainer.train()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mcatherine-chahrour[0m ([33mcatherine-chahrour-university-of-oxford[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


ValueError: The batch received was empty, your model won't be able to train on it. Double-check that your training dataset contains keys expected by the model: input_values,attention_mask,kwargs,label,label_ids.