In [2]:
import os
import sys
import torch
import cv2
import numpy as np
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset, DataLoader, random_split, Subset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
  "data_folder": "data",
  "train_batch_size": 64,
  "val_batch_size": 32,
  "test_batch_size": 64,
  "val_split": 0.2,
  "image_size": 256,
  "classifier": "ResNet50",
  "target_id": "class_id",
  "n_classes": 37,
  "classifier_save_path": "src/models/classifier_ResNet50.pth",
  "train_classifier": False,
  "n_epochs_cls": 5,
  "learning_rate_cls": 0.001,
  "cam_model" : "CAM",
  "cam_threshold": 0.5,
  "cam_visualisation_save_path": "results/cam/cam_examples.png",
  "segmentation_model": "DeepLabV3",
  "weakseg_model_save_path": "src/models/ws_cam_DeepLabV3.pth",
  "train_segmentation": True,
  "n_epochs_seg": 5,
  "learning_rate_seg": 0.0001,
  "segmentation_metrics_save_path": "results/cam/segmentation_metrics.csv",
  "segmentation_visualisation_save_path": "results/cam/segmentation_examples.png"
}

BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
SRC_PATH = os.path.join(BASE_DIR, "src")
sys.path.append(SRC_PATH)

from dataset import data_loading, inverse_normalize, load_data_wrapper, PseudoMaskDataset
from cam_utils import CAMGenerator, generate_grad_cam_common, get_cam_generator
from classification import ResNet50, ResNet101, DenseNet121, select_classifier

FILE_PATH = os.path.join(BASE_DIR, "hugo/data")
split = "trainval"
print(f"Looking for {split}.txt in:", os.path.join(FILE_PATH, "annotations", f"{split}.txt"))
print("Does the file exists?", os.path.exists(os.path.join(FILE_PATH, "annotations", "test.txt")))

def dice_score(output, gt_mask, threshold=0.5, eps=1e-7):
    if not isinstance(output, torch.Tensor):
        output = torch.from_numpy(output)
    if not isinstance(gt_mask, torch.Tensor):
        gt_mask = torch.from_numpy(gt_mask)
    if output.ndim == 2:
        output = output.unsqueeze(0).unsqueeze(0)
    if gt_mask.ndim == 2:
        gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
    predicted_mask = (output > threshold).float()
    gt_mask = gt_mask.float()
    intersection = (predicted_mask * gt_mask).sum()
    union = predicted_mask.sum() + gt_mask.sum()
    dice = (2 * intersection + eps) / (union + eps)
    return dice.item()

def iou_score(output, gt_mask, threshold=0.5, eps=1e-7):
    if not isinstance(output, torch.Tensor):
        output = torch.from_numpy(output)
    if not isinstance(gt_mask, torch.Tensor):
        gt_mask = torch.from_numpy(gt_mask)
    if output.ndim == 2:
        output = output.unsqueeze(0).unsqueeze(0)
    if gt_mask.ndim == 2:
        gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
    predicted_mask = (output > threshold).float()
    gt_mask = gt_mask.float()
    intersection = (predicted_mask * gt_mask).sum()
    union = predicted_mask.sum() + gt_mask.sum() - intersection
    iou = (intersection + eps) / (union + eps)
    return iou.item()

def pixel_accuracy(output, gt_mask, threshold=0.5):
    if not isinstance(output, torch.Tensor):
        output = torch.from_numpy(output)
    if not isinstance(gt_mask, torch.Tensor):
        gt_mask = torch.from_numpy(gt_mask)
    if output.ndim == 2:
        output = output.unsqueeze(0).unsqueeze(0)
    if gt_mask.ndim == 2:
        gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
    predicted_mask = (output > threshold).float()
    gt_mask = gt_mask.float()
    correct = (predicted_mask == gt_mask).float()
    total = torch.numel(gt_mask)
    accuracy = correct.sum() / total
    return accuracy.item()

