In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src/"))
import extract.data_loading as data_loading
import extract.compute_predictions as compute_predictions
import model.util as model_util
import model.binary_models as binary_models
import plot.viz_sequence as viz_sequence
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "axes.labelweight": "bold",
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold"
}
plt.rcParams.update(plot_params)
plt.rcParams["animation.ffmpeg_path"] = "/users/amtseng/lib/ffmpeg/ffmpeg-git-20200504-amd64-static/ffmpeg"

### Define paths for the model and data of interest

In [None]:
# Shared paths/constants
reference_fasta = "/users/amtseng/genomes/hg38.fasta"
chrom_sizes = "/users/amtseng/genomes/hg38.canon.chrom.sizes"
data_base_path = "/users/amtseng/att_priors/data/processed/"
model_base_path = "/users/amtseng/att_priors/models/trained_models/binary/"
chrom_set = ["chr1"]
input_length = 1000
fourier_att_prior_freq_limit = 150

In [None]:
# SPI1
condition_name = "SPI1"
files_spec_path = os.path.join(data_base_path, "ENCODE_TFChIP/binary/config/SPI1/SPI1_training_paths.json")
num_tasks = 4
task_index = None
model_class = binary_models.BinaryPredictor
noprior_model_path = os.path.join(model_base_path, "SPI1/4/model_ckpt_epoch_2.pt")
prior_model_path = os.path.join(model_base_path, "SPI1_prior/16/model_ckpt_epoch_6.pt")

### Import models

In [None]:
torch.set_grad_enabled(True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def restore_model(model_path):
    model = model_util.restore_model(model_class, model_path)
    model.eval()
    model = model.to(device)
    return model

In [None]:
# Import the model without priors
noprior_model = restore_model(noprior_model_path)

In [None]:
# Import the model with priors
prior_model = restore_model(prior_model_path)

### Data preparation
Create an input data loader, that maps coordinates to data needed for the model

In [None]:
input_func = data_loading.get_binary_input_func(
    files_spec_path, input_length, reference_fasta
)
pos_bins = data_loading.get_positive_binary_bins(
    files_spec_path, task_ind=task_index, chrom_set=chrom_set
)

In [None]:
def get_grad_signal(model, bin_index):
    results = compute_predictions.get_binary_model_predictions(                                              
        model, np.array([bin_index]), input_func,                      
        return_losses=False, return_gradients=True, show_progress=False                                         
    )
    coords = results["coords"]
    input_seqs = results["input_seqs"]
    input_grads = results["input_grads"]
    return np.sum(input_grads[0] * input_seqs[0], axis=1)

### Plotting functions

In [None]:
def dft(signal):
    fourier_coeffs = np.fft.fft(signal)
    fourier_freqs = 2 * np.pi * np.fft.fftfreq(signal.size)
    fourier_freqs = fourier_freqs[:int(len(fourier_freqs) / 2)]  # Only the positive frequencies
    mags = np.abs(fourier_coeffs)[:int(len(fourier_coeffs) / 2)]  # Frequency magnitudes are symmetric
    return fourier_freqs, mags

In [None]:
def apply_lpf(signal, frequency_limit=fourier_att_prior_freq_limit):
    coeffs = np.fft.rfft(signal)
    cutoff = frequency_limit // 2
    coeffs[cutoff:] = 0
    return np.fft.irfft(coeffs)    

In [None]:
def plot_fft(signal, include_dc=False, pos_limit=None, title=None, color="red"):
    abs_signal = np.abs(signal)
    
    freqs, mags = dft(abs_signal)
    if not include_dc:
        freqs, mags = freqs[1:], mags[1:]
        
    plt.figure(figsize=(20, 2))
    plt.plot(freqs, mags, color=color)
    plt.xlabel("Frequency (radians)")
    plt.ylabel("|Frequency component|")
    if pos_limit is not None:
        pos_limit_radians = pos_limit * 2 * np.pi / len(signal)
        plt.axvline(x=pos_limit_radians, color="black")
    if title:
        plt.title(title)
    plt.show()

In [None]:
def plot_signal(signal, title=None, color=None):
    plt.figure(figsize=(20, 2))
    plt.plot(signal, color=color)
    if title:
        plt.title(title)
    plt.show()

### View tracks

In [None]:
bin_index = 5734
noprior_imp = get_grad_signal(noprior_model, bin_index)
prior_imp = get_grad_signal(prior_model, bin_index)
plot_signal(noprior_imp, color="coral")
plot_signal(prior_imp, color="royalblue")

In [None]:
plot_signal(noprior_imp, color="coral")
plot_fft(noprior_imp, color="darkmagenta")
cutoff_imp = apply_lpf(noprior_imp, 150)
plot_signal(cutoff_imp, color="coral")
plot_fft(cutoff_imp, color="darkmagenta")

In [None]:
plot_signal(prior_imp, color="royalblue")
plot_fft(prior_imp, color="darkmagenta")

In [None]:
def make_fft_cutoff_animation(signal, max_frequency_limit):
    fig, ax = plt.subplots(2, 1, figsize=(20, 8))
    
    # Set limits of axes based on original signal
    ax[0].set_xlim((0, len(signal)))
    max_val = np.max(np.abs(signal)) * 1.05
    ax[0].set_ylim(-max_val, max_val)
    max_mag = np.max(dft(signal)[1][1:])
    ax[1].set_xlim((0, np.pi))
    ax[1].set_ylim((0, max_mag * 1.05))
    
    signal_line, = ax[0].plot([], [], color="coral")
    fft_line, = ax[1].plot([], [], color="darkmagenta")
    
    ax[0].set_ylabel("Attribution/importance")
    ax[1].set_xlabel("Frequency (radians)")
    ax[1].set_ylabel("Fourier magnitude")

    def init():
        signal_line.set_data(np.arange(len(signal)), signal)
        
        freqs, mags = dft(signal)
        freqs, mags = freqs[1:], mags[1:]
        fft_line.set_data(freqs, mags)
        return signal_line, fft_line

    def animate(frame_index):
        cutoff_signal = apply_lpf(signal, frame_index)
        signal_line.set_data(np.arange(len(cutoff_signal)), cutoff_signal)
        
        freqs, mags = dft(cutoff_signal)
        freqs, mags = freqs[1:], mags[1:]
        fft_line.set_data(freqs, mags)
        
        cutoff_radians = frame_index * 2 * np.pi / len(signal)
        return signal_line, fft_line

    frame_range = np.concatenate([
        np.ones(10, dtype=int) * len(signal),
        np.arange(len(signal), max_frequency_limit, -5),
        np.ones(10, dtype=int) * max_frequency_limit
    ])
    return animation.FuncAnimation(
        fig, animate, init_func=init, frames=frame_range, interval=50, blit=True
    )

In [None]:
anim = make_fft_cutoff_animation(noprior_imp, 150)
HTML(anim.to_html5_video())

In [None]:
anim.save("animation.gif", writer="pillow")

In [None]:
fig, ax = plt.subplots(figsize=(20, 4))
ax.plot(prior_imp, color="royalblue")
ax.set_xlim((0, len(prior_imp)))