In [57]:
import os,random,time,itertools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from collections import defaultdict
from itertools import combinations
from types import SimpleNamespace as _NS
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [58]:
CFG=_NS(
    pooled_pt="/Users/xshan/Desktop/PLM/model/esmC_600m_training5.pt",
    score_pt="/Users/xshan/Desktop/PLM/model/esmC_600m_Pagg.pt",
    out_csv="testrun_scores_multiseed.csv",

    epochs=10,
    adapter_lr=0.0001,
    adapter_hidden=512,
    adapter_middle=128,
    adapter_latent=64,
    batch_size=512,
    n_pp_diff=60000,
    n_nn_diff=1000,
    n_pn_diff=1000,
    margin=1,
    loss_scale=1000,
    seeds=[2,3,4,5,6],
    verbose=True,
)

mapping={
    "G01": "amidotransferase",
    "G02": "methyltransferase",
    "G03": "methyltransferase",
    "G04": "methyltransferase",
    "G05": "monooxygenase",
    "G06": "monooxygenase",
    "G07": "monooxygenase",
    "G08": "monooxygenase",
    "G09": "prenyltransferase",
    "G10": "prenyltransferase",
    "G11": "adenylation",
    "G12": "dehydrogenase",
    "G13": "dehydrogenase",
    "G14": "dehydrogenase",
    "test": "test",
    "amidotransferase": "amidotransferase",
    "methyltransferase": "methyltransferase",
    "monooxygenase": "monooxygenase",
    "prenyltransferase": "prenyltransferase",
    "adenylation": "adenylation",
    "dehydrogenase": "dehydrogenase",
}

In [59]:
def set_seed(seed: int=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class Adapter(nn.Module):
    def __init__(self,in_dim,hidden=256,middle=128,latent=64,dropout=0.1):
        super().__init__()
        self.fc1=nn.Linear(in_dim,hidden)
        self.fc2=nn.Linear(hidden,middle)
        self.mu=nn.Linear(middle,latent)
        self.act=nn.ReLU(inplace=True)
        self.drop=nn.Dropout(dropout)

    def forward(self,x):
        h1=self.drop(self.act(self.fc1(x)))
        h2=self.drop(self.act(self.fc2(h1)))
        z=self.mu(h2)
        return z

class ContrastiveLoss(nn.Module):
    def __init__(self,margin=1.0,scale=10000):
        super().__init__()
        self.margin=margin
        self.scale=scale
    def forward(self,z1,z2,labels):
        d=F.pairwise_distance(z1,z2,p=2)
        pos_loss=d**2
        neg_loss=F.relu(self.margin-d)**2
        loss=torch.where(labels==1,pos_loss,neg_loss)
        return self.scale * loss.mean()

class IndexPairDataset(Dataset):
    def __init__(self,pairs):
        self.pairs=pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self,k):
        return self.pairs[k]

def collate_fetch(batch,X_train_device,X_train):
    ii=torch.tensor([b[0] for b in batch],dtype=torch.long)
    jj=torch.tensor([b[1] for b in batch],dtype=torch.long)
    yy=torch.tensor([b[2] for b in batch],dtype=torch.long)
    x1=X_train.index_select(0,ii.to(X_train_device))
    x2=X_train.index_select(0,jj.to(X_train_device))
    return x1,x2,yy.to(X_train_device)

@torch.no_grad()
def extract_features(encoder,X,device,batch_size=128):
    encoder.eval()
    feats=[]
    for i in range(0,len(X),batch_size):
        xb=X[i:i+batch_size].to(device)
        z=encoder(xb).cpu().numpy()
        feats.append(z)
    return np.concatenate(feats,axis=0)

In [60]:
def _rng(seed):
    return np.random.default_rng(seed)

def _canon(i,j):
    return (i,j) if i<j else (j,i)

def build_PN_same_all(group_to_pos,group_to_neg):
    out=[]
    shared=sorted(set(group_to_pos) & set(group_to_neg))
    for g in shared:
        for i in group_to_pos[g]:
            for j in group_to_neg[g]:
                out.append((i,j))
    return out