def precision_recall(output, gt_mask, threshold=0.5, eps=1e-7):
    if not isinstance(output, torch.Tensor):
        output = torch.from_numpy(output)
    if not isinstance(gt_mask, torch.Tensor):
        gt_mask = torch.from_numpy(gt_mask)
    if output.ndim == 2:
        output = output.unsqueeze(0).unsqueeze(0)
    if gt_mask.ndim == 2:
        gt_mask = gt_mask.unsqueeze(0).unsqueeze(0)
    predicted_mask = (output > threshold).float()
    gt_mask = gt_mask.float()
    tp = (predicted_mask * gt_mask).sum()
    predicted_positive = predicted_mask.sum()
    actual_positive = gt_mask.sum()
    precision = (tp + eps) / (predicted_positive + eps)
    recall = (tp + eps) / (actual_positive + eps)
    return precision.item(), recall.item()

def eval_pseudo_masks_dice(pseudo_masks, masks):
    return sum([dice_score(pseudo_mask, true_mask.squeeze()) for pseudo_mask, true_mask in zip(pseudo_masks, masks)])

def eval_pseudo_masks_pixel_accuracy(pseudo_masks, masks):
    return sum([pixel_accuracy(pseudo_mask, true_mask.squeeze()) for pseudo_mask, true_mask in zip(pseudo_masks, masks)])

def eval_pseudo_masks_iou(pseudo_masks, masks):
    return sum([iou_score(pseudo_mask, true_mask.squeeze()) for pseudo_mask, true_mask in zip(pseudo_masks, masks)])

def eval_pseudo_masks_precision_recall(pseudo_masks, masks):
    precisions_recalls = [precision_recall(pseudo_mask, true_mask.squeeze()) for pseudo_mask, true_mask in zip(pseudo_masks, masks)]
    precisions = [pr[0] for pr in precisions_recalls]
    recalls = [pr[1] for pr in precisions_recalls]
    return sum(precisions), sum(recalls)

def evaluate(pseudo_masks, mask_batch, batch_size):
    dice = eval_pseudo_masks_dice(pseudo_masks, mask_batch) / batch_size
    accuracy_score = eval_pseudo_masks_pixel_accuracy(pseudo_masks, mask_batch) / batch_size
    iou = eval_pseudo_masks_iou(pseudo_masks, mask_batch) / batch_size
    precision, recall = eval_pseudo_masks_precision_recall(pseudo_masks, mask_batch) 
    return dice, accuracy_score, iou, precision / batch_size, recall / batch_size

