In [None]:
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
import json
import csv
import cv2
from pytorch_grad_cam import GradCAM
from tqdm import tqdm

In [None]:
model_path = "llava-hf/llava-v1.6-mistral-7b-hf"

processor = LlavaNextProcessor.from_pretrained(model_path)
model = LlavaNextForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

In [None]:
# helper functions

def reshape_transform(tensor):
    B, N, C = tensor.shape
    H_W = int((N - 1) ** 0.5)
    assert H_W * H_W == (N - 1), f"Cannot reshape {N - 1} tokens into square feature map"

    tensor = tensor[:, 1:, :]
    return tensor.reshape(B, H_W, H_W, C).permute(0, 3, 1, 2)

class TokenLogitWrapper(torch.nn.Module):
    def __init__(self, base_model, inputs_template, token_index, num_image_views=5):
        super().__init__()
        self.model = base_model
        self.base_inputs = inputs_template.copy()
        self.base_inputs.pop("pixel_values", None)
        self.base_inputs.pop("image_sizes", None)
        self.token_index = token_index
        self.num_image_views = num_image_views

    def forward(self, pixel_values_4d):
        current_inputs = self.base_inputs.copy()
        B, C, H, W = pixel_values_4d.shape
        pixel_values_5d = torch.zeros(  
                                        B, 
                                        self.num_image_views, 
                                        C, 
                                        H, 
                                        W,
                                        dtype = pixel_values_4d.dtype,
                                        device = pixel_values_4d.device
                                    )
        pixel_values_5d[:, 0, :, :, :] = pixel_values_4d

        current_inputs["pixel_values"] = pixel_values_5d
        current_inputs["image_sizes"] = torch.tensor([(H, W)] * B, device = pixel_values_4d.device)

        out = self.model(**current_inputs)
        logits = out.logits

        # shape: [batch_size, vocab_size]
        target_logits = logits[:, self.token_index, :]
        return target_logits.unsqueeze(0) if target_logits.ndim == 1 else target_logits

def upscale_image_if_needed(img, target_size=336):
    if min(img.size) < target_size:
        return img.resize((target_size, target_size), Image.BICUBIC)
    return img

def rect_to_mask(image_shape, rect):
    x1, y1, x2, y2 = map(int, rect)
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    mask[y1 : y2, x1 : x2] = 1
    return mask

def scale_bbox(bbox, orig_size, new_size):
    x_scale = new_size[0] / orig_size[0]
    y_scale = new_size[1] / orig_size[1]
    x, y, w, h = bbox
    return [int(x * x_scale), int(y * y_scale), int(w * x_scale), int(h * y_scale)]

def average_activation(cam_map: np.ndarray, region_mask: np.ndarray) -> float:
    """
    Computes average GradCAM activation within the region defined by region_mask.
    """
    region_area = np.sum(region_mask)
    if region_area == 0:
        return 0.0
    return np.sum(cam_map * region_mask) / region_area

