In [14]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from peft import LoraConfig, get_peft_model

from transformers.models.esm.tokenization_esm import EsmTokenizer

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

import pandas as pd

import random
import numpy as np

In [None]:
def set_seed(seed: int=7) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(7)

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

DATA_PATH = "./Data/"
df = pd.read_csv(DATA_PATH + "fine_tuning.csv")

ref_seq = df["reference_seq"]
var_seq = df["variant_seq"]
label = df["label"]

In [4]:
max_seq_len = max(ref_seq.str.len().max(), var_seq.str.len().max())
print(f"Rows = {len(df):,}, Max Sequence Length = {max_seq_len}")

Rows = 397,182, Max Sequence Length = 512


In [5]:
BATCH_SIZE = 64

MODEL_ID = "InstaDeepAI/nucleotide-transformer-v2-100m-multi-species"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID, trust_remote_code=True)

MODEL_CAP = tokenizer.model_max_length
MAX_LEN = min(MODEL_CAP, max_seq_len)

In [15]:
type(model)

transformers_modules.InstaDeepAI.nucleotide-transformer-v2-100m-multi-species.f34324c6fde36a4f635f0f1f06cac5d25acd6798.modeling_esm.EsmForMaskedLM

In [None]:
class SiameseDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: EsmTokenizer, max_len: int) -> None:
        self.ref_seq = df["reference_seq"].tolist()
        self.var_seq = df["variant_seq"].tolist()
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.ref_seq)
    
    def __getitem__(self, idx: int) -> dict:
        return {"ref_seq": self.ref_seq[idx], "var_seq" : self.var_seq[idx]}
    
def collate_fn(batch: torch.tensor, tok: EsmTokenizer=tokenizer, max_len: int=MAX_LEN) -> dict:
    ref_seq = [b["ref_seq"] for b in batch]
    var_seq = [b["var_seq"] for b in batch]
    
    ref_enc = tok.batch_encode_plus(
        ref_seq,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=max_len
    )

    var_enc = tok.batch_encode_plus(
        var_seq,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=max_len
    )

    return {
        "ref_input_ids": ref_enc["input_ids"],
        "ref_attention_mask": ref_enc["attention_mask"],
        "var_input_ids": var_enc["input_ids"],
        "var_attention_mask": var_enc["attention_mask"]
    }

In [8]:
dataset = SiameseDataset(df, tokenizer, MAX_LEN)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,
                    collate_fn=collate_fn)

In [11]:
model.config.hidden_size

512

In [16]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["query", "key", "value"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
class BackboneModel(nn.Module):
    def __init__(self, backbone, reconstruction_dim: int=2048) -> None:
        super().__init__()
        self.backbone = backbone
        hidden_size = backbone.config.hidden_size
        self.reconstruction_layer = nn.Linear(hidden_size, reconstruction_dim)

    def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
        outs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        last_hidden = outs.hidden_states[-1]
        mask_exp = attention_mask.unsqueeze(-1)

        summed = (last_hidden * mask_exp).sum(dim=1)
        counts = mask_exp.sum(dim=1).clam(min=1)
        seq_emb = summed / counts

        seq_emb = self.reconstruction_layer(seq_emb)
        
        return seq_emb

In [None]:
class SiameseModel(nn.Module):
    def __init__(self, backbone, reconstruction_dim: int=2048) -> None:
        super().__init__()
        self.encoder = BackboneModel(backbone, reconstruction_dim)

    def forward(self, ref_input_ids, ref_attention_mask,
                      var_input_ids, var_attention_mask):
        ref_emb = self.encoder(ref_input_ids, ref_attention_mask)
        var_emb = self.encoder(var_input_ids, var_attention_mask)

        return ref_emb, var_emb