### **Train ResNetCLassifier and extract features**

In [None]:
import os
import sys
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image, ImageDraw, ImageOps
from lxml import etree
from torchvision import transforms

os.add_dll_directory(
    r"C:\Program Files\OpenSlide\openslide-bin-4.0.0.8-windows-x64\bin"
)
import openslide
from models.resnet import ResNet18Classifier, ResNet18FeatureExtractor
from datasets.patch_dataset import PatchDataset
from PIL import Image

def parse_xml_mask(xml_path, level_dims, downsample):
    """
    Convert XML annotation to binary mask.
    Parameters:
    - xml_path: str, path to the XML file containing annotations.
    - level_dims: tuple, dimensions of the WSI at the specified level (width, height).
    - downsample: float, downsample factor for the WSI level.
    """
    try:
        tree = etree.parse(xml_path)
    except etree.XMLSyntaxError as e:
        print(f"Error parsing XML file {xml_path}: {e}")
        return None

    mask = Image.new("L", level_dims, 0)
    draw = ImageDraw.Draw(mask)

    for coordinates_node in tree.xpath("//Annotation/Coordinates | //Annotations/Annotation/Coordinates"):
        coords = []
        for coord_node in coordinates_node.findall("Coordinate"):
            try:
                x = float(coord_node.get("X"))
                y = float(coord_node.get("Y"))
                # Scale coordinates to the target level
                scaled_x = int(x / downsample)
                scaled_y = int(y / downsample)
                coords.append((scaled_x, scaled_y))
            except (ValueError, TypeError) as e:
                print(f"Warning: Could not parse coordinate (X,Y) from XML for {xml_path}: {e}")
                continue
        if coords:
            # Draw with 255 for white on a black background
            draw.polygon(coords, outline=255, fill=255)
    return mask


def train_resnet_classifier(level=3):
    """ 
    Train a ResNet18 classifier on the extracted patches.
    """
    print("[INFO] Training ResNet18 classifier on extracted patches...")
    patch_dir = os.path.join(os.getcwd(), "..", "data", "camelyon16", "patches", f"level_{level}")

    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    dataset = PatchDataset(patch_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet18Classifier().to(device)

    optimizer = Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    num_epochs = 5
    for epoch in range(num_epochs):
        model.train()
        total_loss, correct = 0, 0
        for imgs, labels, _ in loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()

        acc = correct / len(dataset)
        print(
            f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}, Accuracy: {acc:.4f}"
        )

    torch.save(model.state_dict(), "models/resnet18_patch_classifier.pth")
    print("[INFO] ResNet18 classifier training complete and saved.")