def sample_PP_diff_stream(group_to_pos,k,seed):
    if k<=0:
        return []
    rng=_rng(seed)
    pos_groups=[g for g, v in group_to_pos.items() if v]
    if len(pos_groups)<2:
        return []
    seen=set()
    out=[]
    max_attempts=max(10000,20*k)
    attempts=0
    while len(out)<k and attempts < max_attempts:
        g1,g2=rng.choice(pos_groups,size=2,replace=False)
        i=rng.choice(group_to_pos[g1])
        j=rng.choice(group_to_pos[g2])
        key=_canon(i,j)
        if key in seen:
            attempts+=1
            continue
        seen.add(key)
        out.append(key)
    return out

def sample_NN_diff_stream(group_to_neg,k,seed):
    if k<=0:
        return []
    rng=_rng(seed)
    neg_groups=[g for g,v in group_to_neg.items() if v]
    if len(neg_groups)<2:
        return []
    seen=set()
    out=[]
    max_attempts=max(10000,20*k)
    attempts=0
    while len(out)<k and attempts<max_attempts:
        g1,g2=rng.choice(neg_groups,size=2,replace=False)
        i=rng.choice(group_to_neg[g1])
        j=rng.choice(group_to_neg[g2])
        key=_canon(i,j)
        if key in seen:
            attempts+=1
            continue
        seen.add(key)
        out.append(key)
    return out

def sample_PN_diff_stream(group_to_pos,group_to_neg,k,seed):
    if k<=0:
        return []
    rng = _rng(seed)
    pos_groups=[g for g,v in group_to_pos.items() if v]
    neg_groups=[g for g,v in group_to_neg.items() if v]
    if not pos_groups or not neg_groups:
        return []
    seen=set()
    out=[]
    max_attempts=max(20000,30*k)
    attempts=0
    while len(out) < k and attempts < max_attempts:
        gp=rng.choice(pos_groups)
        gn=rng.choice(neg_groups)
        if gp==gn:
            attempts+=1
            continue
        i=rng.choice(group_to_pos[gp])
        j=rng.choice(group_to_neg[gn])
        key=(i,j)
        if key in seen:
            attempts +=1
            continue
        seen.add(key)
        out.append(key)
    return out

