In [2]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BertConfig,
    BertForSequenceClassification
)

class LLMEmbeddingGenerator:
    def __init__(self, batch_size=8, max_len=512, kmer_size=6):
        self.batch_size = batch_size
        self.max_len = max_len
        self.kmer_size = kmer_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _save_embeddings(self, df, emb_tensor, id_col, output_csv):
        emb_cols = [f"emb_{i}" for i in range(emb_tensor.shape[1])]
        emb_df = pd.DataFrame(emb_tensor, columns=emb_cols)
        emb_df.insert(0, id_col, df[id_col].values)
        emb_df.to_csv(output_csv, index=False)
        print(f"Saved embeddings to {output_csv}")

    def generate_embeddings_NT(self, model_dir, input_tsv, output_csv, id_col="VariationID", seq_col="Sequence"):
        print(f"=== [Nucleotide Transformer] {input_tsv} → {output_csv}")
        df = pd.read_csv(input_tsv, sep="\t", dtype={seq_col: str})
        print(f"Loaded {len(df)} sequences")
        tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
        model = AutoModelForSequenceClassification.from_pretrained(model_dir, local_files_only=True)
        model.to(self.device).eval()

        all_embeddings = []
        for i in tqdm(range(0, len(df), self.batch_size)):
            batch_seqs = df[seq_col].iloc[i:i+self.batch_size].tolist()
            inputs = tokenizer(
                batch_seqs,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_len,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True, return_dict=True)
                last_hidden = outputs.hidden_states[-1]
                emb_batch = last_hidden.mean(dim=1).cpu().numpy()
            all_embeddings.append(emb_batch)

        emb_tensor = np.vstack(all_embeddings)
        print("Embedding shape:", emb_tensor.shape)
        self._save_embeddings(df, emb_tensor, id_col, output_csv)

    def generate_embeddings_DNABERT_6(self, model_dir, input_tsv, output_csv, id_col="VariationID", seq_col="Sequence"):
        print(f"=== [DNABERT-6] {input_tsv} → {output_csv}")
        def seq_to_kmers(seq, k=self.kmer_size):
            seq = str(seq).strip().upper()
            if len(seq) < k:
                return seq
            return " ".join(seq[i:i+k] for i in range(len(seq) - k + 1))

        df = pd.read_csv(input_tsv, sep="\t", dtype={seq_col: str})
        print(f"Loaded {len(df)} sequences")
        df["kmers"] = df[seq_col].apply(seq_to_kmers)

        config = BertConfig.from_json_file(f"{model_dir}/config.json")
        tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
        model = BertForSequenceClassification.from_pretrained(model_dir, config=config, local_files_only=True)
        model.to(self.device).eval()

        all_embeddings = []
        for i in tqdm(range(0, len(df), self.batch_size)):
            batch_seqs = df["kmers"].iloc[i:i+self.batch_size].tolist()
            inputs = tokenizer(
                batch_seqs,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_len,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True, return_dict=True)
                last_hidden = outputs.hidden_states[-1]
                emb_batch = last_hidden.mean(dim=1).cpu().numpy()
            all_embeddings.append(emb_batch)

        emb_tensor = np.vstack(all_embeddings)
        print("Embedding shape:", emb_tensor.shape)
        self._save_embeddings(df, emb_tensor, id_col, output_csv)

    def generate_embeddings_GROVER(self, model_dir, input_tsv, output_csv, id_col="VariationID", seq_col="Sequence"):
        print(f"=== [GROVER] {input_tsv} → {output_csv}")
        df = pd.read_csv(input_tsv, sep="\t", dtype={seq_col: str})
        print(f"Loaded {len(df)} sequences")
        tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
        model = AutoModelForSequenceClassification.from_pretrained(model_dir, local_files_only=True)
        model.to(self.device).eval()

        all_embeddings = []
        for i in tqdm(range(0, len(df), self.batch_size)):
            batch_seqs = df[seq_col].iloc[i:i+self.batch_size].tolist()
            inputs = tokenizer(
                batch_seqs,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_len,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True, return_dict=True)
                last_hidden = outputs.hidden_states[-1]
                emb_batch = last_hidden.mean(dim=1).cpu().numpy()
            all_embeddings.append(emb_batch)

        emb_tensor = np.vstack(all_embeddings)
        print("Embedding shape:", emb_tensor.shape)
        self._save_embeddings(df, emb_tensor, id_col, output_csv)

In [4]:
# ==========================
# Create object
# ==========================
generator = LLMEmbeddingGenerator(batch_size=8, max_len=512, kmer_size=6)

In [11]:
# Train embeddings
generator.generate_embeddings_NT(
    model_dir="./finetuned_models/nucleotide_transformer_pathogenic_classifier_225",
    input_tsv="./data/windows_225/clinvar_binary_train_225.tsv",
    output_csv="./data/embeddings/clinvar_binary_train_embeddings_NT_225.csv"
)

