# Import package

In [24]:
import os
import numpy as np
import torch
import random
from evodiff.pretrained import OA_DM_640M
from datasets import load_dataset
from evodiff.conditional_generation import get_intervals

torch.cuda.set_device('cuda:3')
device = torch.device("cuda:3")

## Load fine-tuned model

In [35]:
checkpoint = OA_DM_640M()
model, collater, tokenizer, scheme = checkpoint

model.eval().cuda()
model.to(device)

ByteNetLMTime(
  (embedder): ByteNetTime(
    (time_encoding): PositionalEncoding1D()
    (embedder): Embedding(31, 8, padding_idx=28)
    (up_embedder): PositionFeedForward(
      (conv): Conv1d(8, 1280, kernel_size=(1,), stride=(1,))
    )
    (layers): ModuleList(
      (0): ByteNetBlock(
        (conv): MaskedConv1d(1280, 1280, kernel_size=(5,), stride=(1,), padding=(2,))
        (sequence1): Sequential(
          (0): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (1): GELU(approximate=none)
          (2): PositionFeedForward(
            (conv): Conv1d(1280, 1280, kernel_size=(1,), stride=(1,))
          )
          (3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (4): GELU(approximate=none)
        )
        (sequence2): Sequential(
          (0): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (1): GELU(approximate=none)
          (2): PositionFeedForward(
            (conv): Conv1d(1280, 1280, kernel_size=(1,), stride=(1,

## Load dataset

In [20]:
ds = load_dataset("lun610200/detergent-enzyme")
test_ds = ds['test']

## Target

In [34]:
# target
target = test_ds[0]
print(target)
motif_start, motif_end = int(target['motif_start']), int(target['motif_end'])


{'id': 'A0A0S1M121', 'ec': '3.2.1.1', 'protein': 'Alpha-amylase (EC 3.2.1.1)', 'species': 'Bacillus licheniformis', 'organism': 'Bacillus licheniformis', 'motif_start': 60.0, 'motif_end': 400.0, 'pH_left': 7.0, 'pH_right': 9.0, 'temp_left': 38.0, 'temp_right': 42.0, 'sequence': 'tr|A0A0S1M121|A0A0S1M121_BACLI\nMKQQKRLYARLLPLLFALIFLLPHSAAAAANLKGTLMQYFEWYMPNDGQHWKRLQNDSAYLAEHGITAVWIPPAYKGTSQADVGYGAYDLYDLGEFHQKGTVRTKYGTKGELQSAIKSLHSRDINVYGDVVINHKGGADATEDVTAVEVDPADRNRVISGEHRIKAWTHFHFPGRGSTYSDFKWHWYHFDGTDWDESRKLNRIYKFQGKAWDWEVSNENGNYDYLMYADIDYDHPDVAAEIKRWGTWYANELQLDGFRLDAVKHIKFSFLRDWVNHVREKTGKEMFTVAEYWQNDLGALENYLNKTNFNHSVFDVPLHYQFHAASTQGGGYDMRKLLNGTVVSKHPLKAVTFVDNHDTQPGQSLESTVQTWFKPLAYAFILTRESGYPQVFYGDMYGTKGDSQREIPALKHKIEPILKARKQYAYGAQHDYFDHHDIVGWTREGDSSVANSGLAALITDGPGGAKRMYVGRQNAGETWHDITGNRSEPVVINSEGWGEFHVNGGSVSIYVQR', 'pdb_data': 'HEADER                                            01-JUN-22                     \nTITLE     ALPHAFOLD MONOMER V2.0 PREDICTION FOR ALPHA-AMYLASE (A0A0S1M121)    

## Parameter

In [26]:
batch_size = 1
scaffold_min = 100
scaffold_max = 1024
random_baseline=False
single_res_domain=False
scaffold_length = random.randint(scaffold_min, scaffold_max)

## Conditional Generation

In [33]:


mask = tokenizer.mask_id
print("sequence length", len(target['sequence']))
motif_seq = target['sequence'][motif_start: motif_end]
print("motif extracted from indexes supplied:", motif_seq)
motif_tokenized = tokenizer.tokenize((motif_seq,))

# Create input motif + scaffold
seq_len = scaffold_length + len(motif_seq)
sample = torch.zeros((batch_size, seq_len)) + mask # start from all mask
new_start = np.random.choice(scaffold_length) # randomly place motif in scaffold
sample[:, new_start:new_start+len(motif_seq)] = torch.tensor(motif_tokenized)
nonmask_locations = (sample[0] != mask).nonzero().flatten()

new_start_idxs, new_end_idxs = get_intervals(nonmask_locations, single_res_domain=single_res_domain)
print(f'new start index: {new_start_idxs} and new end index: {new_end_idxs}')
value, loc = (sample == mask).long().nonzero(as_tuple=True) # locations that need to be unmasked
loc = np.array(loc)
np.random.shuffle(loc)
sample = sample.long().to(device)
with torch.no_grad():
    for i in loc:
        timestep = torch.tensor([0] * batch_size)  # placeholder but not called in model
        timestep = timestep.to(device)
        if random_baseline:
            p_sample = torch.multinomial(torch.tensor(train_prob_dist), num_samples=1)
        else:
            prediction = model(sample, timestep)
            p = prediction[:, i, :len(tokenizer.all_aas) - 6]  # only canonical
            p = torch.nn.functional.softmax(p, dim=1)  # softmax over categorical probs
            p_sample = torch.multinomial(p, num_samples=1)
        sample[:, i] = p_sample.squeeze()
print("Generated sequence:", [tokenizer.untokenize(s) for s in sample])





sequence length 543
motif extracted from indexes supplied: ANLKGTLMQYFEWYMPNDGQHWKRLQNDSAYLAEHGITAVWIPPAYKGTSQADVGYGAYDLYDLGEFHQKGTVRTKYGTKGELQSAIKSLHSRDINVYGDVVINHKGGADATEDVTAVEVDPADRNRVISGEHRIKAWTHFHFPGRGSTYSDFKWHWYHFDGTDWDESRKLNRIYKFQGKAWDWEVSNENGNYDYLMYADIDYDHPDVAAEIKRWGTWYANELQLDGFRLDAVKHIKFSFLRDWVNHVREKTGKEMFTVAEYWQNDLGALENYLNKTNFNHSVFDVPLHYQFHAASTQGGGYDMRKLLNGTVVSKHPLKAVTFVDNHDTQPGQSLESTVQ
new start index: [787] and new end index: [1126]
Generated sequence: ['MATVEVTNTVDYEPRHLTQVKMSIGLDHNIRVWKKHMDSDEIRELHGTDADVDEYNQARNAVTGLSFVAREERELIDIFNPENSKDAEKAYKRLDDDVNQTTYARTTVGDSDDFGIDDELGNIFGGGGGAESDSVGNITNTPIPSTGNFKEYKKKNYDKTSKDLGASIGDVFKRRLKINLAIPVQQALLGIPSSAPHTAGLLDPDKLDTIIRDKVVQFAKTVIDEFDEKDGQNRELNYDAGMNEVTVGNRWQSTRDRVNYIMLNLPEYVGTDMMGEVPAGLSTGSMGAGAVQNRFTGEIQVATADYFIQSAYASITAQAGTELVRIDIDPEEATEYGVGDFLSFLLLKDEIGSVICLLGIDGIDQSGDRRLMTTDPETRVYDIRLSLIYVEIEKRTANTTLNIPLTTPEEDEEATIGSGEQQVLPAAVVISLPDKMVGRIRDILKDVVEDVIEEVVEEALITPKMRIVGPNDPTSNGTAFRRSKTPDPATWWESVYPSVEEDTNYHEEGGRDGALTKGTVNIKKDRLDR

## Evaluation

## Conditional Generation MSA

## Visualization

In [4]:
import py3Dmol
import os

dir = 'data/scaffolding-pdbs'
pdb = 'L8ETE9'

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(open(os.path.join(dir, f'{pdb}.pdb'),'r').read(),'pdb')

view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}}) # as color is set to lDDT
view.setStyle({'cartoon': {'color':'spectrum'}})

view.zoomTo()
view.show()