# Visualizations

This notebook assists with loading visualization-related items.

In [22]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import checkpoint

import numpy as np
import matplotlib.pyplot as plt
from model import SWEEM
from visualization import visualizePathwayModules, visualizeModelAttention
import seaborn as sns

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
data = pd.read_csv('./Data/OmicsData/data.csv')

# Separate to make sure that there's an even distribution of 1s and 0s in train and test
data_ones = data[data.iloc[:, -1] == 1]
data_zeros = data[data.iloc[:, -1] == 0]

# Split the data into train and validation sets.
train_data_ones, test_data_ones, train_labels_ones, test_labels_ones = train_test_split(
    data_ones.iloc[:, 1:-2], data_ones.iloc[:, -2:], test_size=0.2, random_state=42)
train_data_zeros, test_data_zeros, train_labels_zeros, test_labels_zeros = train_test_split(
    data_zeros.iloc[:, 1:-2], data_zeros.iloc[:, -2:], test_size=0.2, random_state=42)

# Concatenate in the end to make train and test
train_data = pd.concat((train_data_ones, train_data_zeros))
train_labels = pd.concat((train_labels_ones, train_labels_zeros))
test_data = pd.concat((test_data_ones, test_data_zeros)) 
test_labels = pd.concat((test_labels_ones, test_labels_zeros))

# Create Tensor datasets
train_dataset = TensorDataset(torch.tensor(train_data.values, dtype=torch.float32), torch.tensor(train_labels.values, dtype=torch.float32))
test_dataset  = TensorDataset(torch.tensor(test_data.values, dtype=torch.float32), torch.tensor(test_labels.values, dtype=torch.float32))

# Create DataLoader objects
batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=True)

In [35]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
settings = {
    "model": {
        "rna_dim": 5540,
        "scna_dim": 5507,
        "methy_dim": 4846,
        "use_rna": False,
        "use_scna": False,
        "use_methy": True,
        "hidden_dim": 32,
        "self_att": True,
        "cross_att": False,
        "device": device
    },
    "train": {
        "lr": 0.00001,
        "l2": 1e-3,
        "epochs":1001,
        "epoch_mod": 25
    }
}

### Model trained on methylation data. 
model, optimizer, epoch_train_losses, epoch_val_losses, settings = checkpoint.load("./sweem.model", SWEEM, optim.Adam)
model = SWEEM(**settings["model"])
model.load_state_dict(torch.load("./sweem_inference.model", map_location=device))

<All keys matched successfully>

In [None]:
visualizePathwayModules(rna=True, scna=True, methy=True)
plt.close()

In [100]:
def integrated_gradients(model, event, kwarg_dict, baseline=None, num_steps=50):
    input_data = kwarg_dict['methy']

    if baseline is None:
        baseline = torch.zeros_like(input_data)

    scaled_inputs = [baseline + (float(i) / num_steps) * (input_data - baseline) for i in range(num_steps + 1)]

    total_gradients = torch.zeros_like(input_data)

    criterion = torch.nn.BCELoss()

    for scaled_input in scaled_inputs:
        scaled_input.requires_grad_()
        print(scaled_input.shape)
        print(kwarg_dict['rna'].shape)
        # Compute the model outputs
        outputs = model(event=event, rna=kwarg_dict['rna'], scna=kwarg_dict['scna'], methy=scaled_input)
        # Compute the loss between the model outputs and the target
        loss = criterion(outputs, event).sum()
        kwarg_dict['methy'] = scaled_input
        
        input = (event, kwarg_dict['rna'], kwarg_dict['scna'], kwarg_dict['methy'])
        grads = torch.autograd.grad(loss, input)[0]
        total_gradients += grads.detach()  # Detach to prevent memory leak

    # Calculate integrated gradients after the loop
    integrated_grad = (input_data - baseline) * total_gradients / num_steps
    return integrated_grad


In [None]:
# Integrated Gradients Interptability.
model.device = device

with torch.no_grad():
    for i, (batchX, batchY) in enumerate(test_dataloader):
        batchX = batchX.to(device)
        rna = batchX[:, :5540].to(device)
        scna = batchX[:, 5540:11047].to(device)
        methy = batchX[:, 11047:].to(device)
        time = batchY[:, 0].reshape(-1, 1).to(device)
        event = batchY[:, 1].reshape(-1, 1).to(device)
        outputs = model(event=event, rna=rna, scna=scna, methy=methy)

        kwarg_dict = {'rna':rna, 'scna':scna, 'methy':methy}
        baseline = torch.zeros_like(methy)
        integrated_grad = integrated_gradients(model, event, kwarg_dict, baseline=baseline)

        print(integrated_gradients)
        break


In [None]:
# Attempted Model Attention Perturbation-based Interpretation.
## The captum attr library does not seem to support our input format of **kwargs,
## so we will have to rework the model to take in a single input. For now, we focus
## on other attention-based interpretability methods that we can use.
from captum.attr import LayerGradientXActivation
model.device = device

layer_grad_x_activation = LayerGradientXActivation(model, model.methyl_att)

with torch.no_grad():
    for i, (batchX, batchY) in enumerate(test_dataloader):
        batchX = batchX.to(device)
        rna = batchX[:, :5540].to(device)
        scna = batchX[:, 5540:11047].to(device)
        methy = batchX[:, 11047:].to(device)
        time = batchY[:, 0].reshape(-1, 1).to(device)
        event = batchY[:, 1].reshape(-1, 1).to(device)
        outputs = model(event=event, rna=rna, scna=scna, methy=methy)

        attributions, delta = layer_grad_x_activation.attribute({'methy':methy})
        # attributions = attributions.squeeze().cpu().detach().numpy()

        # fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        # axs[0].imshow(methy[0].cpu().detach().numpy(), cmap='viridis', aspect='auto')
        # axs[0].set_title('Original Methy Input')

        # axs[1].imshow((methy + delta)[0].cpu().detach().numpy(), cmap='viridis', aspect='auto')
        # axs[1].set_title('Perturbed Methy Input')

        # axs[2].imshow(attributions[0], cmap='seismic', aspect='auto')
        # axs[2].set_title('Attributions')

        # plt.show()
        break
