# Focus Probe

This notebook trys to see what part this model looks at during prediction. 

In [None]:
import torch
import matplotlib.pyplot as plt


from model_dataset import ToneDatasetNew as ThisDataset
from model_incremental import *
from model_dataset import TokenMap
from A_3 import configs
from paths import *
from H_1_models import TwoConvNetwork

In [None]:
def grad_cam(model, input_tensor, target_layer, target_class):
    model.eval()  # Set the model to evaluation mode
    
    gradients = []
    activations = []

    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0])

    def forward_hook(module, input, output):
        activations.append(output)

    hook_b = target_layer.register_backward_hook(backward_hook)
    hook_f = target_layer.register_forward_hook(forward_hook)

    output = model(input_tensor)
    class_score = output[0, target_class]

    model.zero_grad()
    class_score.backward()

    grad = gradients[0].cpu().data.numpy()[0]
    act = activations[0].cpu().data.numpy()[0]

    hook_b.remove()
    hook_f.remove()

    weights = np.mean(grad, axis=(1, 2))
    cam = np.sum(weights[:, None, None] * act, axis=0)

    cam = np.maximum(cam, 0)
    cam = cam - np.min(cam)
    cam = cam / np.max(cam) if np.max(cam) != 0 else cam

    return cam

def get_one_batch(dataloader):
    dataloader_iter = iter(dataloader)  # Create an iterator
    batch = next(dataloader_iter)      # Get the first batch
    return batch

def grad_cam_visualization(input_tensor, cam, title="Grad-CAM Heatmap on Test Input"): 
    # Plot the heatmap on the test Mel spectrogram
    plt.imshow(input_tensor.squeeze(0).squeeze(0).cpu().numpy(), aspect='auto', cmap='viridis')
    plt.imshow(cam, cmap='jet', alpha=0.5, extent=(0, input_tensor.size(-1), 0, input_tensor.size(-2)))
    plt.colorbar(label='Importance')
    plt.title(title)
    plt.show()

In [None]:
# Load validation data: target and full
pretype = "l"
posttype = "f"
train_name = "A3"
ts = "1126125730"
model_type = "twoconvCNN"
pre_epoch = 50
total_epoch = 300
selection = "full"


model_save_dir = os.path.join(model_save_, f"{train_name}-{ts}")
guides_dir = os.path.join(model_save_dir, "guides")
model_save_dir_specific = os.path.join(model_save_dir, f"{model_type}-{pre_epoch}-{total_epoch-pre_epoch}", selection, f"{pretype}{posttype}")

pool_messanger = PoolMessanger(configs["num_dataset"], configs["data_type_mapper"][pretype], configs["data_type_mapper"][posttype], guides_dir)

# NOTE: Subset Cache, this is to manage the reading of datasets. Should be transparent to user. 
valid_cache = SubsetCache(max_cache_size=configs["max_cache_size_valid"], dataset_class=ThisDataset)
full_valid_cache = SubsetCache(max_cache_size=configs["max_cache_size_valid"], dataset_class=ThisDataset)

mylist = ["1", "2", "3", "4"]
mymap = TokenMap(mylist)

In [None]:
epoch = 20
# Load Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TwoConvNetwork()
model.to(device)

model_path = os.path.join(model_save_dir_specific, f"{epoch}.pt")
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

In [None]:
dataset_id = 0

dataset_id, meta_path, data_path = pool_messanger.get_loading_params(dataset_id,
                                                                    eval_type="valid")
valid_loader = valid_cache.get_subset(dataset_id, meta_path, data_path, mymap)


dataset_id, meta_path, data_path = pool_messanger.get_loading_params(dataset_id,
                                                                    eval_type="full_valid")
full_valid_loader = full_valid_cache.get_subset(dataset_id, meta_path, data_path, mymap)

In [None]:
valid_mel, valid_tag = get_one_batch(valid_loader)
full_valid_mel, full_valid_tag = get_one_batch(full_valid_loader)

In [None]:
# Get one randomly
random_idx = np.random.randint(0, len(valid_mel))
one_valid_mel, one_valid_tag = valid_mel[random_idx].unsqueeze(0), valid_tag[random_idx].item()
one_full_valid_mel, one_full_valid_tag = full_valid_mel[random_idx].unsqueeze(0), full_valid_tag[random_idx].item()

In [None]:
target_layer = model.conv[0]    # 4 for the second convolutional layer
cam_valid = grad_cam(model, one_valid_mel, target_layer, one_valid_tag)
cam_full_valid = grad_cam(model, one_full_valid_mel, target_layer, one_full_valid_tag)

In [None]:
grad_cam_visualization(one_valid_mel, cam_valid, title=f"Grad-CAM Valid {one_valid_tag}")
grad_cam_visualization(one_full_valid_mel, cam_full_valid, title=f"Grad-CAM Full Valid {one_full_valid_tag}")