In [2]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models import resnet50
import torch.nn.functional as F
import io
import requests
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomOxfordDataset(Dataset):
    def __init__(self, images_dir, csv_file, transform=None):
        self.images_dir = images_dir
        self.annotations = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = self.annotations.iloc[idx, 0]
        label = int(self.annotations.iloc[idx, 1])
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label


# --- Path do Dataset ---
base_url = "https://github.com/Lucas-Junqueira/Pratica_IA/tree/main/oxford_subset"
csv_url = "https://raw.githubusercontent.com/Lucas-Junqueira/Pratica_IA/refs/heads/main/oxford_subset/labels.csv"

# --- Baixar CSV ---
labels_df = pd.read_csv(csv_url)

# --- Dataset em memória ---
class GitHubDataset(Dataset):
    def __init__(self, base_url, labels_df, transform=None):
        self.base_url = base_url
        self.labels_df = labels_df
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        img_name = row["filename"]   # coluna do CSV com nome do arquivo
        label = row["label"]

        # Montar URL da imagem
        img_url = self.base_url + img_name

        # Baixar imagem direto da URL
        response = requests.get(img_url)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# --- Transformações ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# --- Criar dataset ---
dataset = GitHubDataset(base_url, labels_df, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# --- Modelo ResNet50 pré-treinado ---
model = resnet50(weights="IMAGENET1K_V1")
model.eval()

# --- Captura dos feature maps ---
feature_maps = {}
def get_activation(name):
    def hook(model, input, output):
        feature_maps[name] = output.detach()
    return hook

model.layer1.register_forward_hook(get_activation('layer1'))
model.layer2.register_forward_hook(get_activation('layer2'))
model.layer3.register_forward_hook(get_activation('layer3'))
model.layer4.register_forward_hook(get_activation('layer4'))

# --- Funções auxiliares ---
def show_image(img, ax=None):
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    if ax is None:
        plt.imshow(img)
        plt.axis("off")
    else:
        ax.imshow(img)
        ax.axis("off")

def show_feature_map(fmap, num_channels=4):
    fmap = fmap.squeeze(0)  # remove batch
    channels = torch.linspace(0, fmap.shape[0]-1, steps=num_channels).long()
    fig, axes = plt.subplots(1, num_channels, figsize=(12, 4))
    for i, ax in enumerate(axes):
        ax.imshow(fmap[channels[i]].cpu().numpy(), cmap="viridis")
        ax.axis("off")
    plt.show()

# --- Grad-CAM ---
gradients = {}
activations = {}
def forward_hook(module, input, output):
    activations["value"] = output.detach()
def backward_hook(module, grad_in, grad_out):
    gradients["value"] = grad_out[0].detach()

target_layer = model.layer4[-1]
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)

def generate_gradcam(image_tensor, class_idx=None):
    output = model(image_tensor)
    if class_idx is None:
        class_idx = output.argmax(dim=1).item()
    model.zero_grad()
    loss = output[0, class_idx]
    loss.backward()
    grads = gradients["value"]
    acts = activations["value"]
    weights = grads.mean(dim=(2, 3), keepdim=True)
    cam = (weights * acts).sum(dim=1).squeeze()
    cam = F.relu(cam)
    cam -= cam.min()
    cam /= cam.max()
    return cam.cpu().numpy()

def show_gradcam_on_image(img_tensor, cam, ax, alpha=0.5):
    img = img_tensor.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    cam_tensor = torch.tensor(cam).unsqueeze(0).unsqueeze(0)
    cam_resized = F.interpolate(cam_tensor, size=(img.shape[0], img.shape[1]), mode='bilinear', align_corners=False)
    cam_resized = cam_resized.squeeze().numpy()
    heatmap = plt.cm.jet(cam_resized)[..., :3]
    overlay = heatmap * alpha + img * (1 - alpha)
    overlay = np.clip(overlay, 0, 1)
    ax.imshow(overlay)
    ax.axis("off")

# --- Processar imagens do subset salvo ---
for idx, (image, label) in enumerate(dataloader):
    feature_maps.clear()
    _ = model(image)  # forward

    print(f"\n=== Imagem {idx+1} ===")

    cam = generate_gradcam(image)

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    show_image(image[0], axes[0])
    axes[0].set_title("Imagem Original")
    show_gradcam_on_image(image[0], cam, axes[1])
    axes[1].set_title("Grad-CAM (layer4)")
    plt.show()

    for name, fmap in feature_maps.items():
        print(f"Feature map - {name}:")
        show_feature_map(fmap, num_channels=4)


UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x000001F2C06A20C0>