In [1]:
import re
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
import random
import pandas as pd
import numpy as np
import warnings
import torch
from transformers import BertTokenizer, BertLMHeadModel, AdamW, BertForSequenceClassification

In [2]:
def random_replace_residue(sequence):
    """
    purpose: randomly replace 50% of residues on the target sequence
    :param sequence: target sequence
    :return: new sequence with spaces
    """
    sequence_list = list(sequence)
    replace_indices = random.sample(range(len(sequence_list)), int(0.5 * len(sequence_list)))
    for index in replace_indices:
        sequence_list[index] = random.choice(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 
                                              'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
    new_sequence = ' '.join(sequence_list)
    return new_sequence


def random_mask_sequence(sequence):
    """
    purpose: mask off 50% of the residues on the target sequence
    :param sequence: target sequence
    :return: new sequence with spaces
    """
    masked_sequence = ''
    mask_positions = []
    for i, aa in enumerate(sequence):
        if random.uniform(0, 1) < 0.5:
            masked_sequence += ' [MASK]'
            mask_positions.append(i)
        else:
            masked_sequence += ' ' + aa
    return masked_sequence


def random_generate_sequence(sequence):
    """
    purpose: randomly generate a sequence as long as the target sequence
    :param sequence: target sequence
    :return: newly generate sequence with spaces
    """
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    new_sequence = ' '.join([random.choice(amino_acids) for _ in range(len(sequence))])
    return new_sequence


class Generator(nn.Module):
    """ Define Generator """
    def __init__(self):
        super().__init__()
        self.model = BertLMHeadModel.from_pretrained(model_name, return_dict=False).cuda()
    
    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        output_ids = torch.argmax(output[0], -1)
        return output_ids, output


class SequenceDataset(Dataset):
    """Create SequenceDataset Loader"""
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        real_seq = self.sequences[index]
        fake_seq = random_replace_residue(real_seq)

        real_seq = " ".join("".join(real_seq.split()))
        real_seq = re.sub(r"[UZOB]", "X", real_seq)
        real_seq_ids = tokenizer(real_seq, truncation=True, padding='max_length', max_length=70)
        real_sample = {key: torch.tensor(val) for key, val in real_seq_ids .items()}
        real_data = real_sample['input_ids']
        real_attention_mask = (real_sample['attention_mask'])

        fake_seq = re.sub(r"[UZOB]", "X", fake_seq)
        fake_seq_ids = tokenizer(fake_seq, truncation=True, padding='max_length', max_length=70)
        fake_sample = {key: torch.tensor(val) for key, val in fake_seq_ids.items()}
        fake_data = fake_sample['input_ids']
        fake_attention_mask = (fake_sample['attention_mask'])
        return (real_data, fake_data, fake_attention_mask)

In [3]:
model_name = "../Rostlab/prot_bert_bfd"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained(model_name)

# Load dataset and create Dataloader
file = '../dataset/CPPCase-2.csv'
seqs = [item for item in pd.read_csv(file)['Sequence'].values]
batch_size = 9
dataset = SequenceDataset(seqs)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

model = Generator()
model.load_state_dict(torch.load('/home/qfchen/CPPCGM/CPPGenerator/model/replace_generator.pt'))

model.eval()
generated_sequences = []
with torch.no_grad(): 
    for i, data in enumerate(dataloader):
        gene_data, gene_logits = model(data[1].to(device), data[2].to(device))
        decoded_seqs = tokenizer.batch_decode(gene_data, skip_special_tokens=True)
        generated_sequences.extend(decoded_seqs) 

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
  model.load_state_dict(torch.load('/home/qfchen/CPPCGM/CPPGenerator/model/replace_generator.pt'))


In [4]:
for i, seq in enumerate(generated_sequences):
    print(f"Generated peptide {i+1}: {seq.replace(' ', '')}")

Generated peptide 1: RRRPWWLKKLKKLLKWLLKWLKWLLKLLWLRPWWLKLLKWRRPWWLKWLKKRRRRRRRRPWRLLKLLKWL
Generated peptide 2: AFWYFHKKFHKKFKYFHYKHHYHHKFWYFHHYFHKKFHYFFHYFHKKHHKHAAAAAAAFYFAAAHHKFHK
Generated peptide 3: XMKKNKIKKKKKKKKKKKKKKKKIIKKKKKKKNKIKKKKKKKKNKIKKIKKKXXXXXKKKNXKKKIKKKK
Generated peptide 4: GIRKRWRRRRRRRRRRRKRWRRWRRIRRRRRKRWRRRRRRIRKRWRRWRRGGGGGGGGRRGGRRRRRRWR
Generated peptide 5: RYIAWVDLIQIIIIVWIVWVDLIIDWIDLVIVWVDLIIVIYIIWVDLIILIIRXRRRAIARRDVIIIIID


In [5]:
generated_sequences = [peptide.replace(' ', '') for peptide in generated_sequences]
start_len = 8

all_sequences = []
for peptide in generated_sequences:
    item = 0
    for length in range(start_len, len(peptide) + 1): 
        sequence = peptide[:length]
        if sequence not in seqs:
            all_sequences.append(sequence)
        print(f"Generated peptide {item + 1}: {sequence}")
        item += 1


df = pd.DataFrame(all_sequences, columns=["Peptide"])
df.to_csv('results/replace_generated_peptides_1.csv', index=False)

Generated peptide 1: RRRPWWLK
Generated peptide 2: RRRPWWLKK
Generated peptide 3: RRRPWWLKKL
Generated peptide 4: RRRPWWLKKLK
Generated peptide 5: RRRPWWLKKLKK
Generated peptide 6: RRRPWWLKKLKKL
Generated peptide 7: RRRPWWLKKLKKLL
Generated peptide 8: RRRPWWLKKLKKLLK
Generated peptide 9: RRRPWWLKKLKKLLKW
Generated peptide 10: RRRPWWLKKLKKLLKWL
Generated peptide 11: RRRPWWLKKLKKLLKWLL
Generated peptide 12: RRRPWWLKKLKKLLKWLLK
Generated peptide 13: RRRPWWLKKLKKLLKWLLKW
Generated peptide 14: RRRPWWLKKLKKLLKWLLKWL
Generated peptide 15: RRRPWWLKKLKKLLKWLLKWLK
Generated peptide 16: RRRPWWLKKLKKLLKWLLKWLKW
Generated peptide 17: RRRPWWLKKLKKLLKWLLKWLKWL
Generated peptide 18: RRRPWWLKKLKKLLKWLLKWLKWLL
Generated peptide 19: RRRPWWLKKLKKLLKWLLKWLKWLLK
Generated peptide 20: RRRPWWLKKLKKLLKWLLKWLKWLLKL
Generated peptide 21: RRRPWWLKKLKKLLKWLLKWLKWLLKLL
Generated peptide 22: RRRPWWLKKLKKLLKWLLKWLKWLLKLLW
Generated peptide 23: RRRPWWLKKLKKLLKWLLKWLKWLLKLLWL
Generated peptide 24: RRRPWWLKKLKKLLKWLLKWL