In [7]:
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.patches as mpatches
from matplotlib.cm import ScalarMappable
from transformers import AutoModelForMaskedLM, AutoTokenizer


In [8]:
# Define amino acid dictionary for tokenization, define WT for length of context window
AAs = 'ACDEFGHIKLMNPQRSTVWY' # setup torchtext vocab to map AAs to indices, usage is aa2ind(list(AAsequence))
WT = 'MAGLRHTFVVADATLPDCPLVYASEGFYAMTGYGPDEVLGHNARFLQGEGTDPKEVQKIRDAIKKGEACSVRLLNYRKDGTPFWNLLTVTPIKTPDGRVSKFVGVQVDVTSKTEGKALA' # CreiLOV
sequence_length = len(WT)
model_identifier ='esm2_t33_650M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_identifier}")


In [9]:
def load_and_plot_heatmaps(base_dir, num_updates, model_identifier, output_dir):
    # Set up tokens and color map
    all_tokens = list(tokenizer.get_vocab().keys())[4:24]
    all_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in all_tokens]

    # Mutations to annotate: (position, mutant_aa)
    mutations_to_annotate = {
        4: 'D',   # R5D
        6: 'S',   # T7S
        25: 'T',  # G26T
        111: 'I', # K112I
        7: 'I',   # F8I
        15: 'E',  # P16E
        54: 'L',  # E55L
        59: 'K',  # R60K
        71: 'I',  # R72I
        76: 'K'   # R77K
    }
    
    Magma_r = plt.cm.magma_r(np.linspace(0, 1, 256))
    Magma_r[0] = [0, 0, 0, 0.03]
    cmap = LinearSegmentedColormap.from_list("Modified_Magma_r", Magma_r, N=256)
        
    for i in range(num_updates):
        model_identifier_name = get_model_identifier(i, model_identifier)
        data_path = os.path.join(base_dir, f'update_{i}_probabilities.npy')
        probabilities = np.load(data_path)
        
        plt.figure(figsize=(30, 6))
        heatmap = sns.heatmap(probabilities.T, cmap=cmap, square=True, linewidths=0.003, linecolor='0.7', vmin=0, vmax=1)
        cbar = heatmap.collections[0].colorbar
        cbar.set_label('Predicted Amino Acid Probabilities', fontsize=16)
        cbar.ax.tick_params(labelsize=12)
        plt.title(f'Heatmap for {model_identifier_name}')
        plt.yticks(np.arange(20) + 0.5, all_tokens, fontsize=8, rotation=0)
        plt.xlabel("Position in CreiLOV", fontsize=18)
        plt.ylabel('Amino Acid', fontsize=18)

        # Add black dots for WT residues and orange dots for mutations
        for pos, token in enumerate(WT):  
            token_id = tokenizer.convert_tokens_to_ids(token)
            if token_id in all_token_ids:  # Check if the token exists in the token list
                token_index = all_token_ids.index(token_id)
                dot_color = 'red' if token != WT[pos] else 'black' # Set dot color based on whether it matches WT or is a mutation
                plt.scatter(pos + 0.5, token_index + 0.5, color=dot_color, s=30)  # Adjust dot size as needed

        # add annotations with circles for the mutations R5D, T7S, G26T , K112I, F8I, P16E, E55L, R60K, R72I, R77K
        for pos, mutant_aa in mutations_to_annotate.items():
            token_id = tokenizer.convert_tokens_to_ids(mutant_aa)
            if token_id in all_token_ids:  # Check if the token exists in the token list
                token_index = all_token_ids.index(token_id)
                plt.scatter(pos + 0.5, token_index + 0.5, color='blue', s=30)
            else:
                print('Issue with token_id ', pos, mutant_aa, token_id)

        legend_elements = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=10, label='WT'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Mutation'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Mutations in best sequence variant')
        ]
                
        png_path = os.path.join(output_dir, f'delta_heatmap_update_{i}.png')
        svg_path = os.path.join(output_dir, f'delta_heatmap_update_{i}.svg')
        plt.savefig(png_path)
        plt.savefig(svg_path)
        plt.close()
    
