In [None]:
# Load dataset
import datasets
ds = datasets.load_dataset("damlab/uniprot")
ds["train"]["sequence"][0] # Inspect the sequence data

In [None]:
ds

In [None]:
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from transformers import AdamW, pipeline, RobertaTokenizerFast, DataCollatorForLanguageModeling
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, PreTrainedTokenizerFast
import torch
from tqdm.auto import tqdm

In [None]:
# Make iterator for tokenizing
def generate_iter():
    for i in range(0,ds['train'].num_rows,10000):
        seq = ds['train']['sequence'][i:i + 10000]
        yield seq
vocab = ['A','R','N','D','C','E','Q','G','H','I','L','K','M','F','P','S','T','W','Y','V']
corpus = generate_iter()

In [None]:
tokenizer = ByteLevelBPETokenizer(vocab=vocab)
tokenizer.train_from_iterator(corpus,show_progress=True,vocab_size=100,
                              special_tokens=["<s>","<pad>","</s>","<unk>","<mask>"])
tokenizer.save_model(".", "ast3")


In [None]:
# Process tokenizer for Roberta
from tokenizers.implementations import ByteLevelBPETokenizer
tokenizer = ByteLevelBPETokenizer("ast3-vocab.json","ast3-merges.txt")
tokenizer._tokenizer.post_processor = BertProcessing(("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)
print(tokenizer.get_vocab_size())
tokenizer.enable_truncation(max_length=256)
tokenizer.enable_padding()
tokenizer.save('token/ast3')
tokenizer.save_model('token','ast3')
tokenizer = RobertaTokenizerFast(vocab_file="token/ast3-vocab.json",merges_file="token/ast3-merges.txt")
tokenizer.save_pretrained('token/')

In [None]:
tokenizer = RobertaTokenizerFast.from_pretrained('token/')
tokenizer.vocab_size

In [None]:
inputs = tokenizer.encode_plus(ds['train']['sequence'][1000],max_length=256,
                               truncation=True,padding='max_length') # test tokenizer
inputs

In [None]:
## <mask> is 4, <pad> is 1

In [None]:
## Data loader class for pytorch taken directly from a website
class Dataset(torch.utils.data.Dataset):
    def __init__(self, ds, tokenizer, i_start, i_end):
        self.encodings = []
#         self.masks = []
        for seq in ds['train']['sequence'][i_start:i_end]:
            seq_encoded = tokenizer.encode_plus(seq, max_length = 256, truncation=True, padding='max_length')
            self.encodings += [seq_encoded.input_ids]
#             self.masks += [seq_encoded.attention_mask]

    def __len__(self):
        # return the number of samples
        return len(self.encodings)

    def __getitem__(self, i):
        # return dictionary of input_ids, attention_mask, and labels for index i
        return torch.tensor(self.encodings[i])

trdataset = Dataset(ds,tokenizer,0,10000)
evdataset = Dataset(ds, tokenizer, 10000, 11000)
loader = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15)

In [None]:
evdataset[0]

In [None]:
tokenizer.vocab_size

In [None]:
# config = BertConfig(vocab_size=100,max_position_embeddings=256)
# model = BertForMaskedLM(config)
config = RobertaConfig(vocab_size=tokenizer.vocab_size, max_position_embeddings=258, 
                       type_vocab_size=1, hidden_size=768, num_attention_heads=12, 
                       num_hidden_layers=6,) ## Sneaky max pos embedding = max len+2
model = RobertaForMaskedLM(config)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

In [None]:
from transformers import Trainer, TrainingArguments
targs = TrainingArguments(output_dir='ast',overwrite_output_dir=True,evaluation_strategy='epoch',
                         num_train_epochs=2,learning_rate=1e-2,per_device_train_batch_size=64,
                         per_device_eval_batch_size=32, save_total_limit=1)
trainer = Trainer(model=model,args=targs,data_collator=loader,train_dataset=trdataset,eval_dataset=evdataset)

In [None]:
trainer.train()

In [None]:
model.save_pretrained('./ast')

In [None]:
fill = pipeline('fill-mask', model='./ast', tokenizer='token/', config=config)

In [None]:
fill(f'MAFSAE<mask>VLKEYDRRRRMEALLLSLYYP')

In [None]:
## Model performance not looking great

In [None]:
model = RobertaForMaskedLM.from_pretrained('./ast')

In [None]:
ids = []
masks = []
for seq in vocab:
    seq_encoded = tokenizer.encode_plus(seq, 
                            max_length = 256, truncation=True, padding='max_length')
    ids.append(seq_encoded.input_ids)
    masks.append(seq_encoded.attention_mask)
ids = torch.tensor(ids)
masks = torch.tensor(masks)

In [None]:
## Getting vocab embeddings 

In [None]:
out = model(ids,masks)

In [None]:
lhs = out[0]
lhs.shape

In [None]:
cls = lhs[:,0,:].detach()
cls.shape

In [None]:
from sklearn.decomposition import  PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

In [None]:
pca = PCA(n_components=2)
vis = pca.fit_transform(cls)
vis.shape

In [None]:
## Setting colors and markers same as example figure
pcharge = ['K','R','H']
ncharge = ['D','E']
hydrophobic = ['A','I','L','M','V']
aromatic = ['F','W','Y']
polar = ['S','T','N','Q','H']
unique = ['C','G','P']
small = ['A','G','P','S','T','V']
med = ['C','I','M','L','N','Q','K','D','E']
large = ['H','R','F','W','Y']

In [None]:
plt.figure(figsize=(6,4))
for j in range(20):
    aa = vocab[j]
    if aa in pcharge:
        color = 'red'
        marker = 's'
        label = 'Positively charged'
    elif aa in ncharge:
        color='red'
        marker='x'
        label = 'Negatively charged'
    elif aa in hydrophobic:
        color='green'
        marker = 'o'
        label = 'Hydrophobic'
    elif aa in aromatic:
        color='green'
        marker='+'
        label = 'Aromatic'
    elif aa in polar:
        color='blue'
        marker='o'
        label = 'Polar'
    elif aa in unique:
        color='orange'
        marker='o'
        label = 'Unique'
    if aa in polar and aa in pcharge:
        color='purple'
        label = aa
    if aa in small:
        s = 30
    elif aa in med:
        s = 60
    elif aa in large:
        s = 90
    plt.scatter(vis[j,0],vis[j,1],label=label,color=color,marker=marker,s=s)
plt.xlabel('PCA 0')
plt.ylabel('PCA 1')
h,l = plt.gca().get_legend_handles_labels()
ln = np.unique(l)
hn = []
lnn = []
for li in ln:
    print(li)
    hn.append(h[np.where(np.array(l)==li)[0][0]])
    lnn.append(li)
plt.legend(handles=hn,labels=lnn,loc=(1.01,0.01))