def generate_pseudo_mask(bbox, image, variant="GrabCut"):
    x_min, y_min, x_max, y_max = bbox
    image_uint8 = (image * 255).clip(0, 255).astype(np.uint8)
    
    if variant == "GrabCut":
        grabcut_mask = np.zeros(image_uint8.shape[:2], np.uint8)
        bgdModel = np.zeros((1, 65), np.float64)
        fgdModel = np.zeros((1, 65), np.float64)
        cv2.grabCut(image_uint8, grabcut_mask, (x_min, y_min, x_max - x_min, y_max - y_min), bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
        mask2 = np.where((grabcut_mask == 2) | (grabcut_mask == 0), 0, 1).astype('uint8')
        pseudo_mask = (mask2 * 255).astype(np.uint8)
    
    elif variant == "Super-Pixel":
        image_cv = cv2.cvtColor(image_uint8, cv2.COLOR_RGB2BGR)
        image_cv = cv2.GaussianBlur(image_cv, (5, 5), sigmaX=1, sigmaY=1)
        h, w = image_cv.shape[:2]
        region_size = int(np.sqrt((h * w) / 200))
        slic = cv2.ximgproc.createSuperpixelSLIC(image_cv, cv2.ximgproc.SLIC, region_size=region_size, ruler=10)
        slic.iterate(10)
        segments = slic.getLabels()
        pseudo_mask = np.zeros_like(segments, dtype=np.uint8)
        for label in np.unique(segments):
            sp_mask = (segments == label)
            coords = np.column_stack(np.where(sp_mask))
            if len(coords) == 0:
                continue
            y_center, x_center = coords.mean(axis=0)
            if (x_center >= x_min and x_center <= x_max and y_center >= y_min and y_center <= y_max):
                pseudo_mask[sp_mask] = 1
        pseudo_mask = (pseudo_mask * 255).astype(np.uint8)
    
    return pseudo_mask

def generate_pseudo_masks(bboxs, images, variant="GrabCut"):
    return [generate_pseudo_mask(bbox, inverse_normalize(image).permute(1, 2, 0).cpu().numpy(), variant=variant) for bbox, image in zip(bboxs, images)]

def generate_bbox(cam, image):
    image_uint8 = (image * 255).clip(0, 255).astype(np.uint8)
    threshold = 0.5
    binary_mask = cam > threshold
    coords = np.column_stack(np.where(binary_mask))
    if coords.size != 0:
        y_min, x_min = coords.min(axis=0)
        y_max, x_max = coords.max(axis=0)
        return (x_min, y_min, x_max, y_max)
    else:
        return None

def generate_bboxs(cams, images):
    return [generate_bbox(cam, inverse_normalize(image).permute(1, 2, 0).cpu().numpy()) for cam, image in zip(cams, images)]

def generate_cam(image):
    image_uint8 = (image * 255).clip(0, 255).astype(np.uint8)
    model = models.resnet50(pretrained=True)
    model.eval()
    target_layer = model.layer4[-1]
    activations = None
    gradients = None

    def forward_hook(module, input, output):
        nonlocal activations
        activations = output.detach()

    def backward_hook(module, grad_input, grad_output):
        nonlocal gradients
        gradients = grad_output[0].detach()

    target_layer.register_forward_hook(forward_hook)
    target_layer.register_backward_hook(backward_hook)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])
    pil_img = Image.fromarray(image_uint8)
    input_tensor = preprocess(pil_img).unsqueeze(0)
    output = model(input_tensor)
    pred_class = output.argmax(dim=1).item()
    model.zero_grad()
    output[0, pred_class].backward()
    weights = torch.mean(gradients, dim=(2, 3))
    cam = torch.zeros(activations.shape[2:], dtype=torch.float32)
    for i, w in enumerate(weights[0]):
        cam += w * activations[0, i, :, :]
    cam = F.relu(cam)
    cam = cam - cam.min()
    if cam.max() != 0:
        cam = cam / cam.max()
    cam = cam.unsqueeze(0).unsqueeze(0)
    cam = F.interpolate(cam, size=(pil_img.size[1], pil_img.size[0]), mode='bilinear', align_corners=False)
    cam = cam.squeeze().cpu().numpy()
    return cam

def generate_cams(images):
    return [generate_cam(inverse_normalize(image).permute(1, 2, 0).cpu().numpy()) for image in images]

def visualise_results(image, bbox, mask, pseudo_mask, variant="GrabCut", save_path=None):
    # Convert tensors to numpy arrays
    image_np = inverse_normalize(image).permute(1, 2, 0).cpu().numpy()
    mask_np = mask.squeeze().cpu().numpy()
    
    # Convert numpy arrays to PIL Images
    img_pil = Image.fromarray((image_np * 255).astype(np.uint8))
    mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)).convert('L').convert('RGB')
    pseudo_mask_pil = Image.fromarray(pseudo_mask).convert('L').convert('RGB')
    
    # Create composite image
    composite = Image.new('RGB', (img_pil.width * 3, img_pil.height))
    
    # Draw bounding box on original image
    draw = ImageDraw.Draw(img_pil)
    x_min, y_min, x_max, y_max = bbox
    draw.rectangle([x_min, y_min, x_max, y_max], outline='red', width=3)
    
    # Paste images into composite
    composite.paste(img_pil, (0, 0))
    composite.paste(mask_pil, (img_pil.width, 0))
    composite.paste(pseudo_mask_pil, (img_pil.width * 2, 0))
    
    # Add labels
    draw = ImageDraw.Draw(composite)
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except:
        font = ImageFont.load_default()
        
    draw.text((10, 10), "Image with Bounding Box", fill='white', font=font)
    draw.text((img_pil.width + 10, 10), "Ground Truth Mask", fill='white', font=font)
    draw.text((img_pil.width * 2 + 10, 10), f"Pseudo Mask {variant}", fill='white', font=font)
    
    if save_path:
        composite.save(save_path)
    return composite

