### Set parameters. 
protein_idx comes from the train_test_split.csv file

In [None]:
ckpt_path = None
device = "cuda:0"
config_file = 'configs/celle.yaml'
protein_idx = 901

### Run the following cell once to load the model and set definitions

In [None]:
#run once
import os

if 'scripts' in os.getcwd():
    os.chdir('..')

import torch
import numpy as np
from matplotlib import pyplot as plt
from collections import OrderedDict
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import patches
import glob
from PIL import Image
import cv2
from tqdm import tqdm

from einops import rearrange
from omegaconf import OmegaConf

from scripts.grad_map import ActivationsAndGradients
from celle_main import instantiate_from_config
from dataloader import OpenCellLoader

IUPAC_VOCAB = OrderedDict([
    ("<pad>", 0),
    ("<mask>", 1),
    ("<cls>", 2),
    ("<sep>", 3),
    ("<unk>", 4),
    ("A", 5),
    ("B", 6),
    ("C", 7),
    ("D", 8),
    ("E", 9),
    ("F", 10),
    ("G", 11),
    ("H", 12),
    ("I", 13),
    ("K", 14),
    ("L", 15),
    ("M", 16),
    ("N", 17),
    ("O", 18),
    ("P", 19),
    ("Q", 20),
    ("R", 21),
    ("S", 22),
    ("T", 23),
    ("U", 24),
    ("V", 25),
    ("W", 26),
    ("X", 27),
    ("Y", 28),
    ("Z", 29)])

IUPAC_VOCAB_INV = {v: k for k, v in IUPAC_VOCAB.items()}

def pad_to_square(input):
    squares = np.square(np.arange(1, 33))
    target_size = np.where(squares >= len(input))[0][0]
    diff = squares[target_size].item() - len(input)
    squared = np.concatenate([input, np.zeros(diff)])
    
    return np.reshape(squared,(target_size+1, target_size+1))

def get_key(fp):
    filename = os.path.splitext(os.path.basename(fp))[0]
    int_part = filename.split()[0]
    return float(int_part)

def write_video(file_path, frames, fps):
    """
    Writes frames to an mp4 video file
    :param file_path: Path to output video, must end with .mp4
    :param frames: List of PIL.Image objects
    :param fps: Desired frame rate
    """

    w, h = frames[0].size
    fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
    writer = cv2.VideoWriter(file_path, fourcc, fps, (w, h))

    for frame in frames:
        writer.write(cv2.resize(cv2.cvtColor(np.array(frame),cv2.COLOR_RGB2BGR),(w,h),fx=0,fy=0,interpolation=cv2.INTER_CUBIC))

    writer.release() 

# color maps for plots
color_array = plt.get_cmap('gist_rainbow')(range(256))
color_array[:,-1] = np.linspace(1.0,0.0,256)
map_object = LinearSegmentedColormap.from_list(name='rainbow_alpha',colors=color_array[::-1])
plt.register_cmap(cmap=map_object)
colors = [(0, 0, 0, 1), (0, 0, 0, 0), (0, 0, 0, 0)] # first color is black, last is red
map_object = LinearSegmentedColormap.from_list("gray_alpha", colors, N=256)
plt.register_cmap(cmap=map_object)
colors = [(.78, .12, .07, .5),(4*.78/5, 4*.12/5, 4*.07/5, .5),(3*.78/5, 3*.12/5, 3*.07/5, .5), (2*.78/5, 2*.12/5, 2*.07/5, .5),(.78/5, .12/5, .07/5, .5),(0, 0, 0, 1)] # first color is black, last is red
map_object = LinearSegmentedColormap.from_list("red_alpha", colors[::-1], N=256)


device = torch.device(device)

#load model
configs = OmegaConf.load(config_file);
model = instantiate_from_config(configs.model).to(device);

if ckpt_path:
    t = torch.load(ckpt_path,map_location = 'cpu')['state_dict'];
    for key in list(t.keys()):
        t[key.replace('celle.','')] = t.pop(key);
        
model.celle.load_state_dict(t,strict=False);
model = model.celle
model = model.to(device)
model = model.eval()

