In [1]:
import os
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv

In [68]:
TRAIN_PATH = Path('/home/team/data/pmt_pmt/raw/training_data.csv')
TEST_PATH = Path('/home/team/data/pmt_pmt/raw/testing_data.csv')

In [69]:
train_df = pd.read_csv(TRAIN_PATH)
test_df = pd.read_csv(TEST_PATH)

In [70]:
train_df = train_df.drop_duplicates()
test_df = test_df.drop_duplicates()

In [71]:
merged_df = test_df.merge(train_df, on=['Antigen', 'HLA', 'CDR3'], how='left', indicator=True)

In [72]:
test_df = test_df[~test_df.index.isin(merged_df[merged_df['_merge'] == 'both'].index)]

In [73]:
import numpy as np
import pandas as pd
from typing import Tuple

def generate_triplet_negatives(
    df: pd.DataFrame,
    pep_col: str = "Antigen",
    mhc_col: str = "HLA",
    tcr_col: str = "CDR3",
    k: int = 1,
    seed: int = 42,
    max_tries_per_sample: int = 50,
    return_with_labels: bool = True,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Для каждого уникального пептида сэмплирует k случайных (HLA, TCR) на КАЖДУЮ его позитивную строку,
    избегая комбинаций, которые уже встречались с этим пептидом в df.

    Возвращает:
      negatives_df: DataFrame с колонками [pep_col, mhc_col, tcr_col] (+ 'label'=0 если return_with_labels)
      optionally: объединённый df с позитивами (label=1) и негативами (label=0), если return_with_labels=True
    """
    rng = np.random.default_rng(seed)

    # все уникальные значения для глобального сэмплинга
    all_mhc = df[mhc_col].unique()
    all_tcr = df[tcr_col].unique()

    # для каждого пептида — множество уже встречавшихся пар (mhc, tcr)
    seen_pairs_by_pep = {
        pep: set(zip(g[mhc_col].tolist(), g[tcr_col].tolist()))
        for pep, g in df.groupby(pep_col, sort=False)
    }

    neg_rows = []
    for pep, g in df.groupby(pep_col, sort=False):
        n_pos = len(g)
        need = n_pos * k

        seen_pairs = set(seen_pairs_by_pep[pep])  # нельзя брать отсюда
        chosen_pairs = set()                      # и не повторяться в негативных для этого pep

        tries = 0
        # rejection sampling: случайно берём mhc и tcr пока не наберём need штук
        while len(chosen_pairs) < need and tries < need * max_tries_per_sample:
            mhc = rng.choice(all_mhc)
            tcr = rng.choice(all_tcr)
            pair = (mhc, tcr)
            if pair not in seen_pairs and pair not in chosen_pairs:
                chosen_pairs.add(pair)
            tries += 1

        if len(chosen_pairs) < need:
            # если данных мало (например, один MHC и один TCR у pep), предупреждаем
            print(f"[warn] For peptide '{pep}' generated {len(chosen_pairs)}/{need} negatives "
                  f"(increase max_tries_per_sample or provide more diversity).")

        for mhc, tcr in chosen_pairs:
            neg_rows.append({pep_col: pep, mhc_col: mhc, tcr_col: tcr})

    negatives_df = pd.DataFrame(neg_rows, columns=[pep_col, mhc_col, tcr_col])

    if return_with_labels:
        pos = df.copy()
        pos["label"] = 1
        neg = negatives_df.copy()
        neg["label"] = 0
        combined = pd.concat([pos, neg], ignore_index=True)
        return combined


Ratio 1:1

In [74]:
train_df_labeled = generate_triplet_negatives(train_df)
test_df_labeled = generate_triplet_negatives(test_df)

In [79]:
class TripletDataFromDF:
    """
    Принимает готовый pd.DataFrame с колонками:
      Antigen, HLA, CDR3, label
    Строит HeteroData граф для задачи p-m-t.
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.copy()
        self.pid = {}
        self.mid = {}
        self.tid = {}
        self.data = None  # HeteroData

    def build_id_maps(self):
        all_p = pd.Index(self.df["Antigen"].unique())
        all_m = pd.Index(self.df["HLA"].unique())
        all_t = pd.Index(self.df["CDR3"].unique())
        self.pid = {v: i for i, v in enumerate(all_p)}
        self.mid = {v: i for i, v in enumerate(all_m)}
        self.tid = {v: i for i, v in enumerate(all_t)}

    def build_graph(self) -> HeteroData:
        data = HeteroData()
        data["pep"].num_nodes = len(self.pid)
        data["mhc"].num_nodes = len(self.mid)
        data["tcr"].num_nodes = len(self.tid)

        # P-M edges
        pm_edges = [(self.pid[p], self.mid[m]) for p, m in zip(self.df["Antigen"], self.df["HLA"])]
        data["pep", "binds", "mhc"].edge_index = torch.tensor(pm_edges, dtype=torch.long).t().contiguous()

        # M-T edges
        mt_edges = [(self.mid[m], self.tid[t]) for m, t in zip(self.df["HLA"], self.df["CDR3"])]
        data["mhc", "presents_to", "tcr"].edge_index = torch.tensor(mt_edges, dtype=torch.long).t().contiguous()

        # Triplet tensors + split
        y_pmt = torch.tensor(self.df["label"].astype("int64").values)
        tr, te = train_test_split(
            np.arange(len(self.df)),
            test_size=0.2, random_state=42, stratify=self.df["label"].values
        )
        data["pmt_splits"] = {"train": torch.tensor(tr), "test": torch.tensor(te)}
        data["pmt_pairs"] = {
            "pep": torch.tensor([self.pid[p] for p in self.df["Antigen"].values], dtype=torch.long),
            "mhc": torch.tensor([self.mid[m] for m in self.df["HLA"].values], dtype=torch.long),
            "tcr": torch.tensor([self.tid[t] for t in self.df["CDR3"].values], dtype=torch.long),
            "y":   y_pmt,
        }

        self.data = data
        return data

In [94]:
class TripletOnlyGNN(nn.Module):
    def __init__(self, n_pep, n_mhc, n_tcr, emb_dim=128, hidden=256, layers=2, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.emb = nn.ModuleDict({
            "pep": nn.Embedding(n_pep, emb_dim),
            "mhc": nn.Embedding(n_mhc, emb_dim),
            "tcr": nn.Embedding(n_tcr, emb_dim),
        })
        self.layers = nn.ModuleList()
        for _ in range(layers):
            self.layers.append(HeteroConv({
                ("pep","binds","mhc"): SAGEConv((-1,-1), hidden),
                ("mhc","presents_to","tcr"): SAGEConv((-1,-1), hidden),
            }, aggr="mean"))

        # НОВОЕ: проекции в общий размер hidden
        self.proj_pep = nn.Linear(emb_dim, hidden)
        self.proj_mhc = nn.Identity()   # уже hidden после конвов
        self.proj_tcr = nn.Identity()   # уже hidden после конвов

        self.head = nn.Sequential(
            nn.Linear(3*hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )

    def _forward_node_embeddings(self, data):
        device = self.emb["pep"].weight.device
        h = {
            "pep": self.emb["pep"](torch.arange(data["pep"].num_nodes, device=device)),
            "mhc": self.emb["mhc"](torch.arange(data["mhc"].num_nodes, device=device)),
            "tcr": self.emb["tcr"](torch.arange(data["tcr"].num_nodes, device=device)),
        }
        edge_index_dict = {
            ("pep","binds","mhc"): data["pep","binds","mhc"].edge_index,
            ("mhc","presents_to","tcr"): data["mhc","presents_to","tcr"].edge_index,
        }
        for conv in self.layers:
            out = conv(h, edge_index_dict)
            out = {k: F.dropout(F.relu(v), p=self.dropout, training=self.training) for k,v in out.items()}
            # сохраняем представления для типов, которых нет в out (pep)
            h = {k: out.get(k, h[k]) for k in h.keys()}
        return h

    def forward_scores(self, data, pairs):
        h = self._forward_node_embeddings(data)
        # приведение всех трёх к hidden
        hp = self.proj_pep(h["pep"])[pairs["pep"]]
        hm = self.proj_mhc(h["mhc"])[pairs["mhc"]]
        ht = self.proj_tcr(h["tcr"])[pairs["tcr"]]
        logits = self.head(torch.cat([hp, hm, ht], dim=-1)).squeeze(-1)
        return logits

In [95]:
@dataclass
class Config:
    seed: int = 42

    # model
    emb_dim: int = 128
    hidden: int = 256
    layers: int = 2
    dropout: float = 0.2

    # training
    epochs: int = 10
    lr: float = 2e-3
    weight_decay: float = 1e-4

    device: str = "cuda"

In [96]:
def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def metrics_bin(y_true: np.ndarray, y_score: np.ndarray) -> Dict[str, float]:
    return {
        "roc_auc": float(roc_auc_score(y_true, y_score)),
        "pr_auc": float(average_precision_score(y_true, y_score)),
    }

In [97]:
class Trainer:
    def __init__(self, cfg: Config, data: HeteroData):
        self.cfg = cfg
        self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
        self.data = data

        n_pep = data["pep"].num_nodes
        n_mhc = data["mhc"].num_nodes
        n_tcr = data["tcr"].num_nodes

        self.model = TripletOnlyGNN(
            n_pep, n_mhc, n_tcr,
            emb_dim=cfg.emb_dim, hidden=cfg.hidden, layers=cfg.layers, dropout=cfg.dropout
        ).to(self.device)

        self.opt = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        # move static graph tensors
        for et in [("pep", "binds", "mhc"), ("mhc", "presents_to", "tcr")]:
            data[et].edge_index = data[et].edge_index.to(self.device)
        for kk, v in data["pmt_pairs"].items():
            data["pmt_pairs"][kk] = v.to(self.device)
        for kk, v in data["pmt_splits"].items():
            data["pmt_splits"][kk] = v.to(self.device)

    def _bce(self, logits, y):
        return F.binary_cross_entropy_with_logits(logits, y.float())

    def _run_split(self, split: str) -> Dict[str, float]:
        model, data = self.model, self.data
        is_train = split == "train"
        model.train() if is_train else model.eval()

        idx = data["pmt_splits"][split]
        pairs_full = data["pmt_pairs"]
        pairs = {"pep": pairs_full["pep"][idx],
                 "mhc": pairs_full["mhc"][idx],
                 "tcr": pairs_full["tcr"][idx]}
        y = pairs_full["y"][idx]

        with torch.set_grad_enabled(is_train):
            logits = model.forward_scores(data, pairs)
            loss = self._bce(logits, y)

            if is_train:
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

        return {"loss": float(loss.detach().item())}

    @torch.no_grad()
    def evaluate(self) -> Dict[str, float]:
        model, data = self.model, self.data
        model.eval()
        idx = data["pmt_splits"]["test"]

        pairs_full = data["pmt_pairs"]
        pairs = {"pep": pairs_full["pep"][idx],
                 "mhc": pairs_full["mhc"][idx],
                 "tcr": pairs_full["tcr"][idx]}
        y = pairs_full["y"][idx].detach().cpu().numpy()
        logits = model.forward_scores(data, pairs)
        scores = torch.sigmoid(logits).detach().cpu().numpy()
        return metrics_bin(y, scores)

    def fit(self):
        for epoch in range(1, self.cfg.epochs + 1):
            tl = self._run_split("train")
            vl = self._run_split("test")
            m = self.evaluate()
            print(json.dumps({"epoch": epoch, "train": tl, "val": vl, "metrics": m}, ensure_ascii=False))

In [98]:
def main():
    cfg = Config()
    set_seed(cfg.seed)

    # load data & build graph
    ds = TripletDataFromDF(train_df_labeled)
    ds.build_id_maps()
    data = ds.build_graph()

    # train
    trainer = Trainer(cfg, data)
    trainer.fit()

In [99]:
torch.set_float32_matmul_precision("high")
main()



{"epoch": 1, "train": {"loss": 0.6944091320037842}, "val": {"loss": 0.6412291526794434}, "metrics": {"roc_auc": 0.942475703596567, "pr_auc": 0.9320037671124302}}
{"epoch": 2, "train": {"loss": 0.6362224221229553}, "val": {"loss": 0.5510789752006531}, "metrics": {"roc_auc": 0.9547503072394423, "pr_auc": 0.9482793968513583}}
{"epoch": 3, "train": {"loss": 0.5571649074554443}, "val": {"loss": 0.45060911774635315}, "metrics": {"roc_auc": 0.9510077042109826, "pr_auc": 0.9517305436737268}}
{"epoch": 4, "train": {"loss": 0.4625527858734131}, "val": {"loss": 0.3511071503162384}, "metrics": {"roc_auc": 0.9612361028900442, "pr_auc": 0.9540453536434291}}
{"epoch": 5, "train": {"loss": 0.35303109884262085}, "val": {"loss": 0.27104517817497253}, "metrics": {"roc_auc": 0.963222032890434, "pr_auc": 0.9573624416574291}}
{"epoch": 6, "train": {"loss": 0.2813451588153839}, "val": {"loss": 0.2215682715177536}, "metrics": {"roc_auc": 0.9692280150858612, "pr_auc": 0.9627079597090946}}
{"epoch": 7, "train":