## Sequence Generate

We take BA.2.1 as parent node to conduct PLM finetuning and sequence generating. We provide a sampled initial sequence set with 10000 sequences for a quick demo.

In [None]:
!git clone https://github.com/Kevinatil/GenPreMut.git

In [None]:
!pip install datasets

In [None]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset as Dataset_

In [None]:
# collator from sequence_generate/sequence_generate/collator.py
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from collections.abc import Mapping

import numpy as np
import torch

from transformers.data.data_collator import PreTrainedTokenizerBase


@dataclass
class DataCollatorForMaskedGeneration:

    tokenizer: PreTrainedTokenizerBase
    mlm_probability: Any = 0.15
    max_mask: int = 5
    device: Any = 'cuda'
    pad_to_multiple_of: Optional[int] = None
    tf_experimental_compile: bool = False
    return_tensors: str = "pt"


    def __call__(self, features):
        return self.torch_call(features)

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)

        special_tokens_mask = batch.pop("special_tokens_mask", None)
        input_ids, len_ = self.torch_mask_tokens(
            batch["input_ids"], special_tokens_mask=special_tokens_mask
        )
        batch["token_ids"] = batch["input_ids"] # no mask
        batch["input_ids"] = input_ids # has mask
        for key, value in batch.items():
            batch[key] = value[:len_].to(self.device)
        return batch

    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        while True:
            if type(self.mlm_probability) == float:
                # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
                probability_matrix = torch.full(inputs.shape, self.mlm_probability)
            else:
                probability_matrix = self.mlm_probability.repeat(inputs.shape[0], 1)

            if special_tokens_mask is None:
                special_tokens_mask = [
                    self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in inputs.tolist()
                ]
                special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
            else:
                special_tokens_mask = special_tokens_mask.bool()

            probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
            masked_indices = torch.bernoulli(probability_matrix).bool()
            # labels[~masked_indices] = -100  # We only compute loss on masked tokens

            inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

            masked_indices = (masked_indices.sum(dim=1) < self.max_mask) & (masked_indices.sum(dim=1) >= 1)

            if masked_indices.any():
                return inputs[masked_indices], masked_indices.sum().item()


In [None]:
rbd_name = 'BA.2.1'


data_root = 'GenPreMut/data'
model_root = 'GenPreMut/ckpt'

save_folder = os.path.join(model_root, "finetune")
batch_size = 2

init_data_path = os.path.join(data_root, "finetune/sample_{}.txt".format(rbd_name))

In [None]:
# model finetuning
def get_training_sequences(path):
    seqs = []
    f = open(path, 'r')
    for line in f:
        seqs.append(line.strip())

    return seqs


tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)


train_sequences = get_training_sequences(init_data_path)
train_tokenized = tokenizer(train_sequences)
train_dataset = Dataset_.from_dict(train_tokenized)
train_args = TrainingArguments(
        output_dir=save_folder,
        save_strategy = "epoch",
        learning_rate=1e-4,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=1,
        weight_decay=0.01,
        warmup_steps=1000,
        report_to="none",
)