# get some params
depth = configs.model.params.depth
crop_size = configs.data.params.crop_size
sequence_mode = configs.data.params.sequence_mode
vocab = configs.data.params.vocab
threshold = configs.data.params.threshold
text_seq_len = configs.data.params.text_seq_len

# convert string to numbered index
dataset = OpenCellLoader(crop_size=crop_size, sequence_mode=sequence_mode, vocab=vocab, threshold=threshold, text_seq_len=text_seq_len)

### Run the following cell to generate the prediction and save attention weights

In [None]:
with torch.no_grad():
    layers = [model.transformer.layers.layers[d][0].fn.fn.fn.fn.fn.save_attn for d in range(depth)]
    activations_and_grads = ActivationsAndGradients(model, layers, None)

    batch = dataset[protein_idx]
    sequence = batch['sequence'].to(device)
    nucleus = batch['nucleus'].unsqueeze(0).to(device)
    target = batch['threshold'].unsqueeze(0).to(device)
    protein_name = dataset.retrieve_metadata(protein_idx)['protein_name']

    print('Generating Best Estimate')
    output, logits = activations_and_grads([sequence, nucleus, target],filter_thres=.5)

with torch.no_grad():
            
    logits = logits[:,-256:,-512:]
    image_tokens = logits @ model.vae.model.quantize.embedding.weight
    image_tokens = rearrange(image_patches, "b (h w) c -> b c h w", h=int(np.sqrt(256)))
    pdf = model.vae.model.decode(image_tokens)
    pdf = torch.clip(pdf,0,1)
    
chunked_list = [activations_and_grads.activations[i:i+depth] for i in range(0, len(activations_and_grads.activations), depth)]
attn_blocks = []

for idx, chunk in enumerate(chunked_list):
    chunked_list[idx] = torch.cat(chunk,dim=0)
s = torch.tensor(np.median(torch.cat(chunked_list,dim=2),axis=1))

s = (s + torch.eye(s.shape[1])[None,...])
s = s / s.sum(axis=2)[...,None]
joint_attentions = torch.zeros(s.shape)
layers = joint_attentions.shape[0]
joint_attentions[0] = s[0]
for i in torch.arange(1,layers):
    joint_attentions[i] = torch.matmul(s[i],(joint_attentions[i-1]))
    
unpadded_sequence = sequence[:,:np.where(sequence!=0)[-1][-1]+1][0]
if unpadded_sequence[-1] == 3:
    unpadded_sequence = unpadded_sequence[:-2]
protein_letters = [IUPAC_VOCAB_INV[x.item()] for x in unpadded_sequence]

mean_activation_mat = (joint_attentions[-1]).cpu().numpy()

generative_weights = mean_activation_mat[-256:]

### Run the following cell to visualize a single step (0-256)

In [None]:
step = 256

In [None]:
tok_pos = step
relevant_weights = generative_weights[tok_pos]
sequence_text_tokens = (relevant_weights[1:min(len(unpadded_sequence)+1,1001)])
condition_tokens = relevant_weights[1001:1257]
image_tokens = np.log(relevant_weights[1257:])
    
fig = plt.figure(figsize=(17.5, 5), dpi=140, tight_layout=True)
fig.suptitle(f'{protein_name}\nRelative Attention Weights')
axes = []

#.1

fig = plt.figure(figsize=(17.5, 5), dpi=140, tight_layout=True)
fig.suptitle(f'{protein_name}\nRelative Attention Weights')
axes = []


# .1 SEQUENCE PLOT
axes.append(fig.add_subplot(1, 4, 1))
axes[-1].set_title('Sequence')
axes[-1].axis('off')
count = 0
text_plot_weights = pad_to_square(sequence_text_tokens) 
text_mask = np.zeros(text_plot_weights.shape)
for i in range(text_plot_weights.shape[0]):
    for j in range(text_plot_weights.shape[1]):
        axes[-1].text(j,i,f'{protein_letters[count]}',ha='center',va='center',color='black')
        text_mask[i,j] = 1
        count += 1
        if count == len(protein_letters):
            break
    if count == len(protein_letters):
            break
