In [7]:
import torch
import numpy as np
import cv2
from dcnet import DCNet
from dataset_utility import ToTensor
from torch.autograd import Variable
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import seaborn as sns
import pandas as pd
import torch.nn.functional as F
import re

In [17]:
data_dict = {
    "distribute_four": {
        "correct": [9138, 3148], 
        "wrong": [929]
    },
    "distribute_nine": {
        "correct": [4088, 3149],
        "wrong": [9138, 1228]
    },
    "center_single": {
        "correct": [8558, 2529],
        "wrong": [8608]
    },
    "in_center_single_out_center_single": {
        "correct": [9138, 5878],
        "wrong": [7409, 7899]
    },
    "in_distribute_four_out_center_single": {
        "correct": [5878, 7789, 219],
        "wrong": [2469]
    }, 
    "left_center_single_right_center_single": {
        "correct": [5348, 218],
        "wrong": [7969]
    },
    "up_center_single_down_center_single": {
        "correct": [7029, 219],
        "wrong": [8669]
    }
}

In [4]:
# Load model
device = "mps"
model = DCNet().to(torch.device(device))
model.load_state_dict(torch.load('model_02.pth', map_location=torch.device("mps")))
model.eval()

DCNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (res1): ResBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (choice_contrast1): Sequential(
    (0): Conv2d(128, 128, kernel_size=(

# Find correct/wrong prediction

In [5]:
def find_predictions_by_match(model, base_dir, match_op='==', max_files=20, img_size=96):
    """
    Find files where the model prediction matches or mismatches the label based on `match_op`.

    Args:
        model: Trained DCNet model
        base_dir: Path to RAVEN directory
        match_op: '==' for correct, '!=' for incorrect
        max_files: Max number of files to return
        img_size: Resize shape (img_size, img_size)

    Returns:
        List of matching file paths
    """
    assert match_op in ['==', '!='], "match_op must be '==' or '!='"

    model.eval()
    device = next(model.parameters()).device
    matching_files = []
    pattern = re.compile(r'RAVEN_\d+_test\.npz')

    for root, _, files in os.walk(base_dir):
        for fname in files:
            if not pattern.match(fname):
                continue

            path = os.path.join(root, fname)
            data = np.load(path)
            images = data['image']
            target = int(data['target'])

            # Resize and convert to tensor
            images_resized = np.stack([
                cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_NEAREST)
                for img in images
            ])
            images_tensor = torch.tensor(images_resized, dtype=torch.float32).unsqueeze(0).to(device)

            with torch.no_grad():
                output = model(images_tensor)
                pred = torch.argmax(output, dim=1).item()

            # Check condition
            is_match = pred == target
            if (match_op == '==' and is_match) or (match_op == '!=' and not is_match):
                matching_files.append(path)
                if len(matching_files) == max_files:
                    return matching_files

    return matching_files

In [8]:
data_path = "dataset/RAVEN-10000/distribute_four"
# Find first 20 correct predictions
correct = find_predictions_by_match(model, data_path, match_op='==')

# Find first 20 incorrect predictions
incorrect = find_predictions_by_match(model, data_path, match_op='!=')

In [10]:
# correct
incorrect

['dataset/RAVEN-10000/distribute_four/RAVEN_1228_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_4089_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_2529_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_5879_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_7788_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_2289_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_7029_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_9698_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_4728_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_1918_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_1919_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_929_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_869_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_868_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_4669_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN_1859_test.npz',
 'dataset/RAVEN-10000/distribute_four/RAVEN

# Save RAVEN questions 

In [11]:
def save_raven_question_visual(path, save_dir="raven_visuals"):
    """
    Load a RAVEN .npz file and save two plots:
    1. The 3x3 matrix with the missing panel replaced by '?'
    2. The 8 answer choices with labels

    Args:
        path (str): Path to a .npz file (e.g., RAVEN_x_test.npz)
        save_dir (str): Directory to save output images

    Returns:
        (str, str): Paths to the saved matrix and choices images
    """
    os.makedirs(save_dir, exist_ok=True)
    data = np.load(path)
    images = data['image']  # shape: (16, H, W)
    base_filename = os.path.splitext(os.path.basename(path))[0]

    # --- Save Matrix with "?" ---
    fig, axs = plt.subplots(3, 3, figsize=(6, 6))
    for i in range(8):
        ax = axs[i // 3, i % 3]
        ax.imshow(images[i], cmap='gray')
        ax.set_xticks([]); ax.set_yticks([])
        ax.set_xticks([0, images[i].shape[1]], minor=True)
        ax.set_yticks([0, images[i].shape[0]], minor=True)
        ax.grid(which='minor', color='black', linewidth=1.5)

    ax = axs[2, 2]
    ax.axis('off')
    ax.text(0.5, 0.5, '?', transform=ax.transAxes, fontsize=40, ha='center', va='center')

    for row in range(3):
        axs[row][0].set_ylabel(f"Row {row+1}", fontsize=12)
    for col in range(3):
        axs[0][col].set_title(f"Col {col+1}", fontsize=12)

    plt.suptitle('Matrix (Choose the missing piece)', fontsize=14)
    plt.tight_layout()
    matrix_path = os.path.join(save_dir, base_filename + "_matrix.png")
    plt.savefig(matrix_path, dpi=300)
    plt.close()

    # --- Save Answer Choices ---
    fig, axs = plt.subplots(1, 8, figsize=(16, 2.5))
    for i in range(8):
        axs[i].imshow(images[8 + i], cmap='gray')
        axs[i].set_xticks([]); axs[i].set_yticks([])
        axs[i].set_xticks([0, images[i].shape[1]], minor=True)
        axs[i].set_yticks([0, images[i].shape[0]], minor=True)
        axs[i].grid(which='minor', color='black', linewidth=1.0)
        axs[i].set_title(f"Choice {i}", fontsize=10)

    plt.suptitle('Answer Choices (0 to 7)', fontsize=14)
    plt.tight_layout()
    choices_path = os.path.join(save_dir, base_filename + "_choices.png")
    plt.savefig(choices_path, dpi=300)
    plt.close()

    return matrix_path, choices_path

In [20]:
base_dir = "dataset/RAVEN-10000/"
save_root = "experiment/" 

In [21]:
for category, results in data_dict.items():
    for label in ["correct", "wrong"]:
        for sample_id in results[label]:
            fname = f"RAVEN_{sample_id}_test.npz"
            path = os.path.join(base_dir, category, fname)
            out_dir = os.path.join(save_root, label, category, str(sample_id))
            
            try:
                matrix_path, choices_path = save_raven_question_visual(path, save_dir=out_dir)
                # print(f"Saved: {matrix_path}, {choices_path}")
            except Exception as e:
                print(f"Error with {path}: {e}")

# Occlusion Sensitivity Map

In [22]:
def visualize_occlusion_with_heatmaps(model, images_tensor, original_images, pred, save_dir="outputs", filename_prefix="heatmap"):
    """
    Generates and saves occlusion heatmap visualizations for a Raven matrix.
    
    Args:
        model: Trained DCNet model
        images_tensor: Tensor of shape (1, 16, H, W)
        original_images: Numpy array of shape (16, H, W)
        pred: int, predicted answer index (0–7)
        save_dir: folder to save the image(s)
        filename_prefix: prefix for saved filenames
    """
    os.makedirs(save_dir, exist_ok=True)

    def get_occlusion_sensitivity(model, images_tensor, window_size=20, stride=10):
        with torch.no_grad():
            original_output = model(images_tensor)
            pred_idx = torch.argmax(original_output, dim=1).item()
            original_score = original_output[0, pred_idx].item()

        sensitivity_maps = []
        image_size = images_tensor.shape[-1]

        for img_idx in range(16):
            sensitivity_map = np.zeros((image_size, image_size))
            for i in range(0, image_size - window_size + 1, stride):
                for j in range(0, image_size - window_size + 1, stride):
                    modified_input = images_tensor.clone().detach()
                    modified_input[0, img_idx, i:i+window_size, j:j+window_size] = 0

                    with torch.no_grad():
                        output = model(modified_input)
                        new_score = output[0, pred_idx].item()

                    score_change = abs(original_score - new_score)
                    sensitivity_map[i:i+window_size, j:j+window_size] += score_change

            # Normalize
            map_max = sensitivity_map.max()
            map_min = sensitivity_map.min()
            if map_max == map_min:
                sensitivity_map = np.ones_like(sensitivity_map) * 0.5
            else:
                sensitivity_map = (sensitivity_map - map_min) / (map_max - map_min + 1e-8)

            sensitivity_maps.append(sensitivity_map)

        return sensitivity_maps

    # ---- Generate sensitivity maps
    sensitivity_maps = get_occlusion_sensitivity(model, images_tensor)

    # ---- Plot 3×3 matrix with heatmaps + colorbar
    fig, axs = plt.subplots(3, 3, figsize=(7, 7))
    for i in range(8):
        ax = axs[i // 3, i % 3]
        heat = ax.imshow(original_images[i], cmap='gray', alpha=0.5)
        heatmap = ax.imshow(sensitivity_maps[i], cmap='jet', alpha=0.5)
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Add the missing panel
    ax = axs[2, 2]
    ax.axis('off')
    ax.text(0.5, 0.5, '?', transform=ax.transAxes, fontsize=40, ha='center', va='center')
    
    # Add colorbar to the figure
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    fig.colorbar(heatmap, cax=cbar_ax)
    
    plt.suptitle('Matrix with Occlusion Sensitivity', fontsize=14)
    plt.tight_layout(rect=[0, 0, 0.9, 1])  # Make space for colorbar
    matrix_path = os.path.join(save_dir, f"{filename_prefix}_matrix.png")
    plt.savefig(matrix_path)
    plt.close()

    # ---- Plot choices with heatmaps
    fig, axs = plt.subplots(1, 8, figsize=(16, 2.5))
    for i in range(8):
        ax = axs[i]
        ax.imshow(original_images[8 + i], cmap='gray', alpha=0.5)
        ax.imshow(sensitivity_maps[8 + i], cmap='jet', alpha=0.5)
        ax.axis('off')
        ax.set_title(f"{i}", fontsize=10)

    plt.suptitle(f'Answer Choices (Predicted: {pred})', fontsize=14)
    plt.tight_layout()
    choices_path = os.path.join(save_dir, f"{filename_prefix}_choices.png")
    plt.savefig(choices_path)
    plt.close()

    print(f"Saved matrix heatmap to: {matrix_path}")
    print(f"Saved choices heatmap to: {choices_path}")

In [27]:
for category, results in data_dict.items():
    for label in ["correct", "wrong"]:
        for sample_id in results[label]:
            filename = f"RAVEN_{sample_id}_test.npz"
            path = os.path.join(base_dir, category, filename)
            out_dir = os.path.join(save_root, label, category, str(sample_id), 'occulstion_sensitivity')

            data = np.load(path)
            images = data['image']  # (16, 160, 160)
            images_resized = np.stack([cv2.resize(img, (96, 96), interpolation=cv2.INTER_NEAREST) for img in images])
            
            # Convert to 3-channel tensors
            tf = ToTensor()
            images_tensor = tf(images).unsqueeze(0).to(torch.device("mps"))
            images_tensor_resized = tf(images_resized).unsqueeze(0).to(torch.device("mps"))
            
            # Predict
            with torch.no_grad():
                output = model(images_tensor)
                output_resized = model(images_tensor_resized)
                # pred = torch.argmax(output, dim=1).item()
                pred = torch.argmax(output_resized, dim=1).item()
            
            try:
                visualize_occlusion_with_heatmaps(
                    model=model,
                    images_tensor=images_tensor,
                    original_images=images,
                    pred=pred, 
                    save_dir=out_dir,       
                    filename_prefix=filename 
                )
            except Exception as e:
                print(f"Failed on {path}: {e}")

  plt.tight_layout(rect=[0, 0, 0.9, 1])  # Make space for colorbar


Saved matrix heatmap to: ruti/correct/distribute_four/9138/occulstion_sensitivity/RAVEN_9138_test.npz_matrix.png
Saved choices heatmap to: ruti/correct/distribute_four/9138/occulstion_sensitivity/RAVEN_9138_test.npz_choices.png
Saved matrix heatmap to: ruti/correct/distribute_four/3148/occulstion_sensitivity/RAVEN_3148_test.npz_matrix.png
Saved choices heatmap to: ruti/correct/distribute_four/3148/occulstion_sensitivity/RAVEN_3148_test.npz_choices.png
Saved matrix heatmap to: ruti/wrong/distribute_four/929/occulstion_sensitivity/RAVEN_929_test.npz_matrix.png
Saved choices heatmap to: ruti/wrong/distribute_four/929/occulstion_sensitivity/RAVEN_929_test.npz_choices.png
Saved matrix heatmap to: ruti/correct/distribute_nine/4088/occulstion_sensitivity/RAVEN_4088_test.npz_matrix.png
Saved choices heatmap to: ruti/correct/distribute_nine/4088/occulstion_sensitivity/RAVEN_4088_test.npz_choices.png
Saved matrix heatmap to: ruti/correct/distribute_nine/3149/occulstion_sensitivity/RAVEN_3149_tes