In [1]:
protein_sequence = 'MDDDIAALVVDNGSGMCKAGFAGDDAPRAVFPSIVGRPRHQGVMVGMGQKDSYVGDEAQSKRGILTLKYPIEHGIVTNWDDMEKIWHHTFYNELRVAPEEHPVLLTEAPLNPKANREKMTQIMFETFNTPAMYVAIQAVLSLYASGRTTGIVMDSGDGVTHTVPIYEGYALPHAILRLDLAGRDLTDYLMKILTERGYSFTTTAEREIVRDIKEKLCYVALDFEQEMATAASSSSLEKSYELPDGQVITIGNERFRCPEALFQPSFLGMESCGIHETTFNSIMKCDVDIRKDLYANTVLSGGTTMYPGIADRMQKEITALAPSTMKIKIIAPPERKYSVWIGGSILASLSTFQQMWISKQEYDESGPSIVHRKCF'
nucleus_image = 'images/nucleus.jpg'
protein_name = "Actin, cytoplasmic 1 (Beta-actin) [Cleaved into: Actin, cytoplasmic 1, N-terminally processed]"
device = "cuda:0"
config_file = 'configs/celle.yaml'
ckpt_path = 'logs/2022-09-14T00-40-31_celle/checkpoints/last.ckpt'

In [3]:
#run once
import os

if 'notebooks' in os.getcwd():
    os.chdir('..')

import torch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.colors import LinearSegmentedColormap
import torchvision

from einops import rearrange
from omegaconf import OmegaConf

from celle_main import instantiate_from_config
from dataloader import OpenCellLoader

# color map for plot
color_array = plt.get_cmap('gist_rainbow')(range(256))
color_array[:,-1] = np.linspace(1.0,0.0,256)
map_object = LinearSegmentedColormap.from_list(name='rainbow_alpha',colors=color_array[::-1])
plt.register_cmap(cmap=map_object)

device = torch.device(device)

# #load model
configs = OmegaConf.load(config_file);
model = instantiate_from_config(configs.model).to(device);
if ckpt_path:
    t = torch.load(ckpt_path,map_location = 'cpu')['state_dict'];
    for key in list(t.keys()):
        t[key.replace('celle.','')] = t.pop(key);
model.celle.load_state_dict(t,strict=False);
model = model.celle
model = model.to(device)
model = model.eval()

# get some params
crop_size = configs.data.params.crop_size
sequence_mode = configs.data.params.sequence_mode
vocab = configs.data.params.vocab
threshold = configs.data.params.threshold
text_seq_len = configs.data.params.text_seq_len

# convert string to numbered index
dataset = OpenCellLoader(data_csv=configs.data.params.data_csv,crop_size=crop_size, sequence_mode=sequence_mode, vocab=vocab, threshold=threshold, text_seq_len=text_seq_len)

  plt.register_cmap(cmap=map_object)


In [4]:
protein_sequence = ''.join(filter(str.isalpha, protein_sequence)) 
protein_sequence = dataset.tokenize_seqeuence(protein_sequence)

# import nucleus, scale and crop
nucleus = torch.tensor(plt.imread(nucleus_image)).float()
nucleus /= 255
nucleus = torchvision.transforms.RandomCrop(256)(nucleus).unsqueeze(0).unsqueeze(0)

# generate image
with torch.no_grad():
    output = model.generate_images(text=protein_sequence.to(device), condition = nucleus.to(device), return_logits=True, use_cache=True, progress=True)
            
    logits = output[-1][:,-256:,-512:]
    image_tokens = logits @ model.vae.model.quantize.embedding.weight
    image_tokens = rearrange(image_patches, "b (h w) c -> b c h w", h=int(np.sqrt(256)))
    pdf = model.vae.model.decode(image_tokens)
    pdf = torch.clip(pdf,0,1)
    
    plt.figure(dpi=300, clear=True)      
    plt.axis('off')
    plt.imshow(nucleus[0,0],cmap='gray',interpolation='bicubic')
    plt.imshow(pdf.cpu()[0,0],cmap='rainbow_alpha',alpha = .75,interpolation='bicubic')
    plt.colorbar(mappable=cm.ScalarMappable(cmap='rainbow_alpha'))
    
    if protein_name:
        plt.title(protein_name)

  protein_vector = torch.tensor(


NameError: name 'model' is not defined