In [34]:
# === IMPORTS === #
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
import joblib
from sentence_transformers import SentenceTransformer
import os

In [17]:
# === DEVICE === #
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
# === LOAD DATASET === #
df = pd.read_parquet("clinvar_dmd_embeddings.parquet")

In [19]:
# === SAVE LABEL ENCODERS FOR STRUCTURED CATEGORICAL === #
for col in ["Variant type"]:
    le = LabelEncoder()
    df[col] = df[col].fillna("unknown").astype(str)
    df[col] = le.fit_transform(df[col])
    joblib.dump(le, f"label_encoder_{col.replace(' ', '_')}.pkl")

In [20]:
# === ENCODE TARGET === #
le_target = LabelEncoder()
df["Germline classification"] = df["Germline classification"].astype(str)
df["target"] = le_target.fit_transform(df["Germline classification"])
joblib.dump(le_target, "label_encoder_germline.pkl")

['label_encoder_germline.pkl']

In [21]:
# === EMBEDDINGS === #
llm_features = [col for col in df.columns if "_llm_" in col]
X_llm = df[llm_features].values.astype(np.float32)

In [22]:
# === STRUCTURED FEATURES === #
structured_cols = ['Start_Pos', 'End_Pos', 'Variant_Length', 'has_fs', 'has_stop', 'n_protein_changes', 'Variant type']
X_struct = df[structured_cols].fillna(0).values.astype(np.float32)

In [23]:
# === LIST-BASED FEATURES === #
from collections import defaultdict

def fit_label_encoder_for_lists(column):
    all_items = set(item for sublist in df[column] for item in sublist)
    le = LabelEncoder()
    le.fit(list(all_items) + ["unknown"])
    joblib.dump(le, f"label_encoder_{column}.pkl")
    return le

def list_to_index_tensor(list_col, le, max_len):
    index_seqs = [torch.tensor([le.transform([item])[0] if item in le.classes_ else le.transform(["unknown"])[0] for item in row]) for row in list_col]
    padded = pad_sequence(index_seqs, batch_first=True, padding_value=0)
    return padded[:, :max_len]

# Fit encoders
le_gene = fit_label_encoder_for_lists("Gene_list")
le_cond = fit_label_encoder_for_lists("Condition_list")
le_cons = fit_label_encoder_for_lists("Consequence_list")
le_allele = fit_label_encoder_for_lists("Allele_list")

# Convert to tensors
gene_tensor = list_to_index_tensor(df['Gene_list'], le_gene, 10)
cond_tensor = list_to_index_tensor(df['Condition_list'], le_cond, 5)
cons_tensor = list_to_index_tensor(df['Consequence_list'], le_cons, 5)
allele_tensor = list_to_index_tensor(df['Allele_list'], le_allele, 10)

target = df['target'].values

In [35]:
# === COMPUTE CLASS WEIGHTS === #
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(target),
    y=target
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)


In [36]:
# === DEFINE MODEL === #
class DMDMultiInputModel(nn.Module):
    def __init__(self,
                 llm_dim=768*4,
                 structured_dim=7,
                 gene_vocab_size=2640,
                 cond_vocab_size=80,
                 cons_vocab_size=20,
                 allele_vocab_size=10325,
                 output_dim=8):
        super().__init__()

        self.gene_embed = nn.Embedding(gene_vocab_size, 64, padding_idx=0)
        self.cond_embed = nn.Embedding(cond_vocab_size, 32, padding_idx=0)
        self.cons_embed = nn.Embedding(cons_vocab_size, 16, padding_idx=0)
        self.allele_embed = nn.Embedding(allele_vocab_size, 64, padding_idx=0)

        self.llm_fc = nn.Linear(llm_dim, 256)
        self.struct_fc = nn.Linear(structured_dim, 64)

        self.final_fc = nn.Sequential(
            nn.Linear(256 + 64 + 64 + 32 + 16 + 64, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, output_dim)
        )

    def forward(self, llm_x, struct_x, gene_ids, cond_ids, cons_ids, allele_ids):
        gene_vec = self.gene_embed(gene_ids).mean(dim=1)
        cond_vec = self.cond_embed(cond_ids).mean(dim=1)
        cons_vec = self.cons_embed(cons_ids).mean(dim=1)
        allele_vec = self.allele_embed(allele_ids).mean(dim=1)

        llm_feat = F.relu(self.llm_fc(llm_x))
        struct_feat = F.relu(self.struct_fc(struct_x))

        x = torch.cat([llm_feat, struct_feat, gene_vec, cond_vec, cons_vec, allele_vec], dim=1)
        return self.final_fc(x)

