In [1]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
import random
from selenium import webdriver
from selenium.webdriver.common.by import By
import os
import time
from tqdm import tqdm
from selenium.webdriver.firefox.options import Options
import math
from Bio.Blast.Applications import NcbiblastpCommandline
import subprocess
from collections import Counter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import py3Dmol
import glob
from sklearn.manifold import TSNE
import seaborn as sns
from evaluate import load
from torch.nn import CrossEntropyLoss
from evaluate import logging

  from .autonotebook import tqdm as notebook_tqdm


## Create DS __(2/3)__

In [None]:
device = torch.device("cuda:1")

In [3]:
tokenizer = AutoTokenizer.from_pretrained('/agh/projects/noelia/NLP/zymCTRL/dataset_preparation/tokenizer')

In [4]:
model = GPT2LMHeadModel.from_pretrained('..').to(device)

In [5]:
input_ids = tokenizer.encode('1.1.1.1',return_tensors='pt').to(device)

In [96]:
output = model.generate(
    input_ids, 
    top_k=950,
    repetition_penalty=1.2,
    max_length=2000,
    eos_token_id=1,
    pad_token_id=0,
    do_sample=True,
    num_return_sequences=10,
)

In [97]:
print("Output:\n" + 100 * '-')
for i in range(len(output)):
    print(tokenizer.decode(output[i]));

Output:
----------------------------------------------------------------------------------------------------
1. 1. 1. 1 <sep> <start> M K G A V L H E F G H P W Q I K E T D Q P I P G P G Q A L V R I V A S G I C H S D T H V V R G D D A E V C Q T A G R Q G P P V A L M P A V L G H E I V G E V V P T G P H T V R R K V P A C G K C H P C S T D N E H Q T L C R A F A P D T L D G T Y R R K P H T P L P F A L G G D A A L A E Y C L L N P A T T F E V P P K L R P D L V P P G C R A D V A G L L A T P Y I G V Y G P E A V R L G V R Y E N A L A V I G L G G I G Q C A I K I V Q M A G G R V V V I D R N P E N L A L A A E T L P K A E V L T L N G S R G N G N N P R Y R E L M G G L K A P R I M V Q T A T H A A P L H F Y N A L G D L T I S V T A S V S Q P W G S A P A D L N M I L P M L C E R R V Q G S L I G A Y E P L P L I K G F I A Q H R I G P L E K Y G L D R V V S C A E M N A A N E A F G K L A S G H A T R I R I L D <end> <|endoftext|> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad

In [98]:
predictions = [tokenizer.decode(output[i]).replace(' ','') for i in range(10)]
del output

In [99]:
max_tokenized_len = model.config.max_length
ppls = []
loss_fct = CrossEntropyLoss(reduction="none")
batch_size=10
encoded_texts = encodings["input_ids"]
attn_masks = encodings["attention_mask"]
add_start_token = False

