In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os

In [10]:
def plot_position_wise_attention(weights, path, quarters, positions, xlabel, ylabel, title):
    
    plt.style.use('ggplot')
    cax = plt.matshow(weights.numpy(), cmap='BuGn')
    # plt.rcParams.update({'font.size': 10})
    plt.colorbar(cax)
    plt.grid(
        visible=False,
        axis='both',
        which='both',
    )
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(ticks = range(0, 6), labels = quarters, rotation = 45)
    plt.yticks(ticks = range(0, len(positions)), labels = positions, rotation = 0)
    plt.title(title, pad=15)
    plt.savefig(path, format='pdf', dpi=1200)
    plt.close()

In [9]:
def plot_attention(weights, path, quarters, xlabel, ylabel, title):
    """
    Plots attention weights in a grid.
    """

    plt.style.use('ggplot')
    cax = plt.matshow(weights.numpy(), cmap='BuGn')
    plt.rcParams.update({'font.size': 10})
    plt.colorbar(cax)
    plt.grid(
        visible=False,
        axis='both',
        which='both',
    )
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(ticks = range(0, 6), labels = quarters, rotation = 45)
    if ylabel == 'Quarter':
        plt.yticks(ticks = range(0, 6), labels = quarters, rotation = 75)
    plt.title(title, pad=15)
    plt.savefig(path, format='pdf', dpi=1200)
    plt.close()

In [2]:
weights = torch.load('results_cov/attn/new_model/weights.pth')
len(weights)

9

In [3]:
weights.keys()

dict_keys(['channel_1_attn_map', 'channel_2_attn_map', 'channel_3_attn_map', 'channel_4_attn_map', 'channel_1_weights', 'channel_2_weights', 'channel_3_weights', 'channel_4_weights', 'final_weights'])

In [131]:
for i in range(1, 5):
    print(len(weights[f'channel_{i}_attn_map']))
    for layer in weights[f'channel_{i}_attn_map']:
        layer.shape

2
2
2
2


In [132]:
layer_x_year = []
quarters = ['Q3 2020', 'Q4 2020', 'Q1 2021', 'Q2 2021', 'Q3 2021', 'Q4 2021']
seq_idx = [178,  57, 172,  56 , 97]
epitopes = []
hotspots = []
seq_attn_maps = [
    [
        [] for _ in range(4)
    ] for _ in range(5)
]
with open('data/cov_epitopes/epitopes_sorted.txt', 'r') as f:
    for line in f:
        epitopes.append(int(line))

with open('data/cov_epitopes/epitopes_hotspot.txt', 'r') as f:
    for line in f:
        hotspots.append(int(line))

hotspots.sort()

for channel_idx in range(1, 5):
    keys = f'channel_{channel_idx}_attn_map'
    sam = weights[keys]
    for idx, am in enumerate(sam):
        for i, pos in enumerate(seq_idx):
            seq_attn_maps[i][channel_idx-1].append(am[pos*149:(pos+1)*149, : ,:])
        # am_2d = torch.sum(am, dim=0).cpu()
        # am_1d = torch.sum(am_2d, dim=0).cpu()
        # layer_x_year.append(am_1d)

In [104]:
len(seq_attn_maps), len(seq_attn_maps[0]), len(seq_attn_maps[0][0])

(5, 4, 2)

In [133]:
for seq_i, seq_no in enumerate(seq_idx):
    print(f'Printing for seq no {seq_no}')
    for transformer_idx in range(4):
        for layer_idx in range(2):
            dir = f'results_cov/attn/new_model/{seq_no}/transformer_{transformer_idx+1}/layer_{layer_idx+1}/'
            if not os.path.exists(dir):
                os.makedirs(dir)
            attn_map = seq_attn_maps[seq_i][transformer_idx][layer_idx]
            pos_attn = torch.sum(attn_map, dim=1).cpu()
            epitopes_np = np.asarray(epitopes)
            hotspot_idxs = np.in1d(epitopes_np, hotspots)
            pos_attn_filtered = pos_attn[hotspot_idxs]
            np.savetxt(dir+f'attn_map.txt', pos_attn)
            plot_position_wise_attention(pos_attn, dir+f'attn_map.pdf', quarters , epitopes, xlabel='Quarter', ylabel='Mutation Sites', title=f'Layer-{layer_idx+1} Self Attention Map')
            print(pos_attn_filtered)
            np.savetxt(dir+f'attn_map_filtered.txt', pos_attn_filtered)
            plot_position_wise_attention(pos_attn_filtered, dir+f'attn_map_filtered_t_{transformer_idx+1}_l_{layer_idx+1}.pdf', quarters , hotspots, xlabel='Quarter', ylabel='Mutation Sites', title=f'Layer-{layer_idx+1} Attention Map')
        # final_am = torch.cat(layer_x_year).cpu()
        # final_am = final_am.view(8, 6)
        # plot_attention(final_am, f'results_cov/attn/final_attn_map.pdf', quarters, xlabel='Quarter', ylabel='Layer', title='Layerwise Attention Map')