In [39]:
# === TRAIN MODEL === #
idx_train, idx_test = train_test_split(
    np.arange(len(df)), test_size=0.2, stratify=target, random_state=42)

model = DMDMultiInputModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

EPOCHS = 15
BATCH_SIZE = 64

for epoch in range(EPOCHS):
    model.train()
    permutation = torch.randperm(len(idx_train))
    total_loss = 0

    for i in range(0, len(permutation), BATCH_SIZE):
        batch_indices = permutation[i:i + BATCH_SIZE]
        batch_ids = idx_train[batch_indices]  # this gives real data row indices

        batch_llm = torch.tensor(X_llm[batch_ids]).to(device)
        batch_struct = torch.tensor(X_struct[batch_ids]).to(device)
        batch_gene = gene_tensor[batch_ids].to(device)
        batch_cond = cond_tensor[batch_ids].to(device)
        batch_cons = cons_tensor[batch_ids].to(device)
        batch_allele = allele_tensor[batch_ids].to(device)
        batch_target = torch.tensor(target[batch_ids]).to(device)

        optimizer.zero_grad()
        output = model(batch_llm, batch_struct, batch_gene, batch_cond, batch_cons, batch_allele)
        loss = loss_fn(output, batch_target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {total_loss:.4f}")
    if epoch == 0:
        print("Sample target batch:", batch_target[:10])
        print("Sample output:", output[:2].cpu().detach().numpy())


Epoch 1/15, Loss: 204.5886
Sample target batch: tensor([7, 7])
Sample output: [[-1.665524    1.2894893   1.012018    0.04047731 -1.6004074  -3.1071808
  -2.14112     2.0404668 ]
 [-1.6721618   0.24505603  0.18369399 -0.46111047 -1.5689144  -1.1376799
  -1.6126863   2.0893607 ]]
Epoch 2/15, Loss: 153.6431
Epoch 3/15, Loss: 130.0872
Epoch 4/15, Loss: 116.7089
Epoch 5/15, Loss: 104.1498
Epoch 6/15, Loss: 96.8221
Epoch 7/15, Loss: 87.8316
Epoch 8/15, Loss: 79.9915
Epoch 9/15, Loss: 70.7795
Epoch 10/15, Loss: 60.2807
Epoch 11/15, Loss: 51.7645
Epoch 12/15, Loss: 46.0761
Epoch 13/15, Loss: 38.8717
Epoch 14/15, Loss: 30.5196
Epoch 15/15, Loss: 26.3582


In [40]:
# === EVALUATE === #
from sklearn.utils.multiclass import unique_labels
model.eval()
with torch.no_grad():
    batch_llm = torch.tensor(X_llm[idx_test]).to(device)
    batch_struct = torch.tensor(X_struct[idx_test]).to(device)
    batch_gene = gene_tensor[idx_test].to(device)
    batch_cond = cond_tensor[idx_test].to(device)
    batch_cons = cons_tensor[idx_test].to(device)
    batch_allele = allele_tensor[idx_test].to(device)
    batch_target = torch.tensor(target[idx_test]).to(device)

    output = model(batch_llm, batch_struct, batch_gene, batch_cond, batch_cons, batch_allele)
    pred = torch.argmax(output, dim=1).cpu().numpy()
    y_true = batch_target.cpu().numpy()
# Compute all expected labels
all_labels = list(range(len(le_target.classes_)))

# Safe classification report
print(classification_report(
    y_true, pred,
    labels=all_labels,
    target_names=le_target.classes_,
    zero_division=0
))

                                              precision    recall  f1-score   support

                                      Benign       0.38      0.54      0.45        96
                        Benign/Likely benign       0.15      0.29      0.20        31
Conflicting classifications of pathogenicity       0.48      0.59      0.53       128
                               Likely benign       0.79      0.78      0.78       611
                           Likely pathogenic       0.22      0.42      0.29       101
                                  Pathogenic       0.81      0.73      0.77       542
                Pathogenic/Likely pathogenic       0.10      0.22      0.13        23
                      Uncertain significance       0.85      0.62      0.72       533

                                    accuracy                           0.67      2065
                                   macro avg       0.47      0.52      0.48      2065
                                weighted avg       0

In [42]:
# === SAVE MODEL === #
torch.save(model.state_dict(), "latest_dmd_model.pt")

In [32]:
# === LOAD & DEMO PREDICTION === #
# Load encoders
le_target = joblib.load("label_encoder_germline.pkl")
le_gene = joblib.load("label_encoder_Gene_list.pkl")
le_cond = joblib.load("label_encoder_Condition_list.pkl")
le_cons = joblib.load("label_encoder_Consequence_list.pkl")
le_allele = joblib.load("label_encoder_Allele_list.pkl")
le_vtype = joblib.load("label_encoder_Variant_type.pkl")
# Load model
model = DMDMultiInputModel().to(device)
model.load_state_dict(torch.load("best_dmd_model.pt"))
model.eval()

# Example sample for testing
def prepare_single_sample(sample_dict):
    llm_feat = torch.tensor(sample_dict['llm'], dtype=torch.float).unsqueeze(0).to(device)

    struct_feat = torch.tensor(sample_dict['structured'], dtype=torch.float).unsqueeze(0).to(device)
    gene_ids = list_to_index_tensor([sample_dict['Gene_list']], le_gene, 10).to(device)
    cond_ids = list_to_index_tensor([sample_dict['Condition_list']], le_cond, 5).to(device)
    cons_ids = list_to_index_tensor([sample_dict['Consequence_list']], le_cons, 5).to(device)
    allele_ids = list_to_index_tensor([sample_dict['Allele_list']], le_allele, 10).to(device)
    return llm_feat, struct_feat, gene_ids, cond_ids, cons_ids, allele_ids

# Sample prediction input
sample_input = {
    'llm': X_llm[0],
    'structured': X_struct[0],
    'Gene_list': df['Gene_list'].iloc[0],
    'Condition_list': df['Condition_list'].iloc[0],
    'Consequence_list': df['Consequence_list'].iloc[0],
    'Allele_list': df['Allele_list'].iloc[0]
}
print(sample_input)
llm_x, struct_x, gene_ids, cond_ids, cons_ids, allele_ids = prepare_single_sample(sample_input)
with torch.no_grad():
    pred = model(llm_x, struct_x, gene_ids, cond_ids, cons_ids, allele_ids)
    label = le_target.inverse_transform([torch.argmax(pred).item()])[0]
    print("Predicted Label:", label)


{'llm': array([ 0.24032561,  0.2860818 ,  0.23787445, ..., -0.47412282,
       -0.12398335, -0.52267736], dtype=float32), 'structured': array([0., 0., 0., 0., 0., 1., 2.], dtype=float32), 'Gene_list': array(['LOC130068431', 'LOC130068432', 'LOC130068480', ..., 'LINC01281',
       'LINC01282', 'LINC01283'], dtype=object), 'Condition_list': array(['Autism', 'Schizophrenia'], dtype=object), 'Consequence_list': array(['missense variant'], dtype=object), 'Allele_list': array(['481029'], dtype=object)}
Predicted Label: Pathogenic