def extract_patches(patch_size=224, level=3, stride=None, pad=True):
    """
    Extract patches from WSIs at a specified level, apply mask overlays, and save tumor vs normal labels.
    Only extracts patches if they have not already been extracted for a given image.
    Parameters:
    - patch_size: int, size of the patches to extract.
    - level: int, level of the WSI to extract patches from.
    - stride: int, stride for patch extraction.
    - pad: bool, if True, pad the image to cover all regions.
    """
    print(f"[INFO] Extracting patches at level {level}...")
    stride = stride or patch_size

    # Set patch size according to level
    patch_sizes = {0: 1792, 1: 896, 2: 448, 3: 224}
    patch_size = patch_sizes.get(level, 224)

    wsi_dir = os.path.join(os.getcwd(), "..", "data", "camelyon16", "train", "img")
    annot_dir_train = os.path.join(
        os.getcwd(), "..", "data", "camelyon16", "train", "mask", "annotations"
    )
    annot_dir_test = os.path.join(
        os.getcwd(), "..", "data", "camelyon16", "test", "mask", "annotations"
    )
    level_dir = os.path.join(
        os.getcwd(), "..", "data", "camelyon16", "patches", f"level_{level}"
    )
    os.makedirs(level_dir, exist_ok=True)

    for file in os.listdir(wsi_dir):
        if not file.endswith(".tif"):
            continue
        prefix = file.replace(".tif", "")

        # Check if patches for this image already exist
        patch_save_dir = os.path.join(level_dir, prefix)
        if os.path.exists(patch_save_dir) and len(os.listdir(patch_save_dir)) > 0:
            print(f"[INFO] Patches for {file} already extracted, skipping.")
            continue
        os.makedirs(patch_save_dir, exist_ok=True)

        wsi_path = os.path.join(wsi_dir, file)
        xml_name = file.replace(".tif", ".xml")
        if file.startswith("test_"):
            xml_path = os.path.join(annot_dir_test, xml_name)
        elif file.startswith("normal_") or file.startswith("tumor_"):
            xml_path = os.path.join(annot_dir_train, xml_name)
        print(f"[DEBUG] Processing file: {wsi_path} with XML: {xml_path}")
        try:
            slide = openslide.OpenSlide(wsi_path)
        except Exception as e:
            print(f"[ERROR] Could not open {wsi_path}: {e}")
            continue
        downsample = slide.level_downsamples[level]
        width, height = slide.level_dimensions[level]

        # Calculate padded size if needed
        if pad:
            pad_w = (patch_size - width % patch_size) % patch_size
            pad_h = (patch_size - height % patch_size) % patch_size
            padded_width = width + pad_w
            padded_height = height + pad_h
        else:
            padded_width = width
            padded_height = height

        # Load and render XML mask
        mask = None
        if os.path.exists(xml_path):
            try:
                mask = parse_xml_mask(xml_path, (width, height), downsample)
                if pad and (pad_w > 0 or pad_h > 0):
                    mask = ImageOps.expand(mask, (0, 0, pad_w, pad_h), fill=0)
            except Exception as e:
                print(f"[WARNING] Failed to parse XML for {file}: {e}")
        else:
            print(f"[INFO] No annotation found for {file}, treating as normal.")

        print(f"[INFO] Processing {file} at level {level} (size: {width}x{height}, padded: {padded_width}x{padded_height})")

        patch_count = 0
        for x in range(0, padded_width, stride):
            for y in range(0, padded_height, stride):
                # Only process if the top-left corner is inside the original image
                if x >= width or y >= height:
                    continue

                patch_w = min(patch_size, width - x)
                patch_h = min(patch_size, height - y)
                if patch_w <= 0 or patch_h <= 0:
                    continue

                region = slide.read_region(
                    (int(x * downsample), int(y * downsample)),
                    level,
                    (patch_w, patch_h),
                ).convert("RGB")

                # If patch is smaller than patch_size (at border), pad it to patch_size
                if patch_w < patch_size or patch_h < patch_size:
                    padded_region = Image.new("RGB", (patch_size, patch_size), (255, 255, 255))
                    padded_region.paste(region, (0, 0))
                    region = padded_region

                label = "normal"
                # Check if the patch overlaps with any positimve (tumor) region in the generated binary mask
                if mask:
                    mask_patch = mask.crop((x, y, x + patch_size, y + patch_size))
                    if np.any(np.array(mask_patch) > 0):
                        label = "tumor"

                patch_array = np.array(region)
                if np.mean(patch_array) > 240:  # too white (empty tissue)
                    continue

                patch_save_dir = os.path.join(level_dir, prefix)
                os.makedirs(patch_save_dir, exist_ok=True)
                patch_name = f"{prefix}_x{x}_y{y}_{label}.png"
                region.save(os.path.join(patch_save_dir, patch_name))
                patch_count += 1
                if patch_count % 100 == 0:
                    print(f"Extracted patches {patch_count} for {file}")

        print(
            f"[INFO] Patch extraction complete for {file} at level {level}. Total patches: {patch_count}"
        )


