In [None]:
from google.colab import drive
drive.mount('/content/drive')

!git clone https://github.com/aqlaboratory/genie
%cd ./genie
!pip install -e .
!mkdir output
!mkdir ./output/pdbs

In [None]:
import os
import shutil
import glob
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import uuid
from datetime import datetime
import re
import torch
#from time import time
import time
from genie.config import Config
from genie.diffusion.genie import Genie
from evaluations.pipeline import utils
import sys
import argparse
from tqdm import tqdm, trange

meta_data_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/genie_unconditional/generation_metadata_genie.csv"

if os.path.exists(meta_data_filepath):
  all_metadata_df = pd.read_csv(meta_data_filepath)
  print("Existing generation metadata read in.")
else:
  all_metadata_df = pd.DataFrame()
  #all_metadata_df.to_csv(meta_data_filepath, index=False)
  print("Created generation metadata dataframe")


len_dist_filepath = "/content/drive/MyDrive/Generative_Models/unconditional_generation/genie_unconditional/uniref50_length_dist_genie.json"

if os.path.exists(len_dist_filepath):
  with open(len_dist_filepath, "r") as f:
    uniprot_length_dist =  json.load(f)
  print("Loaded length distribution from drive")
else:

  #https://www.uniprot.org/uniprotkb/statistics#sequence-size
  bins = np.array([13,51,101,151,201,251,301,351,401,451,501,551,601,651,701,751,801,851,901,951,1001,1101,1201,1301,1401,1501,1601,1701,1801,1901,2001,2101,2201,2301,2401,2501,34350])
  swissprot_reviewed = np.array([0,9968,43534,59796,59574,58452,52413,52846,45901,37706,30572,22287,15830,13156,9403,7870,5700,4889,5301,4109,3007,4124,2897,2207,2070,1675,834,642,587,503,395,272,386,340,234,195,1462])
  TrEMBL_unreviewed = np.array([0,2668805,19825275,24705701,23838128,23462438,23225451,21389271,16814580,14287105,11501843,8283150,6266068,4715059,3755005,3186452,2687314,2166878,1843669,1457871,1153537,1975953,1398765,961048,664766,517536,390552,300984,236895,210921,180246,138808,122833,102865,82441,71548,527646])

  ecdf = np.cumsum(swissprot_reviewed) / np.sum(swissprot_reviewed)
  #shortest protein in uniprot is 14 res, longest is 34350 res.
  x = np.arange(14, 34350+1)
  ecdf = np.interp(x, bins, ecdf)

  # Sample from the empirical CDF
  num_samples = 11000
  random_values = np.random.rand(num_samples)
  sampled_lengths = np.round(np.interp(random_values, ecdf, x)).astype(int)
  #ten thousand sequences up to 1000 res in length
  sampled_lengths = sampled_lengths[sampled_lengths <= 1000][0:10000]

  # Plot the histogram of sampled values
  hist_values, bin_edges, patches = plt.hist(sampled_lengths, bins=x[0:1001-13], alpha=0.7, label='Sampled Values')
  plt.xlabel('X-axis label')
  plt.ylabel('Frequency')
  plt.legend()
  plt.show()

  uniprot_length_dist = list(zip([int(edge) for edge in bin_edges],[int(value) for value in hist_values]))
  with open(len_dist_filepath, "w") as f:
      json.dump(uniprot_length_dist, f)


Existing generation metadata read in.
Loaded length distribution from drive


In [None]:
device = torch.device("cuda")
torch.cuda.empty_cache()
config = Config('./weights/swissprot_l_256/configuration')
model = Genie.load_from_checkpoint('./weights/swissprot_l_256/epoch=99.ckpt', config=config)
max_n_res = model.config.io['max_n_res']
noise_scale = 0.6
batch_size =1 #Must be one to keep gpu ram usage low (will break later code now if changed)

In [None]:
max_n_res

256

In [None]:
cleanup_command = """
for file in /content/genie/output/pdbs/*.pdb
  do
  bn=$(basename "$file")
  mv $file /content/drive/MyDrive/Generative_Models/unconditional_generation/genie_unconditional/$bn
  done
find ./output -type f -exec rm {} \;
"""

pdbs_dir = os.path.join("output", 'pdbs')

for length, batch_length in uniprot_length_dist:
  torch.cuda.empty_cache()
  if length > max_n_res or all_metadata_df.loc[all_metadata_df.conditions == "length = " + str(length),:].shape[0] >= batch_length: continue
  for batch_idx in range(batch_length):
    mask = torch.cat([torch.ones((batch_size, length)),torch.zeros((batch_size, max_n_res - length))], dim=1).to(device)
    meta_data = {}
    meta_data['entity_id'] = str(uuid.uuid4())
    meta_data["batch_id"] = None
    meta_data["batch_size"] = None
    meta_data["Timestamp"] = str(datetime.now())
    meta_data['model'] = 'Genie'
    meta_data['task'] = 'backbone_pdb_generation'
    meta_data['conditions'] = 'length = ' + str(length)
    meta_data['wall_time_batch'] = None
    meta_data['gpu'] = 'T4 GPU'
    start_time = time.time()
    ts = model.p_sample_loop(mask, noise_scale, verbose=False)[-1]
    coords = ts[0].trans.detach().cpu().numpy()
    coords = coords[:length]
    meta_data['output_file_name'] =
    filepath = os.path.join("output", "len"+ str(length)+"_sample" + str(batch_idx) + ".npy")
    np.savetxt(filepath, coords, fmt='%.3f', delimiter=',')
    coords = np.loadtxt(filepath, delimiter=',') #For some reason we have to reload from file for the pdb conversion code VVVV to work
    # VVVVV modified from from evaluations.pipeline.pipeline.py
    try:
      domain_name = filepath.split('/')[-1].split('.')[0]
      new_file_name = 'Genie_len' + str(length) + '_' + meta_data['entity_id'] + '.pdb'
      pdb_filepath = os.path.join(pdbs_dir, new_file_name)
      if np.isnan(coords).any():
        print(f'Error: {domain_name}')
      else:
        seq = 'A' * coords.shape[0]len()
        utils.save_as_pdb(seq, coords, pdb_filepath)
    except:
      os.remove(pdb_filepath)
      print(f'Error: {domain_name}')
    meta_data['output_file_name'] = new_file_name
    end_time = time.time()
    meta_data['wall_time_task'] = str(end_time-start_time) + " Seconds"
    metadata_entry = pd.Series(meta_data)
    all_metadata_df = pd.concat([all_metadata_df,pd.DataFrame(metadata_entry).T], ignore_index=True)

  all_metadata_df.to_csv(meta_data_filepath, index=False)
  !{cleanup_command}


Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  b = torch.cross(t[:, :-1], t[:, 1:])


In [None]:
!kill -9 -1