plt.imshow(text_plot_weights,cmap='OrRd')
# plt.clim(min_seq,max_seq)
axes[-1].imshow(text_mask,cmap='gray_alpha',alpha = 1)


# .1 NUCLEUS PLOT
axes.append(fig.add_subplot(1, 4, 2))
axes[-1].set_title('Nucleus Image')
test_nuc = np.ones(nucleus[0,0].shape)
for i in range(condition_tokens.shape[0]):
    x = i * 16 % 256
    y = i // (256 // 16) * 16
    test_nuc[y:y + 16, x:x + 16] *= condition_tokens[i]
    
    if i == tok_pos:
        rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
axes[-1].axis('off')
axes[-1].imshow(nucleus[0,0],cmap='gray', interpolation= 'bilinear')
plt.imshow(test_nuc,cmap='red_alpha',alpha=.90)
#plt.clim(min_condition_weight,max_condition_weight)
# plt.clim(min_cond,max_cond)
axes[-1].add_patch(rect)


# .1 OUTPUT PLOT
axes.append(fig.add_subplot(1, 4, 3))
axes[-1].set_title('Predicted Threshold Image')
image_tokens_u = np.concatenate([image_tokens,np.zeros(1)])
test_im = np.ones(output[0,0].shape)
im_mask = np.ones(test_im.shape)

for i in range(test_im.shape[0]):
    x = i * 16 % 256
    y = i // (256 // 16) * 16
    
    if i <= tok_pos: 
        test_im[y:y + 16, x:x + 16] *= image_tokens_u[i]
        im_mask[y:y + 16, x:x + 16] = 1
    else:
        test_im[y:y + 16, x:x + 16] = 0
        im_mask[y:y + 16, x:x + 16] = 0
        
    if i == tok_pos:
        rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
        
axes[-1].axis('off')
axes[-1].imshow(np.clip(output[0,0],0,1),cmap='gray',alpha=1, interpolation= 'bilinear')        
plt.imshow(test_im,cmap='red_alpha',alpha=.90)
#plt.clim(min_protein_weight,max_protein_weight)
# plt.clim(min_im,max_im)
axes[-1].imshow(im_mask,cmap='gray_alpha',alpha=1)
axes[-1].add_patch(rect)


# .1 PDF PLOT
axes.append(fig.add_subplot(1, 4, 4))
axes[-1].set_title('Predicted PDF')
image_tokens_u = np.concatenate([image_tokens,np.zeros(1)])
test_im = np.ones(output[0,0].shape)
im_mask = np.ones(test_im.shape)

for i in range(test_im.shape[0]):
    x = i * 16 % 256
    y = i // (256 // 16) * 16
    
    if i <= tok_pos: 
        test_im[y:y + 16, x:x + 16] *= pdf[y:y + 16, x:x + 16]
        im_mask[y:y + 16, x:x + 16] = 1
    else:
        test_im[y:y + 16, x:x + 16] = 0
        im_mask[y:y + 16, x:x + 16] = 0
        
    if i == tok_pos:
        rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
        
axes[-1].axis('off')
axes[-1].imshow(np.clip(nucleus[0,0],0,1),cmap='gray',alpha=1, interpolation= 'bilinear')        
plt.imshow(test_im,cmap='rainbow_alpha',alpha=.80)
axes[-1].add_patch(rect)

plt.show()

### Run the following to generate a video of all steps

