In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch numpy opencv-python transformers matplotlib albumentations gdown

In [None]:
import gdown
import zipfile
#!pip install transformers torch torchvision datasets evaluate torchmetrics

# Google Drive file ID
file_id = "1BqMBtsuvb6mTpiZUZ9WKcJA8f1hkI2yX"
url = f"https://drive.google.com/uc?id={file_id}"

# Download file
output = "HMDB_simp_clean.zip"
gdown.download(url, output, quiet=False)

# Unzip the file
with zipfile.ZipFile(output, 'r') as zip_ref:
    zip_ref.extractall(".")

print("Download and extraction complete!")

In [3]:
import torch
import numpy as np
import cv2
import os
from transformers import TimesformerForVideoClassification, AutoImageProcessor
import matplotlib.pyplot as plt
import albumentations as A

In [4]:
default_transforms = A.Compose(
    [
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        #A.Normalize(mean=[0.5, 0.5,0.5], std=[0.5, 0.5, 0.5]),
        A.ToTensorV2(),
    ]
)

In [None]:
CATEGORY_INDEX = {
    "brush_hair": 0,
    "cartwheel": 1,
    "catch": 2,
    "chew": 3,
    "climb": 4,
    "climb_stairs": 5,
    "draw_sword": 6,
    "eat": 7,
    "fencing": 8,
    "flic_flac": 9,
    "golf": 10,
    "handstand": 11,
    "kiss": 12,
    "pick": 13,
    "pour": 14,
    "pullup": 15,
    "pushup": 16,
    "ride_bike": 17,
    "shoot_bow": 18,
    "shoot_gun": 19,
    "situp": 20,
    "smile": 21,
    "smoke": 22,
    "throw": 23,
    "wave": 24,
}

To access the trained model, please extract the attached transformer_output files and specify the directory to it.

In [None]:
# targeted  video
video_folder = "/content/HMDB_simp_clean/brush_hair/F24EB2B7"
# targeted  class
CLASS = "brush_hair"
target_category = CATEGORY_INDEX[CLASS]

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# loading trained model
model_dir = "/content/" # path where model.safetensors, optimizer.pt, etc are located
model = TimesformerForVideoClassification.from_pretrained(model_dir).to(device)
processor = AutoImageProcessor.from_pretrained(model_dir)

In [None]:
# chosen target layer which worked best
target_layer = model.timesformer.encoder.layer[-1].attention.attention

activations = None
gradients = None


def save_activation_hook(module, input, output):
    global activations
    activations = output[0].detach()


def save_gradient_hook(module, grad_input, grad_output):
    global gradients
    gradients = grad_output[0].detach()


target_layer.register_forward_hook(save_activation_hook)
target_layer.register_full_backward_hook(save_gradient_hook)


def load_frames(folder_path):
    frame_files = sorted(os.listdir(folder_path))
    original_frames = []
    frames = []
    assert len(frame_files) >= 8 * 8
    for file in frame_files[:8*8:8]:
        img_path = os.path.join(folder_path, file)
        frame = cv2.imread(img_path)
        frame = cv2.resize(frame, (224, 224))
        # Convert to grayscale but keep 3 dimensions for RGB compatibility
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        rgb_compatible_frame = cv2.cvtColor(gray_frame, cv2.COLOR_GRAY2BGR)
        original_frames.append(rgb_compatible_frame)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = default_transforms(image=frame)["image"]
        frames.append(frame)
    frames = torch.stack(frames)  # (8, 3, 224, 224)
    return frames, original_frames


clip, original_frames = load_frames(video_folder)  # video tensor (8, 3, 224, 224)
inputs = clip.unsqueeze(0).to(device).float()  # matching size - (1, 8, 3, 224, 224)
inputs = {"pixel_values": inputs}

# Forward pass to get class scores
outputs = model(**inputs)
logits = outputs.logits

# Backpropagate only the target class to compute gradients
model.zero_grad()
logits[0, target_category].backward()

# Grad-CAM calculation
# activations: (1, 8, 197, 768)
# gradients: (1, 8, 197, 768)
pooled_gradients = torch.mean(gradients, dim=[0, 1])
weighted_activations = activations * pooled_gradients.unsqueeze(0).unsqueeze(0)
heatmap = torch.mean(weighted_activations, dim=-1).squeeze().cpu().numpy()
heatmap = heatmap[:, 1:]  # remove CLS token
heatmap = heatmap.reshape(len(original_frames), 14, 14)
heatmap_resized = np.array([cv2.resize(h, (224, 224)) for h in heatmap])
heatmap_resized = np.array(
    [(h - np.min(h)) / (np.max(h) - np.min(h) + 1e-8) for h in heatmap_resized]
)

  plt.tight_layout()


In [None]:
# Overlay and save each frame with Grad-CAM heatmap
FRAME_PERCENTAGE = 0.6
for idx in range(heatmap_resized.shape[0]):
    frame = original_frames[idx]
    heatmap_uint8 = np.uint8(255 * heatmap_resized[idx])
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(
        frame, FRAME_PERCENTAGE, heatmap_color, 1 - FRAME_PERCENTAGE, 0
    )
    plt.figure(figsize=(5, 5))
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.title(f"Frame {idx}")
    plt.axis("off")
    os.makedirs("cam", exist_ok=True)
    plt.savefig(f"cam/cam_{idx}.png")
    plt.close()


    if idx == heatmap_resized.shape[0] - 1:
        fig, axes = plt.subplots(1, heatmap_resized.shape[0], figsize=(20, 4))
        # colormap legend
        cmap = plt.cm.jet
        norm = plt.Normalize(vmin=0, vmax=1)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])


        cbar_ax = fig.add_axes([0.15, 0.15, 0.7, 0.02])
        cbar = fig.colorbar(sm, cax=cbar_ax, orientation="horizontal")
        cbar.set_label("Activation Intensity")

        # plot
        plt.subplots_adjust(bottom=0.2)
        for i in range(heatmap_resized.shape[0]):
            frame = original_frames[i]
            hm_uint8 = np.uint8(255 * heatmap_resized[i])
            hm_color = cv2.applyColorMap(hm_uint8, cv2.COLORMAP_JET)
            frame_overlay = cv2.addWeighted(
                frame, FRAME_PERCENTAGE, hm_color, 1 - FRAME_PERCENTAGE, 0
            )
            axes[i].imshow(cv2.cvtColor(frame_overlay, cv2.COLOR_BGR2RGB))
            axes[i].set_title(f"Frame {i}")
            axes[i].axis("off")
        plt.tight_layout()
        plt.savefig("cam/all_frames_strip.png")
        plt.close()
