ResNetAvgPool1024Extractor

In [None]:
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np

class ResNetAvgPool1024Extractor:
    def __init__(self):
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.model.eval()
        self.feature = None

        # Hook sul layer avgpool
        def hook_avgpool(module, input, output):
            self.feature = output.squeeze().detach().numpy()

        self.model.avgpool.register_forward_hook(hook_avgpool)

        # Preprocessing: resize, tensor, normalize
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def extract_feature(self, image_path):
        img = Image.open(image_path).convert('RGB')
        input_tensor = self.transform(img).unsqueeze(0)

        with torch.no_grad():
            _ = self.model(input_tensor)

        # self.feature è 2048-dim; riduci a 1024-dim
        avgpool_2048 = self.feature
        avgpool_1024 = 0.5 * (avgpool_2048[::2] + avgpool_2048[1::2])
        return avgpool_1024


# === ESEMPIO USO ===
if __name__ == "__main__":
    image_path = "Part1/brain_glioma/brain_glioma_0001.jpg"  # Sostituisci con il tuo path
    extractor = ResNetAvgPool1024Extractor()
    features = extractor.extract_feature(image_path)

    print("Feature vector (1024 dim):")
    print(features[:10], "...")  # stampa i primi 10 valori

task 2

In [None]:
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
import os
import pandas as pd

class ResNetAvgPool1024Extractor:
    def __init__(self):
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.model.eval()
        self.feature = None

        # Hook sul layer avgpool
        def hook_avgpool(module, input, output):
            self.feature = output.squeeze().detach().numpy()

        self.model.avgpool.register_forward_hook(hook_avgpool)

        # Preprocessing: resize, tensor, normalize
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def extract_feature(self, image_path):
        img = Image.open(image_path).convert('RGB')
        input_tensor = self.transform(img).unsqueeze(0)

        with torch.no_grad():
            _ = self.model(input_tensor)

        # self.feature è 2048-dim; riduci a 1024-dim
        avgpool_2048 = self.feature
        avgpool_1024 = 0.5 * (avgpool_2048[::2] + avgpool_2048[1::2])
        return avgpool_1024


def process_folders(base_folder, subfolders):
    extractor = ResNetAvgPool1024Extractor()
    data = []

    for label in subfolders:
        folder_path = os.path.join(base_folder, label)
        print(f"[INFO] Elaboro cartella: {folder_path}")

        for filename in os.listdir(folder_path):
            if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif')):
                image_path = os.path.join(folder_path, filename)
                try:
                    features = extractor.extract_feature(image_path)
                    entry = {
                        "filename": filename,
                        "label": label,
                        **{f"f{i}": features[i] for i in range(len(features))}
                    }
                    data.append(entry)
                except Exception as e:
                    print(f"[ERRORE] Immagine saltata {filename}: {e}")

    return pd.DataFrame(data)


# === ESECUZIONE ===
if __name__ == "__main__":
    base_folder = "Part1"
    subfolders = ["brain_glioma", "brain_menin", "brain_tumor"]

    df = process_folders(base_folder, subfolders)
    df.to_csv("resnet1024_features.csv", index=False)

    print(f"[FINE] Salvato CSV con {len(df)} righe in 'resnet1024_features.csv'")