def save_pseudo_masks(pseudo_masks, batch_idx, output_dir):
    for idx, pseudo_mask in enumerate(pseudo_masks):
        pseudo_mask_np = np.squeeze(pseudo_mask)
        if pseudo_mask_np.max() <= 1.0:
            pseudo_mask_np = (pseudo_mask_np * 255).astype(np.uint8)
        else:
            pseudo_mask_np = pseudo_mask_np.astype(np.uint8)
        file_name = os.path.join(output_dir, f"pseudo_mask_batch{batch_idx}_idx{idx}.png")
        Image.fromarray(pseudo_mask_np).save(file_name)
    print(f'Images saved to file')


train_loader, val_loader, test_loader = load_data_wrapper(config=config)

total_dice = 0.0
total_accuracy = 0.0
total_iou = 0.0
total_precision = 0.0
total_recall = 0.0
n_batches = len(train_loader)
batch_size = 64
output_dir = "saved_pseudo_masks"
os.makedirs(output_dir, exist_ok=True)

for batch_idx, (image_batch, mask_batch, info_batch) in enumerate(train_loader, start=1):
    print(f"Batch number: {batch_idx}")
    cams = generate_cams(image_batch)
    bboxs = generate_bboxs(cams, image_batch)
    pseudo_masks = generate_pseudo_masks(bboxs, image_batch, variant="GrabCut")
    dice, accuracy_score, iou, precision, recall = evaluate(pseudo_masks, mask_batch, batch_size)
    print(f'Dice score is {dice} || Pixel score is {accuracy_score} || IOU score is {iou} || Precision score is {precision} || Recall score is {recall}')
    total_dice += dice
    total_accuracy += accuracy_score
    total_iou += iou
    total_precision += precision
    total_recall += recall
    visualise_results(image_batch[0], bboxs[0], mask_batch[0], pseudo_masks[0], variant="GrabCut")
    save_pseudo_masks(pseudo_masks, batch_idx, output_dir)

print(f'Average dice score per image is {total_dice / n_batches}')
print(f'Average accuracy score per image is {total_accuracy / n_batches}')
print(f'Average IOU score per image is {total_iou / n_batches}')
print(f'Average precision score per image is {total_precision / n_batches}')
print(f'Average recall score per image is {total_recall / n_batches}')

Looking for trainval.txt in: /Users/paulchainieux/Documents/UCL/ADL/WeakSegNet/hugo/data/annotations/trainval.txt
Does the file exists? True

----Loading data...
Loading split from: /Users/paulchainieux/Documents/UCL/ADL/WeakSegNet/hugo/data/annotations/trainval.txt
Loading split from: /Users/paulchainieux/Documents/UCL/ADL/WeakSegNet/hugo/data/annotations/trainval.txt
Loading split from: /Users/paulchainieux/Documents/UCL/ADL/WeakSegNet/hugo/data/annotations/test.txt
[Data loaded succesfully]

Training set: 2944 samples
Validation set: 736 samples
Test set: 3669 samples
Batch number: 1
Dice score is 0.505004857471083 || Pixel score is 0.7352609634399414 || IOU score is 0.3946094413756258 || Precision score is 0.8754364333226909 || Recall score is 0.4150360746486448
Images saved to file
Batch number: 2


KeyboardInterrupt: 