In [None]:
for tok_pos in tqdm(range(256)):
    relevant_weights = generative_weights[tok_pos]
    sequence_text_tokens = (relevant_weights[1:min(len(unpadded_sequence)+1,1001)])
    condition_tokens = relevant_weights[1001:1257]
    image_tokens = np.log(relevant_weights[1257:])
        
    fig = plt.figure(figsize=(17.5, 5), dpi=140, tight_layout=True)
    fig.suptitle(f'{protein_name}\nRelative Attention Weights')
    axes = []


    # .0 SEQUENCE PLOT
    axes.append(fig.add_subplot(1, 4, 1))
    axes[-1].set_title('Sequence')
    axes[-1].axis('off')
    count = 0
    text_plot_weights = pad_to_square(sequence_text_tokens) 
    text_mask = np.zeros(text_plot_weights.shape)
    for i in range(text_plot_weights.shape[0]):
        for j in range(text_plot_weights.shape[1]):
            axes[-1].text(j,i,f'{protein_letters[count]}',ha='center',va='center',color='black')
            text_mask[i,j] = 1
            count += 1
            if count == len(protein_letters):
                break
        if count == len(protein_letters):
                break
    plt.imshow(text_plot_weights,cmap='OrRd')

    axes[-1].imshow(text_mask,cmap='gray_alpha',alpha = 1)
    
    
    
    # .0 NUCLEUS PLOT
    axes.append(fig.add_subplot(1, 4, 2))
    axes[-1].set_title('Nucleus Image')
    test_nuc = np.ones(nucleus[0,0].shape)
    for i in range(condition_tokens.shape[0]):
        x = i * 16 % 256
        y = i // (256 // 16) * 16
        test_nuc[y:y + 16, x:x + 16] *= condition_tokens[i]
        
        if i == tok_pos:
            rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
    axes[-1].axis('off')
    axes[-1].imshow(nucleus[0,0],cmap='gray', interpolation= 'bilinear')
    plt.imshow(test_nuc,cmap='red_alpha',alpha=.90)

    axes[-1].add_patch(rect)
    

    # .0 OUTPUT PLOT
    axes.append(fig.add_subplot(1, 4, 3))
    axes[-1].set_title('Predicted Threshold Image')
    image_tokens_u = np.concatenate([image_tokens,np.zeros(1)])
    test_im = np.ones(output[0,0].shape)
    im_mask = np.ones(test_im.shape)

    for i in range(test_im.shape[0]):
        x = i * 16 % 256
        y = i // (256 // 16) * 16
        
        if i < tok_pos: 
            test_im[y:y + 16, x:x + 16] *= image_tokens_u[i]
            im_mask[y:y + 16, x:x + 16] = 1
        else:
            test_im[y:y + 16, x:x + 16] = 0
            im_mask[y:y + 16, x:x + 16] = 0
            
        if i == tok_pos:
            rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
            
    axes[-1].axis('off')
    axes[-1].imshow(np.clip(output[0,0],0,1),cmap='gray',alpha=1, interpolation= 'bilinear')        
    plt.imshow(test_im,cmap='red_alpha',alpha=.90)

    axes[-1].imshow(im_mask,cmap='gray_alpha',alpha=1)
    axes[-1].add_patch(rect)
    
    
    # .0 PDF PLOT
    axes.append(fig.add_subplot(1, 4, 4))
    axes[-1].set_title('Predicted PDF')
    image_tokens_u = np.concatenate([image_tokens,np.zeros(1)])
    test_im = np.ones(output[0,0].shape)
    im_mask = np.ones(test_im.shape)
    
    for i in range(test_im.shape[0]):
        x = i * 16 % 256
        y = i // (256 // 16) * 16
        
        if i < tok_pos: 
            test_im[y:y + 16, x:x + 16] *= pdf[y:y + 16, x:x + 16]
            im_mask[y:y + 16, x:x + 16] = 1
        else:
            test_im[y:y + 16, x:x + 16] = 0
            im_mask[y:y + 16, x:x + 16] = 0
            
        if i == tok_pos:
            rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
            
    axes[-1].axis('off')
    axes[-1].imshow(np.clip(nucleus[0,0],0,1),cmap='gray',alpha=1, interpolation= 'bilinear')        
    plt.imshow(test_im,cmap='rainbow_alpha',alpha=.80)
    axes[-1].add_patch(rect)
    
    plt.savefig(f'temp/{tok_pos}.jpg')
    plt.close(fig)
    
    
    #.1
    
    fig = plt.figure(figsize=(17.5, 5), dpi=140, tight_layout=True)
    fig.suptitle(f'{protein_name}\nRelative Attention Weights')
    axes = []


    # .1 SEQUENCE PLOT
    axes.append(fig.add_subplot(1, 4, 1))
    axes[-1].set_title('Sequence')
    axes[-1].axis('off')
    count = 0
    text_plot_weights = pad_to_square(sequence_text_tokens) 
    text_mask = np.zeros(text_plot_weights.shape)
    for i in range(text_plot_weights.shape[0]):
        for j in range(text_plot_weights.shape[1]):
            axes[-1].text(j,i,f'{protein_letters[count]}',ha='center',va='center',color='black')
            text_mask[i,j] = 1
            count += 1
            if count == len(protein_letters):
                break
        if count == len(protein_letters):
                break
    plt.imshow(text_plot_weights,cmap='OrRd')
    # plt.clim(min_seq,max_seq)
    axes[-1].imshow(text_mask,cmap='gray_alpha',alpha = 1)
    
    
    # .1 NUCLEUS PLOT
    axes.append(fig.add_subplot(1, 4, 2))
    axes[-1].set_title('Nucleus Image')
    test_nuc = np.ones(nucleus[0,0].shape)
    for i in range(condition_tokens.shape[0]):
        x = i * 16 % 256
        y = i // (256 // 16) * 16
        test_nuc[y:y + 16, x:x + 16] *= condition_tokens[i]
        
        if i == tok_pos:
            rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
    axes[-1].axis('off')
    axes[-1].imshow(nucleus[0,0],cmap='gray', interpolation= 'bilinear')
    plt.imshow(test_nuc,cmap='red_alpha',alpha=.90)
    #plt.clim(min_condition_weight,max_condition_weight)
    # plt.clim(min_cond,max_cond)
    axes[-1].add_patch(rect)
    

    # .1 OUTPUT PLOT
    axes.append(fig.add_subplot(1, 4, 3))
    axes[-1].set_title('Predicted Threshold Image')
    image_tokens_u = np.concatenate([image_tokens,np.zeros(1)])
    test_im = np.ones(output[0,0].shape)
    im_mask = np.ones(test_im.shape)

    for i in range(test_im.shape[0]):
        x = i * 16 % 256
        y = i // (256 // 16) * 16
        
        if i <= tok_pos: 
            test_im[y:y + 16, x:x + 16] *= image_tokens_u[i]
            im_mask[y:y + 16, x:x + 16] = 1
        else:
            test_im[y:y + 16, x:x + 16] = 0
            im_mask[y:y + 16, x:x + 16] = 0
            
        if i == tok_pos:
            rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
            
    axes[-1].axis('off')
    axes[-1].imshow(np.clip(output[0,0],0,1),cmap='gray',alpha=1, interpolation= 'bilinear')        
    plt.imshow(test_im,cmap='red_alpha',alpha=.90)
    #plt.clim(min_protein_weight,max_protein_weight)
    # plt.clim(min_im,max_im)
    axes[-1].imshow(im_mask,cmap='gray_alpha',alpha=1)
    axes[-1].add_patch(rect)
    
    
    # .1 PDF PLOT
    axes.append(fig.add_subplot(1, 4, 4))
    axes[-1].set_title('Predicted PDF')
    image_tokens_u = np.concatenate([image_tokens,np.zeros(1)])
    test_im = np.ones(output[0,0].shape)
    im_mask = np.ones(test_im.shape)
    
    for i in range(test_im.shape[0]):
        x = i * 16 % 256
        y = i // (256 // 16) * 16
        
        if i <= tok_pos: 
            test_im[y:y + 16, x:x + 16] *= pdf[y:y + 16, x:x + 16]
            im_mask[y:y + 16, x:x + 16] = 1
        else:
            test_im[y:y + 16, x:x + 16] = 0
            im_mask[y:y + 16, x:x + 16] = 0
            
        if i == tok_pos:
            rect = patches.Rectangle((x, y), 16, 16, linewidth=2, edgecolor='blue', facecolor='none')
            
    axes[-1].axis('off')
    axes[-1].imshow(np.clip(nucleus[0,0],0,1),cmap='gray',alpha=1, interpolation= 'bilinear')        
    plt.imshow(test_im,cmap='rainbow_alpha',alpha=.80)
    axes[-1].add_patch(rect)
    
    plt.savefig(f'temp/{tok_pos}.1.jpg')
    plt.close(fig)

imgs = (Image.open(f) for f in sorted(glob.glob("temp/*.jpg"), key=get_key))

write_video(f"{protein_name}.mp4",[img for img in imgs],15)