In [61]:
def run_train_seed(cfg,seed,verbose=False):
    t_all0=time.time()
    set_seed(seed)
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

    blob=torch.load(CFG.pooled_pt,weights_only=True)
    X_all,ids_all=blob["X"],blob["ids"]
    N,D=X_all.shape

    def parse_label_group(sid):
        parts=sid.split("_")
        return (1 if parts[0].lower()=="positive" else 0, parts[1])

    labels,groups=zip(*(parse_label_group(sid) for sid in ids_all))
    meta_df=pd.DataFrame({"id":ids_all,"label":labels,"group":groups})
    meta_df["group"]=meta_df["group"].map(mapping)

    df_train=meta_df
    X_train=X_all[df_train.index].to(device)
    y_train=torch.tensor(df_train["label"].to_numpy(),dtype=torch.long)
    groups_train=df_train["group"].tolist()

    group_to_pos,group_to_neg=defaultdict(list),defaultdict(list)
    for i,(lab,grp) in enumerate(zip(y_train.numpy(),groups_train)):
        (group_to_pos if lab==1 else group_to_neg)[grp].append(i)

    PN_same_all=build_PN_same_all(group_to_pos,group_to_neg)
    PP_diff_sel=sample_PP_diff_stream(group_to_pos,CFG.n_pp_diff,seed)
    NN_diff_sel=sample_NN_diff_stream(group_to_neg,CFG.n_nn_diff,seed+1)
    PN_diff_sel=sample_PN_diff_stream(group_to_pos,group_to_neg,CFG.n_pn_diff,seed+2)
    
    def warn_short(name,got,want):
        if got<want:
            print(f"[warn] {name}: requested {want}, got {got}",flush=True)
    warn_short("PP_diff",len(PP_diff_sel),CFG.n_pp_diff)
    warn_short("NN_diff", len(NN_diff_sel), CFG.n_nn_diff)
    warn_short("PN_diff", len(PN_diff_sel), CFG.n_pn_diff)

    pairs=([(i,j,-1) for (i,j) in PN_same_all]+
           [(i,j,1) for (i,j) in PP_diff_sel]+
           [(i,j,1) for (i,j) in NN_diff_sel]+
           [(i,j,-1) for (i,j) in PN_diff_sel])

    _rng_local=np.random.default_rng(seed+3)
    _rng_local.shuffle(pairs)

    if verbose:
        print(
            f"[seed {seed}] pairs: "
            f"PN_same={len(PN_same_all)}, "
            f"PP_diff={len(PP_diff_sel)}, "
            f"NN_diff={len(NN_diff_sel)}, "
            f"PN_diff={len(PN_diff_sel)}, "
            f"total={len(pairs)}",
            flush=True
        )

    pair_ds=IndexPairDataset(pairs)
    pair_loader=DataLoader(
        pair_ds,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=lambda b: collate_fetch(b,device,X_train),
    )

    encoder=Adapter(
        in_dim=X_train.shape[1],
        hidden=CFG.adapter_hidden,
        middle=CFG.adapter_middle,
        latent=CFG.adapter_latent
    ).to(device)

    optim=torch.optim.AdamW(encoder.parameters(),lr=CFG.adapter_lr)
    criterion=ContrastiveLoss(margin=CFG.margin,scale=CFG.loss_scale)

    for ep in range(CFG.epochs):
        encoder.train()

        total_loss=0.0
        n_batches=0

        for x1,x2,target in pair_loader:
            z1,z2=encoder(x1),encoder(x2)
            loss=criterion(z1,z2,target.long())

            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()

            total_loss+=loss.item()
            n_batches+=1

        epoch_loss=total_loss/max(1,n_batches)

        print(f"    Epoch {ep+1}/{cfg['epochs']} | ContrastiveLoss = {epoch_loss:.4f}")

    Xz_train=extract_features(encoder,X_train,device,batch_size=128)

    clf=make_pipeline(
        StandardScaler(),
        SVC(kernel="linear",C=1.0,class_weight="balanced",probability=True,random_state=seed)
    )
        
    clf.fit(Xz_train,y_train.numpy())

    return encoder,clf    

In [None]:
blob_score=torch.load(CFG.score_pt,weights_only=True)
X_score_raw,ids_score=blob_score["X"],blob_score["ids"]
base_cfg = {
    "pooled_pt": CFG.pooled_pt,
    "epochs": CFG.epochs,
    "adapter_lr": CFG.adapter_lr,
    "adapter_hidden": CFG.adapter_hidden,
    "adapter_middle": CFG.adapter_middle,
    "adapter_latent": CFG.adapter_latent,
    "batch_size": CFG.batch_size,
    "n_pp_diff": CFG.n_pp_diff,
    "n_nn_diff": CFG.n_nn_diff,
    "n_pn_diff": CFG.n_pn_diff,
    "margin": CFG.margin,
    "loss_scale": CFG.loss_scale,
}

all_seed_scores={}

for seed in CFG.seeds:
    encoder,clf=run_train_seed(base_cfg,seed,verbose=CFG.verbose)
    device=next(encoder.parameters()).device
    X_score=X_score_raw.to(device)
    Xz_score=extract_features(encoder,X_score,device,batch_size=128)

    scores=clf.decision_function(Xz_score)
    all_seed_scores[seed]=scores

df_score=pd.DataFrame({"id":list(ids_score)})
for seed in CFG.seeds:
    df_score[f"seed_{seed}"]=all_seed_scores[seed]

df_score=df_score.set_index("id")
df_score.to_csv(CFG.out_csv)
print(f"\nSaved scores for {len(df_score)} sequences and {len(CFG.seeds)} seeds to: {CFG.out_csv}")