trainer = Trainer(
    model,
    train_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
trainer.train()

In [None]:
rbd_dict={
    'BA.2.1':   'NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVIRGNEVSQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYGFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKST',
    'BA.5.1':   'NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVIRGNEVSQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVGGNYNYRYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGVNCYFPLQSYGFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKST',
    'XBB.1.5':  'NITNLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVIRGNEVSQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKPSGNYNYLYRLFRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKST',
}
mutation_dict={
    'BA.2.1':   ['G339D','S371F','S373P','S375F','T376A','D405N','R408S','K417N','N440K','S477N','T478K','E484A','Q493R','Q498R','N501Y','Y505H'],
    'BA.5.1':   ['G339D','S371F','S373P','S375F','T376A','D405N','R408S','K417N','N440K','L452R','S477N','T478K','E484A','F486V','Q498R','N501Y','Y505H'],
    'XBB.1.5':  ['G339H','R346T','L368I','S371F','S373P','S375F','T376A','D405N','R408S','K417N','N440K','V445P','G446S','N460K','S477N','T478K','E484A','F486P','F490S','Q498R','N501Y','Y505H'],
}

def get_index(line):
    return int(re.findall(r'[A-Z]([0-9]+)[A-Z]',line)[0])
def get_mut(line):
    return re.findall(r'[A-Z][0-9]+([A-Z])',line)[0]

def numpy_mask_tokens(inputs, probility_mutation, mask_token_id):

    masked_indices = np.random.binomial(1, probility_mutation, size=probility_mutation.shape).astype(bool)
    masked_lm_positions = np.where(masked_indices == True)[0]
    inputs[masked_lm_positions] = mask_token_id
    return inputs, masked_lm_positions


class GenerateDataset(Dataset):
    def __init__(self, num_samples, rbd_seq, tokenizer):
        super().__init__()
        self.num_samples = num_samples
        self.rbd_seq = rbd_seq
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        return self.tokenizer(self.rbd_seq)

    def __len__(self):
        return self.num_samples

def dump(seqs, name, path):
    file_name = os.path.join(path, "mutation_{}_{}.txt".format(rbd_name, name))
    print('seq num: {}, save path: {}'.format(len(seqs), file_name))
    f = open(file_name, "w")
    for seq in seqs:
        f.write(seq + "\n")
    f.close()


In [None]:
def get_latest_ckpt(path):
    ckpts =os.listdir(path)
    max_ = 0
    for ckpt in ckpts:
        find = re.findall(r'checkpoint-([0-9]+)', ckpt)
        if len(find):
            max_ = max(max_, int(find[0]))
    return max_

In [None]:
# sequence generate
device = torch.device("cuda")
ft_path = os.path.join(model_root, 'finetune/checkpoint-{}'.format(get_latest_ckpt(os.path.join(model_root, 'finetune'))))
site_freq_path = os.path.join(model_root, 'site_mutation_frequency/{}_mutation_frequency_203.npy'.format(rbd_name))
mutation_save_path = os.path.join(data_root, 'raw_seqs')

total_number = 1000
step = 100

tokenizer = AutoTokenizer.from_pretrained(ft_path)
model = AutoModelForMaskedLM.from_pretrained(ft_path).to(device).eval()

max_len = 203
max_mask = 5
topk = 10
batch_size = 4

rbd_seq = rbd_dict[rbd_name]
rbd_id = np.array(tokenizer(rbd_seq)['input_ids'])

save_steps=np.arange(0, total_number+1, step)[1:]
os.makedirs(mutation_save_path, exist_ok=True)

probility_mutation = np.load(site_freq_path)
collator = DataCollatorForMaskedGeneration(tokenizer, torch.tensor(probility_mutation), max_mask, device=device)

dataset = GenerateDataset(num_samples=total_number*100, 
                          rbd_seq=rbd_seq, 
                          tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collator)

with torch.no_grad():
    output_seq = set()
    output_seq_tp = set()
    process = 0
    for i, data in enumerate(dataloader):
        if i % 500 == 0:
            print('>>>>> {} loops, current sequence num: {}'.format(i, len(output_seq)))
        token_ids_ = data['token_ids'].cpu()
        data.pop('token_ids')

        out = model(**data)
        indices = torch.topk(out['logits'], topk, dim = -1).indices.cpu()
        bs = indices.shape[0]
        indices = indices.reshape(-1, topk)
        for _ in range(20):
            token_ids = token_ids_.clone()
            index_ran = np.random.randint(0, topk, size=(indices.shape[0]))
            predict_id = indices[range(indices.shape[0]), index_ran]
            predict_id = predict_id.reshape(bs, -1)

            mask_ = (token_ids == tokenizer.mask_token_id)&(predict_id >= 4)&(predict_id <= 23)
            token_ids[mask_] = predict_id[mask_]

            sequences = tokenizer.batch_decode(token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

            for sequence in sequences:
                sequence = re.sub(r'\s', '', sequence)
                if sequence not in output_seq:
                    output_seq.add(sequence)
                    output_seq_tp.add(sequence)

                if process >= len(save_steps):
                    break
                if len(output_seq) >= save_steps[process]:
                    print('process {}, output_seq: {}, output_seq_tp: {}'.format(process, len(output_seq), len(output_seq_tp)))
                    dump(output_seq_tp, process, mutation_save_path)
                    process+=1
                    output_seq_tp=set()

            if process >= len(save_steps):
                break
        if process >= len(save_steps):
            break
    print(total_number, len(output_seq))

In [None]:
def _check(path, num):
    raw = rbd_dict['BA.2.1']
    nums = np.zeros(len(raw))
    seqs = set()
    for i in range(num):
        path_ = os.path.join(path, 'mutation_BA.2.1_{}.txt'.format(i))
        f = open(path_, 'r')
        for line in f:
            nums += (np.array(list(line.strip())) != np.array(list(raw)))
            seqs.add(line.strip())
    plt.title('site mutation frequency of generated sequences')
    plt.plot(nums)


In [None]:
# site mutation frequency
%matplotlib inline
plt.title('site mutation frequency of initial sequences')
plt.plot(np.load(site_freq_path))

In [None]:
_check(mutation_save_path, 10)