def save_batch_to_csv(log_data, output_path, fieldnames, append=True):
    mode = 'a' if append else 'w'
    write_header = not os.path.exists(output_path) or not append

    with open(output_path, mode=mode, newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if write_header:
            writer.writeheader()
        writer.writerows(log_data)
    print(f"Saved {len(log_data)} entries to {output_path}")

In [None]:
root_dir = "/content/drive/MyDrive/raf_anger_images"
root_output_dir = root_dir + "/anger_activation_results"
output_csv_filename = "activation_results.csv"
output_csv_path = os.path.join(root_output_dir, output_csv_filename)
log_file_path = os.path.join(root_output_dir, "processed_images.txt")
image_dir = root_dir + "/all_images"
annotation_json_path = root_dir + "anger_all_filtered_faces_and_hair.json"
hair_mask_dir = os.path.join(root_dir + "/hair_segmentation", "hair_masks_fullsize")

In [None]:
keywords = ["anger", "angry"]

# emotion prompt text
prompt_txt = (
    "<image>\nFocusing only on the expressions and emotions of the person or people shown, describe the emotions and expressions of the person or people in the image from one of the following emotions: [happiness, sadness, fear, anger, neutral, unsure]. Keep the description to at most 50-60 words."
)


# activity prompt text
# prompt_txt = (
#     "<image>\nFocusing only on the activities that the person or people shown are doing, describe the activities that the person or people shown are doing from one of the following categories of activities: [helping and caring, eating, household, dance and music, personal care, posing, sports, transportation, work, other, unsure]. Keep the description to at most 50-60 words."
# )

In [None]:
activation_log = []
csv_fieldnames = ["image", "token_index", "token", "face_id", "gender", "hair", "face", "background", "non_background"]

with open(annotation_json_path, "r") as f:
    annotation_data = json.load(f)

processed_images_set = set()
if os.path.exists(log_file_path):
    try:
        with open(log_file_path, mode='r') as f:
            processed_images_set = {line.strip() for line in f}
        print(f"Resuming run. Found {len(processed_images_set)} images already processed in the log file.")
    except (IOError, ValueError) as e:
        print(f"Warning: Could not read existing log file at {log_file_path}. Starting from scratch. Error: {e}")

# get a list of already processed images in case a previous run was interrupted
all_image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith((".png", ".jpg"))]
remaining_image_paths = [path for path in all_image_paths if os.path.splitext(os.path.basename(path))[0] not in processed_images_set]

try:
    with open(log_file_path, 'w') as f:
        # write the list of already-processed files to the log
        for image_name in processed_images_set:
            f.write(f"{image_name}\n")
    print(f"Log file initialized with {len(processed_images_set)} previously processed images.")
except IOError as e:
    print(f"Error writing to initial log file: {e}")

print(f"Total images to process in this run: {len(remaining_image_paths)}")

total_identified_images = 0
total_unidentified_images = 0
processed_image_count = 0
SAVE_INTERVAL = 25