def get_model_identifier(update_index, model_identifier):
    if update_index == 0:
        return f"Pre-trained {model_identifier}"
    elif 1 <= update_index <= 11:
        return f"{model_identifier}: SFT Update {update_index}"
    else:
        return f"{model_identifier}: PPO Update {update_index - 11}"

def create_gif(image_dir, output_filename):
    images = [Image.open(os.path.join(image_dir, f'heatmap_update_{i}.png')) for i in range(num_updates)]
    images[0].save(output_filename, save_all=True, append_images=images[1:], optimize=False, duration=500, loop=3)

# # Set directories and number of updates
# base_directory = 'SM_probability_data'
# output_directory = 'SM_probability_data'
# num_updates = 14  # Total updates including pre-trained, SFT, and PPO

# # Generate heatmaps and save as PNG
# load_and_plot_heatmaps(base_directory, num_updates, model_identifier, output_directory)

# # Create GIF from the saved images
# create_gif(output_directory, 'SM_probability_data/single_mut_probabilities_across_alignment.gif')

In [11]:
def load_and_plot_delta_heatmaps(base_dir, num_updates, output_dir, WT):
    # Set up tokens and color map
    all_tokens = list(tokenizer.get_vocab().keys())[4:24]
    all_token_ids = [tokenizer.convert_tokens_to_ids(token) for token in all_tokens]

    # Mutations to annotate: (position, mutant_aa)
    mutations_to_annotate = {
        4: 'D',   # R5D
        6: 'S',   # T7S
        25: 'T',  # G26T
        111: 'I', # K112I
        7: 'I',   # F8I
        15: 'E',  # P16E
        54: 'L',  # E55L
        59: 'K',  # R60K
        71: 'I',  # R72I
        76: 'K'   # R77K
    }
    
    # Load base probabilities from update_0
    base_path = os.path.join(base_dir, 'update_0_probabilities.npy')
    base_probabilities = np.load(base_path)
    
    # Define the custom colormap
    colors = [(0, '#B2182B'), (0.5, 'white'), (1, '#2166AC')]
    cmap_name = 'custom'
    custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors)
    
    for i in range(0, num_updates):
        model_identifier_name = get_model_identifier(i, model_identifier)
        current_path = os.path.join(base_dir, f'update_{i}_probabilities.npy')
        current_probabilities = np.load(current_path)
        delta_probabilities = current_probabilities - base_probabilities
        
        # Find min and max to set the colormap range and calculate midpoint
        min_score = np.min(delta_probabilities)
        max_score = np.max(delta_probabilities)
        midpoint = abs(min_score) / (max_score - min_score) if (max_score - min_score) != 0 else 0
        
        plt.figure(figsize=(30, 6))
        heatmap = sns.heatmap(delta_probabilities.T, cmap=custom_cmap, square=True, linewidths=0.003, linecolor='0.7', vmin=-1, vmax=1)
        cbar = heatmap.collections[0].colorbar
        cbar.set_label(f'Predicted Amino Acid Probabilities Relative to ESM2 (650M)', fontsize=16)
        cbar.ax.tick_params(labelsize=12)
        plt.title(f'Change in Probabilities for {model_identifier_name}')
        plt.yticks(np.arange(20) + 0.5, all_tokens, fontsize=8, rotation=0)
        plt.xlabel("Position in CreiLOV", fontsize=18)
        plt.ylabel('Amino Acid', fontsize=18)

        # Add black dots for WT residues and orange dots for mutations
        for pos, token in enumerate(WT):  
            token_id = tokenizer.convert_tokens_to_ids(token)
            if token_id in all_token_ids:  # Check if the token exists in the token list
                token_index = all_token_ids.index(token_id)
                dot_color = 'red' if token != WT[pos] else 'black' # Set dot color based on whether it matches WT or is a mutation
                plt.scatter(pos + 0.5, token_index + 0.5, color=dot_color, s=30)  # Adjust dot size as needed

        # add annotations with circles for the mutations R5D, T7S, G26T , K112I, F8I, P16E, E55L, R60K, R72I, R77K
        for pos, mutant_aa in mutations_to_annotate.items():
            token_id = tokenizer.convert_tokens_to_ids(mutant_aa)
            if token_id in all_token_ids:  # Check if the token exists in the token list
                token_index = all_token_ids.index(token_id)
                plt.scatter(pos + 0.5, token_index + 0.5, color='blue', s=30)
            else:
                print('Issue with token_id ', pos, mutant_aa, token_id)

        legend_elements = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=10, label='WT'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Mutation'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Mutations in best sequence variant')
        ]
                
        plt.legend(handles=legend_elements, loc='upper right')
        plt.tight_layout()
        
        png_path = os.path.join(output_dir, f'delta_heatmap_update_{i}.png')
        svg_path = os.path.join(output_dir, f'delta_heatmap_update_{i}.svg')
        plt.savefig(png_path)
        plt.savefig(svg_path)
        plt.close()

        ### 1-Row Heatmap of Cumulative Non-WT Delta Probabilities ---
        cumulative_delta = []
        for pos, wt_res in enumerate(WT):
            wt_token_id = tokenizer.convert_tokens_to_ids(wt_res)
            wt_idx = all_token_ids.index(wt_token_id) if wt_token_id in all_token_ids else None
            if wt_idx is None:
                cumulative_delta.append(0)
                continue

            delta_at_pos = delta_probabilities[pos, :]
            delta_non_wt = np.sum([v for j, v in enumerate(delta_at_pos) if j != wt_idx])
            cumulative_delta.append(delta_non_wt)

        plt.figure(figsize=(30, 1.5))
        sns.heatmap(np.array(cumulative_delta).reshape(1, -1), cmap=custom_cmap,
                    linewidths=0.003, linecolor='0.7', vmin=-1, vmax=1, cbar=True,
                    xticklabels=np.arange(1, len(WT) + 1), yticklabels=["Δ Non-WT P"])
        plt.title(f'Cumulative Δ Non-WT Probabilities: {model_identifier_name}')
        plt.xlabel("Position in CreiLOV", fontsize=18)
        plt.yticks(rotation=0)
        plt.tight_layout()

        cum_png_path = os.path.join(output_dir, f'cumulative_delta_heatmap_update_{i}.png')
        cum_svg_path = os.path.join(output_dir, f'cumulative_delta_heatmap_update_{i}.svg')
        plt.savefig(cum_png_path)
        plt.savefig(cum_svg_path)
        plt.close()

def adjust_cmap(cmap, midpoint):
    """ Adjust the midpoint of a colormap. """
    from matplotlib.colors import DivergingNorm
    return cmap if midpoint == 0.5 else LinearSegmentedColormap.from_list(
        'Adjusted ' + cmap.name,
        cmap(np.linspace(0, 1, 256)),
        N=256
    )

# Set directories and number of updates
base_directory = 'SM_probability_data'
output_directory = 'SM_probability_data'
num_updates = 14  # Total updates including update_0

# Generate delta heatmaps and save as PNG
load_and_plot_delta_heatmaps(base_directory, num_updates, output_directory, WT)

def create_gif(image_dir, output_filename):
    images = [Image.open(os.path.join(image_dir, f'delta_heatmap_update_{i}.png')) for i in range(num_updates)]
    images[0].save(output_filename, save_all=True, append_images=images[1:], optimize=False, duration=500, loop=3)

# Create GIF from the saved images
create_gif(output_directory, 'SM_probability_data/delta_single_mut_probabilities_across_alignment.gif')