In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import os

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = '/kaggle/working/new_model.pth' # The CNN trained in Step 2
CSV_PATH = '/kaggle/working/processed_train_new.csv' # We visualize Train data to compare with known prices
NUM_SAMPLES = 10 # Number of images to generate

# --- 1. Re-define Model Architecture (Must Match Training) ---
class MultimodalNet(nn.Module):
    def __init__(self, num_tabular_features):
        super(MultimodalNet, self).__init__()
        
        # Image Branch
        try:
            from torchvision.models import ResNet18_Weights
            self.cnn = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        except:
            self.cnn = models.resnet18(pretrained=True)
            
        # Match the projection layer from training
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, 64)
        
        # Tabular Branch (Structure must match for loading weights, even if unused for CAM)
        self.tabular_branch = nn.Sequential(
            nn.Linear(num_tabular_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )
        
        # Fusion
        self.fusion = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1) 
        )
        
    def forward(self, image, tabular):
        x_img = self.cnn(image)
        x_tab = self.tabular_branch(tabular)
        x_combined = torch.cat((x_img, x_tab), dim=1)
        output = self.fusion(x_combined)
        return output

# --- 2. Dataset Loader ---
class InferenceDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
        
        # Exclude non-feature columns
        exclude = ['id', 'date', 'price', 'log_price', 'image_path', 'date_int']
        # Identify interaction features if they exist, otherwise stick to basics
        # We reload the original processed columns
        self.feature_cols = [c for c in self.data.columns if c not in exclude]
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = str(row['image_path'])
        
        if img_path != 'nan' and img_path != 'None' and os.path.exists(img_path):
            original_img = Image.open(img_path).convert('RGB')
            original_img = original_img.resize((224, 224))
        else:
            original_img = Image.new('RGB', (224, 224), color=(0,0,0))
            
        image_tensor = self.transform(original_img)
        tabular = torch.tensor(row[self.feature_cols].values.astype(np.float32))
        
        return image_tensor, tabular, np.array(original_img), row['id'], row.get('price', 0)

# --- 3. Grad-CAM Implementation ---
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)
        
    def save_activation(self, module, input, output):
        self.activations = output
        
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]
        
    def __call__(self, image_tensor, tabular_tensor):
        self.model.eval()
        
        # Forward pass
        output = self.model(image_tensor.unsqueeze(0), tabular_tensor.unsqueeze(0))
        
        # Backward pass
        self.model.zero_grad()
        output.backward()
        
        # Generate Heatmap
        gradients = self.gradients.cpu().data.numpy()[0]
        activations = self.activations.cpu().data.numpy()[0]
        
        weights = np.mean(gradients, axis=(1, 2))
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        
        for i, w in enumerate(weights):
            cam += w * activations[i]
            
        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, (224, 224))
        cam = cam - np.min(cam)
        cam = cam / (np.max(cam) + 1e-8)
        
        return cam, output.item()

# --- 4. Main Execution ---
def run_explainability():
    print("Initializing Grad-CAM Visualization...")
    
    dataset = InferenceDataset(CSV_PATH)
    # Pick random indices
    indices = np.random.choice(len(dataset), NUM_SAMPLES, replace=False)
    
    # Initialize Model
    num_features = len(dataset.feature_cols)
    model = MultimodalNet(num_features).to(DEVICE)
    
    # Load weights loosely (allowing for some missing keys if strict match fails)
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    except RuntimeError:
        print("Warning: Strict loading failed, trying non-strict...")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE), strict=False)
    
    # Target Layer: Last Convolutional Block of ResNet18
    target_layer = model.cnn.layer4[-1]
    grad_cam = GradCAM(model, target_layer)
    
    print(f"Generating explanations for {NUM_SAMPLES} samples...")
    
    plt.figure(figsize=(15, 4 * ((NUM_SAMPLES + 2) // 3))) # Dynamic height
    
    for i, idx in enumerate(indices):
        img_tensor, tab_tensor, original_img_np, house_id, true_price = dataset[idx]
        img_tensor = img_tensor.to(DEVICE)
        tab_tensor = tab_tensor.to(DEVICE)
        
        # Generate Heatmap
        heatmap, pred_log_price = grad_cam(img_tensor, tab_tensor)
        
        # Overlay
        heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
        
        superimposed = np.uint8(0.6 * original_img_np + 0.4 * heatmap_colored)
        
        # Plotting
        ax = plt.subplot(NUM_SAMPLES // 3 + 1, 3, i + 1)
        plt.imshow(superimposed)
        plt.title(f"ID: {house_id}\nPrice: ${true_price:,.0f}", fontsize=10)
        plt.axis('off')
        
    plt.tight_layout()
    plt.savefig('explainability_report.png')
    print("Done. Visualization saved to 'explainability_report.png'.")

if __name__ == "__main__":
    run_explainability()