for image_path in tqdm(remaining_image_paths, desc = "Processing Images"):

    image_name = os.path.splitext(os.path.basename(image_path))[0]
    image_output_dir = os.path.join(root_output_dir + "/image_results", image_name + "_test")

    try:
        with open(log_file_path, 'a') as f:
            f.write(f"{image_name}\n")
    except IOError as e:
        print(f"Error appending to log file for image {image_name}: {e}")

    if image_name not in annotation_data:
        print("Continuing because image is not in annotation data")
        continue

    image = Image.open(image_path).convert("RGB")
    orig_size = image.size      # store original size before potential upscale
    image = upscale_image_if_needed(image)      # ensure image is at least target_size

    # prepare inputs using the processor
    inputs = processor(prompt_txt, image, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    input_ids = inputs["input_ids"]
    pixel_values = inputs["pixel_values"].clone().detach().requires_grad_(True)

    # the pixel_values tensor from LLaVA's processor will be [1, num_views, C, H, W]
    pixel_values_for_gradcam = inputs["pixel_values"][:, 0, :, :, :].clone().detach().requires_grad_(True)

    # get the number of image views expected by the model for the TokenLogitWrapper
    # this is the second dimension of the original pixel_values
    num_image_views_in_input = inputs["pixel_values"].shape[1]

    # generate output
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=120)

    caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    assistant_response = caption.split("\n\n")[-1].strip()

    if not any(word in assistant_response.lower() for word in keywords):
        print(f"Skipping {image_name}: caption does not include any of the keywords: {caption}")
        total_unidentified_images += 1
        continue
    else:
        print(f"{image_name}: {caption}")
        total_identified_images += 1
        os.makedirs(image_output_dir, exist_ok = True)

    caption_ids = processor.tokenizer(caption, return_tensors="pt").input_ids[0]
    generated_token_ids = caption_ids.to(model.device)

    target_layer = model.vision_tower.vision_model.encoder.layers[9].layer_norm2

    # determine the target size for GradCAM's output
    # this should be the H, W of the single view image input to the vision tower
    _, _, target_h, target_w = pixel_values_for_gradcam.shape
    gradcam_target_size = (target_w, target_h)

    for i, token_id in enumerate(generated_token_ids):
        # skip input prompt tokens, only consider generated tokens
        if i < input_ids.shape[1]:
            continue

        token_str = processor.tokenizer.decode([token_id]).lower()
        token_str_safe = "".join(c if c.isalnum() else "_" for c in token_str)
        if token_str not in keywords:
            continue

        wrapped_base_inputs = {
            "input_ids": input_ids,
            "attention_mask": inputs["attention_mask"],
        }

        # pass the expected num_image_views to the wrapper
        wrapped = TokenLogitWrapper(model, wrapped_base_inputs, token_index=i, num_image_views=num_image_views_in_input)
        torch.cuda.empty_cache()
        cam = GradCAM(model=wrapped, target_layers=[target_layer], reshape_transform=reshape_transform)

        # pass the single extracted view to GradCAM's input_tensor
        grayscale = cam(input_tensor=pixel_values_for_gradcam)[0]

        torch.cuda.empty_cache()

        # get the heatmap for the image
        vis_image = image.resize((grayscale.shape[1], grayscale.shape[0]))
        rgb = np.array(vis_image).astype(np.float32) / 255.0
        heatmap = plt.cm.jet(grayscale)[..., :3]
        overlay = np.clip(0.5 * heatmap + 0.5 * rgb, 0, 1)
        overlay_img = Image.fromarray(np.uint8(overlay * 255))

        # get activations via the heatmap for the relevant people
        for face_id, info in annotation_data[image_name].items():
            face_rect = info["face_coords"]
            hair_path = os.path.join(hair_mask_dir, info["hair_mask_path"])

            if not os.path.exists(hair_path):
                print(f"Hair path doesn't exist ({hair_path}), continuing")
                continue

            gender = info.get("gender", "unsure")

            hair_mask = np.load(hair_path)
            hair_mask = (hair_mask > 0).astype(np.uint8)
            hair_mask = cv2.resize(hair_mask, (grayscale.shape[1], grayscale.shape[0]), interpolation = cv2.INTER_NEAREST)

            resized_bbox_face = scale_bbox(face_rect, orig_size, vis_image.size)
            face_mask = rect_to_mask(grayscale.shape, resized_bbox_face)

            # prevent overlap between hair and face
            overlap = np.logical_and(face_mask, hair_mask).astype(np.uint8)
            face_mask_clean = np.clip(face_mask - overlap, 0, 1)

            # define background and combined masks
            bg_mask = 1 - np.clip(hair_mask + face_mask_clean, 0, 1)
            non_background_mask = np.clip(hair_mask + face_mask_clean, 0, 1)
            
            activations = {
                "hair": average_activation(grayscale, hair_mask),
                "face": average_activation(grayscale, face_mask_clean),
                "background": average_activation(grayscale, bg_mask),
                "non_background": average_activation(grayscale, non_background_mask)
            }

            total = sum(activations.values())
            if total == 0:
                print(f"Skipping {image_name} due to zero regional activation")
                continue

            activations = {k: v / total for k, v in activations.items()}

            activation_log.append({
                "image": image_name,
                "token_index": i,
                "token": token_str,
                "face_id": face_id,
                "gender": gender,
                **activations
            })

            overlay_save_path = os.path.join(image_output_dir, f"{i}_{token_str_safe}.png")
            overlay_img.save(overlay_save_path)

            break  # only consider the first admissible token

    processed_image_count += 1
    if processed_image_count % SAVE_INTERVAL == 0:
        save_batch_to_csv(activation_log, output_csv_path, csv_fieldnames, append=True)
        activation_log.clear()

if activation_log: # check if there's any data left in the log
    save_batch_to_csv(activation_log, output_csv_path, csv_fieldnames, append=True)
    activation_log.clear()

print("Done.")