def extract_features(level=3, model_path="resnet18_patch_classifier.pth"):
    """
    Extract features from the patches using a ResNet18 model.
    Parameters:
    - level: int, WSI level to extract patches from (0, 1, 2, 3).
    - model_path: str, path to the pre-trained ResNet18 model.
    """
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    model_path = os.path.join(os.getcwd(), "models", model_path)
    patch_dir = os.path.join(
        os.getcwd(), "..", "data", "camelyon16", "patches", f"level_{level}"
    )
    
    if not os.path.exists(patch_dir) or not os.listdir(patch_dir):
        print(f"[ERROR] Patch directory '{patch_dir}' does not exist or is empty. Please run patch extraction first.")
        return

    dataset = PatchDataset(patch_dir, transform=transform)
    # Use higher batch size and num_workers for feature extraction as it's typically I/O bound
    loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=os.cpu_count() or 1) 
    
    print(
        f"[INFO] Extracting features from patches at level {level} with patch directory: {patch_dir}, which exists: {os.path.exists(patch_dir)}"
    )
    print(
        "[INFO] Listing first 5 subdirectories in patch_dir:",
        os.listdir(patch_dir)[:5] if os.path.exists(patch_dir) else "Not found",
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ResNet18FeatureExtractor().to(device)
    full_classifier_model = ResNet18Classifier().to(device)
    if os.path.exists(model_path):
        print(f"[INFO] Loading trained classifier weights from {model_path}")
        full_classifier_model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print(f"[WARNING] Trained classifier model not found at {model_path}. "
              "Extracting features with ImageNet pre-trained weights only. "
              "Consider running `train_resnet_classifier()` first.")

    model = ResNet18FeatureExtractor().to(device)
    # Load the state_dict and filter out the 'fc' layer weights
    pretrained_dict = full_classifier_model.state_dict()
    model_dict = model.state_dict()

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and not k.startswith('model.fc')}

    # copy params from pretrained_dict to model_dict
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    model.eval() 

    features = []
    labels = []
    paths = []

    with torch.no_grad():
        for batch_idx, (imgs, lbls, img_paths) in enumerate(tqdm(loader, desc="Extracting Features")):
            # print("Batch size:", imgs.shape)
            feats = model(imgs.to(device))
            features.append(feats.cpu())
            labels.extend(lbls.tolist()) # Convert tensor to list for extend
            paths.extend(img_paths)
    
    if not features:
        print(
            "[ERROR] No features were extracted. Check your patch directory and dataset. "
            "It might be that PatchDataset found no images, or data loader was empty."
        )
        return
        
    features = torch.cat(features, dim=0)  # (num_patches, 512)

    # Save features, labels, and paths
    features_save_path = f"patch_features_{level}.npy"
    labels_save_path = f"patch_labels_{level}.npy"
    paths_save_path = f"patch_paths_{level}.txt"
    
    np.save(features_save_path, features.numpy())
    np.save(labels_save_path, np.array(labels))
    with open(paths_save_path, "w") as f:
        for p in paths:
            f.write(f"{p}\n")
    print(f"[INFO] Features saved to {features_save_path}, labels to {labels_save_path}, paths to {paths_save_path}")



In [3]:
# extract_patches(patch_size=224, level=3, stride=224, pad=True)
# train_resnet_classifier(level=3)
extract_features(level=3, model_path="resnet18_patch_classifier.pth")

[INFO] Extracting features from patches at level 3 with patch directory: c:\Users\anaca\Documents\sexto.curso\tfg info\fresh-clone\ss25_Hierarchical_Multiscale_Image_Classification\src\..\data\camelyon16\patches\level_3, which exists: True
[INFO] Listing first 5 subdirectories in patch_dir: ['normal_001', 'tumor_001']
[INFO] Loading trained classifier weights from c:\Users\anaca\Documents\sexto.curso\tfg info\fresh-clone\ss25_Hierarchical_Multiscale_Image_Classification\src\models\resnet18_patch_classifier.pth


Extracting Features: 100%|██████████| 35/35 [01:38<00:00,  2.82s/it]

[INFO] Features saved to patch_features_3.npy, labels to patch_labels_3.npy, paths to patch_paths_3.txt



