In [30]:
import os
import math
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from io import BytesIO

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

from sklearn.manifold import TSNE

  matplotlib.use("TkAgg")


In [None]:
from model import SimpleVLMSeq2Seq
from data import generate_clock_image, text_to_indices, generate_dataset, ClockVLMDataset

In [None]:
def load_model(pth_path, device, vocab_size, embed_dim, text_hidden_dim, num_classes):
    model = SimpleVLMSeq2Seq(vocab_size, embed_dim, text_hidden_dim, num_classes)
    model.load_state_dict(torch.load(pth_path, map_location=device))
    model.to(device)
    model.eval()
    return model


In [6]:
def extract_intermediate_representation(model, sample_image, sample_text, device):
    activations = {}
    def hook_fn(module, input, output):
        # Save output from the img_encoder (before fc)
        activations["img_encoder"] = output.detach()
    hook_handle = model.img_encoder.register_forward_hook(hook_fn)
    
    # Forward pass (add batch dimension)
    sample_image = sample_image.unsqueeze(0).to(device)
    sample_text = sample_text.unsqueeze(0).to(device)
    with torch.no_grad():
        _ = model(sample_image, sample_text)
    hook_handle.remove()
    
    # Return the captured activation.
    # Shape is (1, 64, 1, 1)
    return activations["img_encoder"]

In [7]:
def compute_saliency_map(model, sample_image, sample_text, device):
    # Make sure sample_image has grad enabled.
    sample_image = sample_image.unsqueeze(0).to(device)
    sample_image.requires_grad_()
    sample_text = sample_text.unsqueeze(0).to(device)
    
    model.zero_grad()
    output = model(sample_image, sample_text)
    # Choose the highest scoring class.
    pred_class = output.argmax(dim=1).item()
    # Take the logit for the predicted class.
    score = output[0, pred_class]
    score.backward()
    
    # Saliency: take the absolute value of gradients.
    saliency, _ = torch.max(sample_image.grad.data.abs(), dim=1)
    # saliency shape: (1, H, W)
    saliency = saliency.squeeze().cpu().numpy()
    return saliency, pred_class


In [26]:
def visualize_tsne(model, dataset, device, save_path="tsne_plot.png"):
    model.eval()
    features = []
    labels = []
    
    loader = DataLoader(dataset, batch_size=4, shuffle=False)
    with torch.no_grad():
        for img, text, label in loader:
            img = img.to(device)
            text = text.unsqueeze(0).repeat(img.size(0), 1).to(device) if text.dim() == 1 else text.to(device)
            # Extract features from the image branch BEFORE the fully connected layer.
            feat = model.img_encoder(img)
            feat = feat.view(feat.size(0), -1)  # shape: (B, 64)
            features.append(feat.cpu())
            labels.extend(label.tolist())
    
    features = torch.cat(features, dim=0).numpy()
    
    tsne = TSNE(n_components=2, perplexity=5,random_state=42)
    features_2d = tsne.fit_transform(features)
    
    n = 5  # Set this to the discrete clock positions used in your dataset
    hour_labels = np.array(labels) // n  # Integer division to get the hour

    # Create the scatter plot using the hour_labels for coloring.
    plt.figure(figsize=(6,6))
    scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=hour_labels, cmap='tab10')
    plt.colorbar(scatter, label='Hour')
    plt.title("t-SNE of Vision Encoder Features (Colored by Hour)")
    plt.show()
    # plt.savefig("tsne_plot.png")
    plt.close()
    # print("t-SNE plot saved as tsne_plot.png")


In [None]:
# Configuration parameters (should match training settings)
n = 5
image_size = 128
num_classes = n * n

# Text prompt and vocabulary (must match training)
prompt = "Tell me the time on the clock"
vocab = {"tell": 0, "me": 1, "the": 2, "time": 3, "on": 4, "clock": 5}

# Define image transformation (must match training)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

# Device configuration.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model hyperparameters for text branch.
vocab_size = len(vocab)
embed_dim = 32
text_hidden_dim = 128

# Load the saved model.
model_path = "clock_vlm_model_5_128_100_4_0.0005.pth"
model = load_model(model_path, device, vocab_size, embed_dim, text_hidden_dim, num_classes)

# Generate one sample image for interpretability.
sample_img_np, time_str = generate_clock_image(hour=3, minute=2, n=n, size=image_size)
# Save sample image for reference.
plt.imshow(sample_img_np)
# Image.fromarray(sample_img_np).save("sample_clock.png")
# print("Sample clock image saved as sample_clock.png")

# Apply transform.
sample_img = transform(Image.fromarray(sample_img_np))
# Prepare text indices tensor.
sample_text_indices = torch.tensor(text_to_indices(prompt, vocab), dtype=torch.long)

# ---------------------------
# 1. Intermediate Representation
# ---------------------------
rep = extract_intermediate_representation(model, sample_img, sample_text_indices, device)
# rep has shape (1, 64, 1, 1) so we squeeze spatial dims.
rep_vector = rep.view(-1).cpu().numpy()
print("Extracted intermediate representation (vision encoder output) shape:", rep.shape)
# For visualization, we can plot a bar graph of the 64 features.
plt.figure(figsize=(8,4))
plt.bar(np.arange(len(rep_vector)), rep_vector)
plt.title("Intermediate Features from Vision Encoder (flattened)")
plt.xlabel("Feature index")
plt.ylabel("Activation")
plt.show()
# plt.savefig("intermediate_representation.png")
plt.close()
# print("Intermediate representation plot saved as intermediate_representation.png")

# ---------------------------
# 2. Saliency Map Computation
# ---------------------------
saliency, pred_class = compute_saliency_map(model, sample_img, sample_text_indices, device)
print("Predicted class for sample image:", pred_class)
# Plot and overlay the saliency map on the original image.
plt.figure(figsize=(6,6))
# Show original image.
orig_img = sample_img.cpu().permute(1,2,0).numpy()
# Un-normalize the image (assuming normalization: (x-0.5)/0.5)
orig_img = (orig_img * 0.5) + 0.5
plt.imshow(orig_img)
plt.imshow(saliency, cmap='jet', alpha=0.5)
plt.title("Saliency Map Overlay")
plt.axis('off')
plt.show()
# plt.savefig("saliency_map.png")
plt.close()
# print("Saliency map saved as saliency_map.png")

# ---------------------------
# 3. t-SNE Visualization of the Entire Dataset's Image Features
# ---------------------------
# For t-SNE, generate the full dataset.
images, labels, time_strings = generate_dataset(n=n, size=image_size)
dataset = ClockVLMDataset(images, labels, prompt, vocab, transform=transform)
visualize_tsne(model, dataset, device)


Extracted intermediate representation (vision encoder output) shape: torch.Size([1, 64, 1, 1])
Predicted class for sample image: 16


KeyboardInterrupt: 

: 

In [20]:
images, labels, time_strings = generate_dataset(n=n, size=image_size)
print(labels)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
