In [None]:
!git clone https://github.com/ayushnangia/CAM-Back-Again.git
!pip install timm

In [None]:
import os
import glob
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import timm
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings('ignore')


In [None]:
# @title
project_path = "/path/repo/CAM_Back_Again"
sys.path.append(project_path)
from convnext_func import Net2Head
from replknet_func import *
from utils import *
from dataset_func import *

In [None]:
def load_pretrained_model(model, pretrained_path, num_classes):
    pretrained_dict = torch.load(pretrained_path, map_location='cpu')
    model_dict = model.state_dict()

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'head' not in k}

    model_dict.update(pretrained_dict)

    model.load_state_dict(model_dict, strict=False)

    in_features = model.head.in_features
    model.head = nn.Linear(in_features, num_classes)

    return model

### High activation heat map generation code

In [None]:
import os
import glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import timm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Set up paths
model_path = "/modelpath"
train_dir = "/input/path"
heatmap_output_dir = '/output/path'

os.makedirs(heatmap_output_dir, exist_ok=True)

# Model configuration
model_config = {
    "class_n": 2,
    "unit_n": 1024,
    "input_size": 384,
    "size": 12,
    "lr": 1e-05,
    "weight_decay": 0.0005,
    "channels": [128, 256, 512, 1024]
}

img_transform = transforms.Compose([
    transforms.Resize((model_config["input_size"], model_config["input_size"])),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def load_model(model_path):
    if 'convnext' in model_path.lower():
        model = timm.create_model("convnext_base_384_in22ft1k", pretrained=False, num_classes=model_config["class_n"])
    elif 'replknet' in model_path.lower():
        model_config["model_name"] = "RepLKNet-31B"
        model_config["channels"] = [128,256,512,1024]
        model = build_model(model_config)  # Make sure you have this function defined or imported
    else:
        raise ValueError(f"Unsupported model type: {model_path}")

    state_dict = torch.load(model_path, map_location='cpu')

    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    return model

def get_feature_maps(x, model):
    features = []
    def hook_fn(module, input, output):
        features.append(output)

    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, nn.Conv2d):
            handle = module.register_forward_hook(hook_fn)
            break

    _ = model(x)
    handle.remove()

    return features[0]

def get_cam_heatmap(features, model, target_class):
    b, c, h, w = features.shape
    features = features.reshape(c, h*w)

    if hasattr(model, 'head'):
        weights = model.head.weight.data[target_class]
    elif hasattr(model, 'fc'):
        weights = model.fc.weight.data[target_class]
    else:
        raise AttributeError("Model doesn't have 'head' or 'fc' attribute. Please check the model architecture.")

    heatmap = torch.mm(weights.unsqueeze(0), features).reshape(h, w)
    heatmap = F.relu(heatmap)
    heatmap = heatmap.detach().cpu().numpy()
    heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap) + 1e-8)
    return heatmap

def preprocess_image(img):
    img = cv2.resize(img, (model_config["input_size"], model_config["input_size"]))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose((2, 0, 1))  # HWC to CHW
    img = img / 255.0
    img = (img - 0.5) / 0.5
    return torch.FloatTensor(img).unsqueeze(0)