Printing for seq no 178
tensor([[1.9087, 1.9995, 0.8909, 0.0272, 0.7130, 0.4608],
        [1.9699, 2.0785, 0.7165, 0.0247, 0.6886, 0.5217],
        [1.8767, 1.9912, 1.1035, 0.0525, 0.5665, 0.4096],
        [1.6150, 1.7264, 0.9673, 0.0530, 0.4780, 1.1603],
        [1.5462, 1.6176, 1.0413, 0.6657, 0.5792, 0.5500],
        [1.4750, 1.5197, 1.0190, 0.7167, 0.6742, 0.5955],
        [1.5213, 1.5648, 0.9663, 0.6806, 0.6119, 0.6551],
        [1.6209, 1.6786, 0.9127, 0.5851, 0.7196, 0.4832]])
tensor([[0.8486, 0.7173, 0.9659, 1.3884, 0.5962, 1.4836],
        [0.8784, 0.7756, 0.9144, 1.3484, 0.6740, 1.4093],
        [1.4781, 1.2507, 0.4865, 1.0939, 0.7721, 0.9186],
        [1.1829, 0.9047, 0.6892, 1.0688, 0.8234, 1.3310],
        [1.3279, 1.4296, 0.5630, 0.6995, 0.8508, 1.1292],
        [0.7981, 0.8903, 1.0677, 1.2001, 0.9574, 1.0863],
        [1.2787, 1.1495, 0.6749, 0.7578, 1.2586, 0.8805],
        [0.8595, 0.9628, 0.7645, 0.7998, 1.1036, 1.5099]])
tensor([[1.1528, 0.9474, 1.0606, 0.5048, 0.972

In [9]:
def plot_channel_weights(arr, title, channel = None):
    print(arr)
    dir = f'results_cov/attn/new_model/channels/'
    print(title)

    plt.style.use('ggplot')
    cax = plt.matshow(arr, cmap='BuGn')
    plt.rcParams.update({'font.size': 10})
    plt.colorbar(cax)
    plt.grid(
        visible=False,
        axis='both',
        which='both',
    )

    plt.title(title, pad=15)
    if channel:
        print('In if')
        path = dir+f'channel_{channel}.pdf'
        plt.xticks(ticks = range(0, 2), labels=['Attention\nApplied', 'Attention\nNot Applied'], rotation = 45)
    else:
        print('in else')
        path = dir+'final_weights.pdf'
        plt.xticks(ticks = range(0, 4), labels=['Channel 1', 'Channel 2', 'Channel 3', 'Channel 4'],
                   rotation=45)
    print(path)
    plt.savefig(path, format='pdf', dpi=1200)
    plt.close()
    

In [10]:
channel_weights = []
for i in range(1, 5):
    key = f'channel_{i}_weights'
    print(key)
    weight = weights[key].cpu().detach().numpy()
    print(weight)
    channel_weights.append(np.asarray([weight[0], 1-weight[0]]))

channel_weights.append(weights['final_weights'].cpu().detach().numpy())

for ch_idx, channel_weight in enumerate(channel_weights):
    # print(channel_weight.reshape(1,-1))
    print(ch_idx)
    print(channel_weight)
    if 0 <= ch_idx < 4:
        print('In if')
        print(ch_idx)
        plot_channel_weights(channel_weight.reshape(1,-1), f'Channel {ch_idx+1} weights', ch_idx+1)
    else:
        print('In else')
        print(ch_idx)
        plot_channel_weights(channel_weight.reshape(1,-1), 'Final weights')

channel_1_weights
[0.984743]
channel_2_weights
[0.41686982]
channel_3_weights
[-0.02163854]
channel_4_weights
[0.17250489]
0
[0.984743 0.015257]
In if
0
[[0.984743 0.015257]]
Channel 1 weights
In if
results_cov/attn/new_model/channels/channel_1.pdf
1
[0.41686982 0.58313018]
In if
1
[[0.41686982 0.58313018]]
Channel 2 weights
In if
results_cov/attn/new_model/channels/channel_2.pdf
2
[-0.02163854  1.02163854]
In if
2
[[-0.02163854  1.02163854]]
Channel 3 weights
In if
results_cov/attn/new_model/channels/channel_3.pdf
3
[0.17250489 0.82749511]
In if
3
[[0.17250489 0.82749511]]
Channel 4 weights
In if
results_cov/attn/new_model/channels/channel_4.pdf
4
[1.0182537 0.8773319 1.0965743 1.0100062]
In else
4
[[1.0182537 0.8773319 1.0965743 1.0100062]]
Final weights
in else
results_cov/attn/new_model/channels/final_weights.pdf


In [128]:
torch.rand(1)

tensor([0.2494])