=== [Nucleotide Transformer] ./data/windows_225/clinvar_binary_train_225.tsv → ./data/embeddings/clinvar_binary_train_embeddings_NT_225.csv
Loaded 30000 sequences


  return torch.load(checkpoint_file, map_location="cpu")
100%|██████████| 3750/3750 [23:53<00:00,  2.62it/s]


Embedding shape: (30000, 1280)
Saved embeddings to ./data/embeddings/clinvar_binary_train_embeddings_NT_225.csv


In [12]:
generator.generate_embeddings_NT(
    model_dir="./finetuned_models/nucleotide_transformer_pathogenic_classifier_225",
    input_tsv="./data/windows_225/clinvar_binary_test_225.tsv",
    output_csv="./data/embeddings/clinvar_binary_test_embeddings_NT_225.csv"
)

=== [Nucleotide Transformer] ./data/windows_225/clinvar_binary_test_225.tsv → ./data/embeddings/clinvar_binary_test_embeddings_NT_225.csv
Loaded 3000 sequences


  return torch.load(checkpoint_file, map_location="cpu")
100%|██████████| 375/375 [02:23<00:00,  2.62it/s]


Embedding shape: (3000, 1280)
Saved embeddings to ./data/embeddings/clinvar_binary_test_embeddings_NT_225.csv


In [7]:
generator.generate_embeddings_GROVER(
    model_dir="./finetuned_models/grover_pathogenic_classifier_225",
    input_tsv="./data/windows_225/clinvar_binary_train_225.tsv",
    output_csv="./data/embeddings/clinvar_binary_train_embeddings_GROVER_225.csv"
)

=== [GROVER] ./data/windows_225/clinvar_binary_train_225.tsv → ./data/embeddings/clinvar_binary_train_embeddings_GROVER_225.csv
Loaded 30000 sequences


  return torch.load(checkpoint_file, map_location="cpu")
100%|██████████| 3750/3750 [05:12<00:00, 12.01it/s]


Embedding shape: (30000, 768)
Saved embeddings to ./data/embeddings/clinvar_binary_train_embeddings_GROVER_225.csv


In [8]:
# Test embeddings
generator.generate_embeddings_GROVER(
    model_dir="./finetuned_models/grover_pathogenic_classifier_225",
    input_tsv="./data/windows_225/clinvar_binary_test_225.tsv",
    output_csv="./data/embeddings/clinvar_binary_test_embeddings_GROVER_225.csv"
)


=== [GROVER] ./data/windows_225/clinvar_binary_test_225.tsv → ./data/embeddings/clinvar_binary_test_embeddings_GROVER_225.csv
Loaded 3000 sequences


  return torch.load(checkpoint_file, map_location="cpu")
100%|██████████| 375/375 [00:31<00:00, 11.75it/s]


Embedding shape: (3000, 768)
Saved embeddings to ./data/embeddings/clinvar_binary_test_embeddings_GROVER_225.csv


In [9]:
generator.generate_embeddings_DNABERT_6(
    model_dir="./finetuned_models/dnabert6_pathogenic_classifier_225",
    input_tsv="./data/windows_225/clinvar_binary_train_225.tsv",
    output_csv="./data/embeddings/clinvar_binary_train_embeddings_DNABERT6_225.csv"
)

=== [DNABERT-6] ./data/windows_225/clinvar_binary_train_225.tsv → ./data/embeddings/clinvar_binary_train_embeddings_DNABERT6_225.csv
Loaded 30000 sequences


  return torch.load(checkpoint_file, map_location="cpu")
100%|██████████| 3750/3750 [05:25<00:00, 11.53it/s]


Embedding shape: (30000, 768)
Saved embeddings to ./data/embeddings/clinvar_binary_train_embeddings_DNABERT6_225.csv


In [7]:
generator.generate_embeddings_DNABERT_6(
    model_dir="./finetuned_models/dnabert6_pathogenic_classifier_225",
    input_tsv="./data/windows_225/clinvar_binary_test_225.tsv",
    output_csv="./data/embeddings/clinvar_binary_test_embeddings_DNABERT6_225.csv"
)

=== [DNABERT-6] ./data/windows_225/clinvar_binary_test_225.tsv → ./data/embeddings/clinvar_binary_test_embeddings_DNABERT6_225.csv
Loaded 3000 sequences


  return torch.load(checkpoint_file, map_location="cpu")
100%|██████████| 375/375 [00:32<00:00, 11.52it/s]


Embedding shape: (3000, 768)
Saved embeddings to ./data/embeddings/clinvar_binary_test_embeddings_DNABERT6_225.csv
