In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.models as models
import torch.nn.functional as F
from PIL import Image
import os
import numpy as np
import cv2
from tqdm import tqdm


from matplotlib.pyplot import imshow
import math
import json
import re
from scipy.ndimage import gaussian_filter1d, gaussian_filter

  Referenced from: <FB2FD416-6C4D-3621-B677-61F07C02A3C5> /Users/aravjain/miniforge3/envs/spark_env/lib/python3.9/site-packages/torchvision/image.so
  warn(


In [2]:
# Configuration
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "best_VGG_model_1.pth"  # Update with your model path

In [3]:
class VGG16BinaryClassifier(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG16BinaryClassifier, self).__init__()
        self.vgg16 = models.vgg16(pretrained=pretrained)
        for param in self.vgg16.features.parameters():
            param.requires_grad = True
        self.vgg16.classifier = nn.Sequential(
            nn.Linear(25088, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        return self.vgg16(x)

In [9]:
class OcclusionSensitivity:
    def __init__(self, model, window_size=32, stride=16):
        self.model = model
        self.window_size = window_size
        self.stride = stride
        
    def generate_heatmap(self, input_tensor, target_class=None):
        """Generate occlusion sensitivity heatmap"""
        with torch.no_grad():
            # Original prediction
            original_output = torch.sigmoid(self.model(input_tensor))
            original_prob = original_output.item()
            original_class = 1 if original_prob > 0.5 else 0
            target_class = target_class or original_class
            
            # Setup dimensions
            b, c, h, w = input_tensor.shape
            heatmap = torch.zeros((h, w), device=device)
            pad = self.window_size // 2
            
            # Pad input for complete coverage
            padded_input = F.pad(input_tensor, (pad, pad, pad, pad), value=0)
            
            # Slide occlusion window
            for y in tqdm(range(0, h, self.stride), desc="Generating occlusion map"):
                for x in range(0, w, self.stride):
                    # Create occluded version
                    occluded = padded_input.clone()
                    y_start = y + pad
                    x_start = x + pad
                    occluded[..., y_start:y_start+self.window_size, 
                            x_start:x_start+self.window_size] = 0
                    
                    # Get modified prediction
                    output = torch.sigmoid(self.model(occluded[..., pad:-pad, pad:-pad]))
                    current_prob = output.item()
                    
                    # Calculate impact score
                    if target_class == 1:
                        score = original_prob - current_prob
                    else:
                        score = current_prob - original_prob
                    
                    # Update heatmap
                    y_end = min(y + self.stride, h)
                    x_end = min(x + self.stride, w)
                    heatmap[y:y_end, x:x_end] += score
            
            # Normalize and return
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
            return heatmap.cpu().numpy()

In [10]:
def load_model(model_path, device):
    model = VGG16BinaryClassifier(pretrained=True)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device).eval()
    return model

In [11]:
# Process a directory of B-scans with occlusion sensitivity
def process_bscans_with_occlusion(model, bscan_dir, output_dir, transform, window_size=32, stride=16):
    """
    Process all B-scan images in a directory with occlusion sensitivity and save results
    
    Args:
        model: The trained model
        bscan_dir: Directory containing B-scan images
        output_dir: Directory to save occlusion heatmaps
        transform: Image transformation for model input
        window_size: Size of occlusion window
        stride: Stride for occlusion analysis
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Create occlusion sensitivity analyzer
    occluder = OcclusionSensitivity(model, window_size=window_size, stride=stride)
    
    # Get list of images
    img_filenames = sorted([
        f for f in os.listdir(bscan_dir)
        if f.endswith((".png", ".jpg", ".jpeg"))
    ])
    
    for img_name in tqdm(img_filenames, desc="Processing B-scans"):
        img_path = os.path.join(bscan_dir, img_name)
        
        # Load and process image
        image = Image.open(img_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)
        
        # Generate heatmap
        heatmap = occluder.generate_heatmap(input_tensor)
        
        # Process visualizations
        image_np = np.array(image)
        heatmap_resized = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0]))
        
        # Create heatmap overlay
        heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_INFERNO)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
        
        # Blend with original image
        blended = cv2.addWeighted(image_np, 0.7, heatmap_colored, 0.3, 0)
        
        # Save results
        output_path = os.path.join(output_dir, img_name)
        cv2.imwrite(output_path, cv2.cvtColor(blended, cv2.COLOR_RGB2BGR))
        
        # Save raw heatmap for further processing
        heatmap_path = os.path.join(output_dir, f"heatmap_{img_name}")
        cv2.imwrite(heatmap_path, np.uint8(255 * heatmap_resized))
    
    print(f"✅ Processed {len(img_filenames)} B-scans with occlusion sensitivity")

In [None]:
def main():
    # Configuration
    model_path = "best_VGG_model_1.pth"
    bscan_dir = "img"  # Directory containing B-scan images
    occlusions_output_dir = "img_masked_img"  # Directory to save occlusion heatmaps
    enface_path = "enface.jpg"  # Path to the enface image
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load model
    model = load_model(model_path, device)
    
    # Process B-scans with occlusion sensitivity
    process_bscans_with_occlusion(
        model=model,
        bscan_dir=bscan_dir,
        output_dir=occlusions_output_dir,
        transform=transform,
        window_size=32,
        stride=16
    )

In [12]:
if __name__ == "__main__":
    main()