if tokenizer.pad_token is None and batch_size > 1:
    existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
    # check that the model already has at least one special token defined
    assert (
        len(existing_special_tokens) > 0
    ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
    # assign one of the special tokens to also be the pad token
    tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

if add_start_token:
    # leave room for <BOS> token to be added:
    assert (
        tokenizer.bos_token is not None
    ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
    max_tokenized_len = model.config.max_length - 1
else:
    max_tokenized_len = model.config.max_length

encodings = tokenizer(
    predictions,
    add_special_tokens=False,
    padding=True,
    truncation=True,
    max_length=2000,
    return_tensors="pt",
    return_attention_mask=True).to(device)
        
for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
        end_index = min(start_index + batch_size, len(encoded_texts))
        encoded_batch = encoded_texts[start_index:end_index]
        attn_mask = attn_masks[start_index:end_index]

        if add_start_token:
            bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
            encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
            attn_mask = torch.cat(
                [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
            )

        labels = encoded_batch

        with torch.no_grad():
            out_logits = model(encoded_batch, attention_mask=attn_mask).logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        perplexity_batch = torch.exp(
            (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
            / shift_attention_mask_batch.sum(1)
        )

        ppls += perplexity_batch.tolist()
        del out_logits

ppl = {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.19s/it]


In [105]:
special_tokens = ['<start>', '<end>', '<|endoftext|>','<pad>',' ', '<sep>']

In [102]:
def remove_characters(str, chars_list):
    for char in chars_list:
        if char == '<sep>':
            str = str.replace(char, ' ')
        else:
            str = str.replace(char, '')
    return str

## Generation à la carte? ProteInfer __(3.1)__ and save DS(2)

In [103]:
# don't forget to install geckodriver
def proteinfer(seq):
    
    options = Options()
    options.add_argument('--disable-blink-features=AutomationControlled')
    url = 'https://google-research.github.io/proteinfer/'
    os.environ['MOZ_HEADLESS'] = '1'
    driver = webdriver.Firefox()
    driver.get(url)
    item = driver.find_element('id','yourseq')
    item.click()
    time.sleep(5)
    item = driver.find_element('id','input_seq')
    item.send_keys(seq)
    time.sleep(20)
    item = driver.find_elements(By.CLASS_NAME,'top-figure-link')
    out = list(filter(None, [i.text for i in item]))
    driver.quit()
    
    return out

In [None]:
# use the one with best perplexity
os.environ["TOKENIZERS_PARALLELISM"] = "false"
inf_res = {}
_ = remove_characters(tokenizer.decode(encodings["input_ids"][ppl['perplexities'].index(min(ppl['perplexities']))]),special_tokens).split()
seq = _[1]
cond_tok = _[0]
proteinfer_an = proteinfer(seq)
inf_res[seq] = (cond_tok, [x for x in proteinfer_an if x.startswith('EC')],
                [x for x in proteinfer_an if x.startswith('GO')])

In [41]:
# just for all
os.environ["TOKENIZERS_PARALLELISM"] = "false"
inf_res = {}
for i in tqdm(range(len(output))):
    _ = remove_characters(tokenizer.decode(output[i]),special_tokens).split()
    seq = _[1]
    cond_tok = _[0]
    proteinfer_an = proteinfer(seq)
    inf_res[seq] = (cond_tok, [x for x in proteinfer_an if x.startswith('EC')],
                    [x for x in proteinfer_an if x.startswith('GO')])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [07:27<00:00, 44.78s/it]


In [44]:
outdir = '/Users/sebastianlindner/Documents/Projects/Noelia_zymCTRL/'

In [45]:
with open(os.path.join(outdir,'out.fasta'), 'w') as fn:
    for i in inf_res.items():
        fn.write(f'>{*i[1][0],}|{*i[1][1],}|{*i[1][2],}\n{i[0]}\n')

## General quality assesment Structure Prediction __(2.3)__

In [None]:
#colabfold_batch <directory_with_fasta_files> <result_dir> 
def omega_fold(outdir):
    stream = subprocess.Popen(['omegafold', os.path.join(outdir,'out.fasta'), outdir,'--subbatch_size','27'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    while(True):
            retcode = stream.poll() 
            line = stream.stdout.readline()
            yield line
            if retcode is not None:
                return stream.stderr
            
for line in omega_fold(outdir):
                    print(line)

In [71]:
pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):
  
    if chains is None:
        chains = 1 if Ls is None else len(Ls)
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
    if animate:
        view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
    else:
        view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
    if color == "pLDDT":
        view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
    elif color == "rainbow":
        view.setStyle({'cartoon': {'color':'spectrum'}})
    elif color == "chain":
        for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
            view.setStyle({'chain':chain},{'cartoon': {'color':color}})
    if show_sidechains:
        BB = ['C','O','N']
        HP = ["ALA","GLY","VAL","ILE","LEU","PHE","MET","PRO","TRP","CYS","TYR"]
        view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
        view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
        view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
    if show_mainchains:
        BB = ['C','O','N','CA']
        view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
        view.zoomTo()
    if animate: view.animate()
    return view

In [None]:
files = [file for file in glob.glob(f"{outdir}/*.pdb")]
for i in files:
    lens = [len(i) for i in inf_res.keys()]
    pdb_str = open(os.path.join(os.getcwd(),i),'r').read()
    color = "confidence"
    if color == "confidence": color = "pLDDT"
    show_sidechains = False
    show_mainchains = False
    show_pdb(pdb_str, color=color, show_sidechains=show_sidechains, show_mainchains=show_mainchains,
             Ls=lens).show()

## General quality assesment Homology __(2.4)__

In [20]:
seq = os.path.join(outdir,'out.fasta')

out = os.path.join(outdir,'results.tab')

blastp = NcbiblastpCommandline(query=seq, subject= '/Users/sebastianlindner/Downloads/brenda_uniprot.faa', out=out, max_target_seqs = 5, outfmt="6 qseqid sseqid pident qcovs qlen slen length bitscore evalue",
                               evalue=math.exp(-15))

print("BLASTP: %s" % blastp)

stdout, stderr = blastp()
print("STDOUT: %s" % stdout)
print("STDERR: %s" % stderr)

BLASTP: blastp -out /Users/sebastianlindner/results.tab -outfmt "6 qseqid sseqid pident qcovs qlen slen length bitscore evalue" -query /Users/sebastianlindner/out.fasta -evalue 3.059023205018258e-07 -max_target_seqs 5 -subject /Users/sebastianlindner/Downloads/brenda_uniprot.faa



KeyboardInterrupt



## Generation à la carte? t-SNE __(3.2)__

In [None]:
#clustal-omega -i out.fasta --distmat-out=<file> --auto --full

In [82]:
[i.split() for i in open('dist.out', 'r').readlines()]

[['3'],
 ['sp|P69905|HBA_HUMAN', '0.000000', '0.140845', '0.436620'],
 ['sp|P01942|HBA_MOUSE', '0.140845', '0.000000', '0.436620'],
 ['sp|P13786|HBAZ_CAPHI', '0.436620', '0.436620', '0.000000']]

In [136]:
dist_matr = np.array([i.split()[1:] for i in open('dist.out', 'r').readlines() if i.split()[1:]], dtype='float')
label = np.array([i.split()[:1][0] for i in open('dist.out', 'r').readlines() if i.split()[1:]])

In [137]:
t_sne = TSNE(n_components=2,
    learning_rate="auto",
    perplexity=2,
    n_iter=1000,
    init="random",
    random_state=42,
)
tsne_results = t_sne.fit_transform(dist_matr)
t_plot = {"tsne-2d-one":tsne_results[:,0],"tsne-2d-two":tsne_results[:,1], 'y':label}

'The distance matrix was used with the scikit-learn t-SNE module 56 with default settings (early exaggeration 12, learning rate 200, maximum number of iterations: 1000) except that the embedding generation perplexity was set to 7. Coordinates given by t-SNE were used for plotting, the size of a given dot was visualized based on the cluster size it represents.'

In [1]:
plt.figure(figsize=(16,10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", len(label)),
    data=t_plot,
    legend="full",
    alpha=0.3
)
plt.show()

NameError: name 'plt' is not defined