def generate_heatmap(model, image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    original_img = cv2.imread(image_path)
    if original_img is None:
        print(f"Error reading image: {image_path}")
        return None, None

    img = cv2.resize(original_img, (model_config["input_size"], model_config["input_size"]))

    x = preprocess_image(img).to(device)

    with torch.no_grad():
        logit = model(x)

    h_x = F.softmax(logit, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    target_class = idx[0].item()

    features = get_feature_maps(x, model)
    heatmap = get_cam_heatmap(features, model, target_class)

    heatmap = cv2.resize(heatmap, (model_config["input_size"], model_config["input_size"]))
    heatmap = np.uint8(255 * heatmap)
    heatmap_colored = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    superimposed_img = cv2.addWeighted(img, 0.6, heatmap_colored, 0.4, 0)
    superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)

    return original_img, superimposed_img

def process_directory(model, input_dir, output_dir):
    image_files = glob.glob(os.path.join(input_dir, '*.*'))

    for i, img_path in enumerate(image_files):
        img_name = os.path.basename(img_path)
        print(f"Processing {img_name}")

        original, heatmap = generate_heatmap(model, img_path)
        if original is None or heatmap is None:
            continue

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
        ax1.set_title("Original")
        ax1.axis('off')

        ax2.imshow(heatmap)
        ax2.set_title("CAM")
        ax2.axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{img_name}_heatmap.png"))
        plt.close()



def main():
    model = load_model(model_path)
    process_directory(model, train_dir, heatmap_output_dir)
    print("Heatmap generation completed.")

if __name__ == "__main__":
    main()

### High activation based crops

In [None]:
import os
import glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import timm
from skimage.measure import label, regionprops
import random
import warnings
warnings.filterwarnings('ignore')

# Set up paths
model_path = "/modelpath"
train_dir = "/input/path"
crops_output_dir = '/output/path'

os.makedirs(crops_output_dir, exist_ok=True)

# Model configuration
model_config = {
    "class_n": 2,
    "unit_n": 1024,
    "input_size": 384,
    "size": 12,
    "lr": 1e-04,
    "weight_decay": 0.0001,
    "channels": [128, 256, 512, 1024]
}

img_transform = transforms.Compose([
    transforms.Resize((model_config["input_size"], model_config["input_size"])),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def load_model(model_path):
    if 'convnext' in model_path.lower():
        model = timm.create_model("convnext_base_384_in22ft1k", pretrained=False, num_classes=model_config["class_n"])
    elif 'replknet' in model_path.lower():
        model_config["model_name"] = "RepLKNet-31B"
        model_config["channels"] = [128,256,512,1024]
        model = build_model(model_config)  # Make sure you have this function defined or imported
    else:
        raise ValueError(f"Unsupported model type: {model_path}")

    state_dict = torch.load(model_path, map_location='cpu')

    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    return model

def get_feature_maps(x, model):
    features = []
    def hook_fn(module, input, output):
        features.append(output)

    for name, module in reversed(list(model.named_modules())):
        if isinstance(module, nn.Conv2d):
            handle = module.register_forward_hook(hook_fn)
            break

    _ = model(x)
    handle.remove()

    return features[0]

def get_cam_heatmap(features, model, target_class):
    b, c, h, w = features.shape
    features = features.reshape(c, h*w)

    if hasattr(model, 'head'):
        weights = model.head.weight.data[target_class]
    elif hasattr(model, 'fc'):
        weights = model.fc.weight.data[target_class]
    else:
        raise AttributeError("Model doesn't have 'head' or 'fc' attribute. Please check the model architecture.")

    heatmap = torch.mm(weights.unsqueeze(0), features).reshape(h, w)
    heatmap = F.relu(heatmap)
    heatmap = heatmap.detach().cpu().numpy()
    heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap) + 1e-8)
    return heatmap

def preprocess_image(img):
    img = cv2.resize(img, (model_config["input_size"], model_config["input_size"]))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose((2, 0, 1))  # HWC to CHW
    img = img / 255.0
    img = (img - 0.5) / 0.5
    return torch.FloatTensor(img).unsqueeze(0)

def generate_heatmap(model, image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    original_img = cv2.imread(image_path)
    if original_img is None:
        print(f"Error reading image: {image_path}")
        return None, None

    img = cv2.resize(original_img, (model_config["input_size"], model_config["input_size"]))

    x = preprocess_image(img).to(device)

    with torch.no_grad():
        logit = model(x)

    h_x = F.softmax(logit, dim=1).data.squeeze()
    probs, idx = h_x.sort(0, True)
    target_class = idx[0].item()

    features = get_feature_maps(x, model)
    heatmap = get_cam_heatmap(features, model, target_class)

    return original_img, heatmap

def create_focused_crops(image, heatmap, num_crops=8, min_crop_size=224):
    h, w = heatmap.shape
    crops = []

    local_max = np.zeros_like(heatmap)
    for i in range(1, h-1):
        for j in range(1, w-1):
            if heatmap[i, j] > max(heatmap[i-1:i+2, j-1:j+2].flatten()):
                local_max[i, j] = heatmap[i, j]

    coords = np.column_stack(np.where(local_max > 0))
    intensities = local_max[local_max > 0]
    sorted_indices = np.argsort(intensities)[::-1]

    for idx in sorted_indices[:num_crops]:
        y, x = coords[idx]

        intensity = intensities[idx]
        crop_size = int(min_crop_size + (min_crop_size * intensity))
        crop_size = min(crop_size, min(image.shape[0], image.shape[1]))

        y1 = max(0, y - crop_size // 2)
        x1 = max(0, x - crop_size // 2)
        y2 = min(image.shape[0], y1 + crop_size)
        x2 = min(image.shape[1], x1 + crop_size)

        if y2 - y1 < min_crop_size:
            y1 = max(0, y2 - min_crop_size)
        if x2 - x1 < min_crop_size:
            x1 = max(0, x2 - min_crop_size)

        crop = image[y1:y2, x1:x2]

        if crop.shape[:2] != (min_crop_size, min_crop_size):
            crop = cv2.resize(crop, (min_crop_size, min_crop_size))

        crops.append(crop)

    while len(crops) < num_crops:
        y = random.randint(0, image.shape[0] - min_crop_size)
        x = random.randint(0, image.shape[1] - min_crop_size)
        crop = image[y:y+min_crop_size, x:x+min_crop_size]
        crops.append(crop)

    return crops

def process_directory(model, input_dir, output_dir):
    image_files = glob.glob(os.path.join(input_dir, '*.*'))

    for i, img_path in enumerate(image_files):
        img_name = os.path.basename(img_path)
        print(f"Processing {img_name}")

        original, heatmap = generate_heatmap(model, img_path)
        if original is None or heatmap is None:
            continue

        crops = create_focused_crops(original, heatmap, num_crops=8, min_crop_size=224)

        for j, crop in enumerate(crops):
            crop_name = f"{os.path.splitext(img_name)[0]}_crop_{j}.jpg"
            cv2.imwrite(os.path.join(output_dir, crop_name), crop)

def main():
    model = load_model(model_path)
    process_directory(model, train_dir, crops_output_dir)
    print("High activation crop generation completed.")

if __name__ == "__main__":
    main()