# IMPORTS

In [1]:
import sys
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '3'
sys.path.append(os.path.abspath(".."))       # for 'protonet_mnist_add_utils' folder
sys.path.append(os.path.abspath("../.."))    # for 'data' folder
sys.path.append(os.path.abspath("../../..")) # for 'models' and 'datasets' folders

print(sys.path)

['/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/kandinsky/notebooks', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python38.zip', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/lib-dynload', '', '/users-1/eleonora/.local/lib/python3.8/site-packages', '/users-1/eleonora/anaconda3/envs/r4rr/lib/python3.8/site-packages', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation/kandinsky', '/users-1/eleonora/reasoning-shortcuts/IXShort/shortcut_mitigation', '/users-1/eleonora/reasoning-shortcuts/IXShort']


In [2]:
import cv2
import json
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch.nn.functional as F
import torchvision.transforms as T

from argparse import Namespace
from datasets import get_dataset
from datasets.utils.base_dataset import BaseDataset

# UTILS

*Fetches the bounding boxes of the initial prototypes for cropping the triplets*

In [3]:
def get_proto_imgs_bboxes():
    all_bboxes = []
    for i in range(3):
        filename = f"data/kand_annotations/init/bboxes_init/image_{i}_annotations.json"    
        with open(filename, "r") as f:
            data = json.load(f)
        
        all_bboxes.append(data["boxes"])

    bbox_tensor = torch.tensor(all_bboxes, dtype=torch.float32)  # Shape: (27, 3, 4)
    assert bbox_tensor.shape == (3, 3, 4), bbox_tensor.shape
    return bbox_tensor

*Plots the initial prototypes with the bounding boxes*

In [4]:
# plot the initial prototypes with their bounding boxes
def plot_prototypes(proto_imgs_shapes, proto_labels_shapes, bbox_tensor):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))  # 1 row, 3 columns for 3 images
    for idx in range(3):
        ax = axs[idx]
        img_np = proto_imgs_shapes[idx].permute(1, 2, 0).cpu().numpy()
        boxes = bbox_tensor[idx].cpu().numpy()  # Shape: (3, 4)
        ax.imshow(img_np)
        for box in boxes:
            x, y, w, h = box
            rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='red', facecolor='none')
            ax.add_patch(rect)
        ax.axis('off')        
        ax.set_title(f"{proto_labels_shapes[idx].tolist()}", fontsize=8)

    plt.tight_layout()
    plt.show()

*Plots the atomic concept images augmented through anti aliasing*

In [5]:
def plot_augmented_results_with_antialiasing(original_tensor, augmented_tensor, augmented_labels, white_thresholds):
    """
    Plots the original images and the augmented images for each white threshold.
    
    The first row contains the original images.
    Each subsequent row contains the augmented images for one white threshold.
    
    Args:
        original_tensor (Tensor): Original images of shape (9, 3, 64, 64).
        augmented_tensor (Tensor): Augmented images of shape (T*9, 3, 64, 64), where T is the 
                                     number of white_threshold values.
        augmented_labels (Tensor): Corresponding labels of shape (T*9, 2).
        white_thresholds (list or iterable): The list of white thresholds used for augmentation.
    """
    T = len(white_thresholds)    # Number of threshold values applied
    N = original_tensor.shape[0] # 9 images
    total_rows = 1 + T           # First row for originals, then one row per threshold
    
    fig, axes = plt.subplots(total_rows, N, figsize=(3 * N, 3 * total_rows))
    
    # Plot the original images (first row)
    for j in range(N):
        ax = axes[0, j] if total_rows > 1 else axes[j]
        img_orig = original_tensor[j].permute(1, 2, 0).cpu().numpy()
        ax.imshow(img_orig)
        ax.set_title("Original")
        ax.axis("off")
    
    # Reshape augmented_tensor and augmented_labels from (T*9, 3, 64, 64) to (T, 9, 3, 64, 64)
    augmented_tensor = augmented_tensor.view(T, N, 3, 64, 64)
    augmented_labels = augmented_labels.view(T, N, 2)
    
    # Plot each augmentation row
    for t in range(T):
        for j in range(N):
            ax = axes[t + 1, j]
            img_aug = augmented_tensor[t, j].permute(1, 2, 0).cpu().numpy()
            label = augmented_labels[t, j].tolist()
            ax.imshow(img_aug)
            ax.set_title(f"Threshold {white_thresholds[t]}, Labels: {label}")
            ax.axis("off")
    
    plt.tight_layout()
    plt.show()

*Plots the atomic concept images augmented through scalings*

In [6]:
def plot_augmented_results_with_scaling(original_tensor, augmented_tensor, augmented_labels, scaling_factors):
    """
    Plots the original images and the augmented (scaled) images for each scaling factor.
    
    The first row contains the original images.
    Each subsequent row contains the augmented images for one scaling factor.
    
    Args:
        original_tensor (Tensor): Original images of shape (N, 3, 64, 64).
        augmented_tensor (Tensor): Augmented images of shape (T*N, 3, 64, 64), where T is the number of scaling factors.
        augmented_labels (Tensor): Corresponding labels of shape (T*N, ?).
        scaling_factors (list): List of scaling factors used for augmentation.
    """
    T = len(scaling_factors)    # Number of scaling factors
    N = original_tensor.shape[0]  # Number of images
    total_rows = 1 + T           # First row for originals, then one row per factor
    
    fig, axes = plt.subplots(total_rows, N, figsize=(3 * N, 3 * total_rows))
    
    # Plot the original images (first row)
    for j in range(N):
        ax = axes[0, j] if total_rows > 1 else axes[j]
        img_orig = original_tensor[j].permute(1, 2, 0).cpu().numpy()
        ax.imshow(img_orig)
        ax.set_title("Original")
        ax.axis("off")
    
    # Reshape augmented_tensor and augmented_labels from (T*N, 3, 64, 64) to (T, N, 3, 64, 64)
    augmented_tensor = augmented_tensor.view(T, N, 3, 64, 64)
    # Labels are reshaped accordingly (if they have more than one element, adjust accordingly)
    augmented_labels = augmented_labels.view(T, N, -1)
    
    # Plot each augmentation row
    for t in range(T):
        for j in range(N):
            ax = axes[t + 1, j]
            img_aug = augmented_tensor[t, j].permute(1, 2, 0).cpu().numpy()
            label = augmented_labels[t, j].tolist()
            ax.imshow(img_aug)
            ax.set_title(f"Label: {label}")
            ax.axis("off")
    
    plt.tight_layout()
    plt.show()

*Plots the atomic concept images augmented through translations*

In [7]:
def plot_augmented_results_with_translation(original_tensor, augmented_tensor, augmented_labels, translation_offsets):
    """
    Plots the original images and the augmented (translated) images for each translation offset.
    
    The first row contains the original images.
    Each subsequent row contains the augmented images for one translation offset.
    
    Args:
        original_tensor (Tensor): Original images of shape (N, 3, 64, 64).
        augmented_tensor (Tensor): Augmented images of shape (T*N, 3, 64, 64), where T is the number of translation offsets.
        augmented_labels (Tensor): Corresponding labels of shape (T*N, ?).
        translation_offsets (list): List of (dx, dy) tuples used for augmentation.
    """
    T = len(translation_offsets)  # Number of translation offsets
    N = original_tensor.shape[0]  # Number of images
    total_rows = 1 + T           # First row for originals, then one row per offset
    
    fig, axes = plt.subplots(total_rows, N, figsize=(3 * N, 3 * total_rows))
    
    # Plot the original images (first row)
    for j in range(N):
        ax = axes[0, j] if total_rows > 1 else axes[j]
        img_orig = original_tensor[j].permute(1, 2, 0).cpu().numpy()
        ax.imshow(img_orig)
        ax.set_title("Original")
        ax.axis("off")
    
    # Reshape augmented_tensor and augmented_labels from (T*N, 3, 64, 64) to (T, N, 3, 64, 64)
    augmented_tensor = augmented_tensor.view(T, N, 3, 64, 64)
    augmented_labels = augmented_labels.view(T, N, -1)
    
    # Plot each augmentation row
    for t in range(T):
        for j in range(N):
            ax = axes[t + 1, j]
            img_aug = augmented_tensor[t, j].permute(1, 2, 0).cpu().numpy()
            label = augmented_labels[t, j].tolist()
            dx, dy = translation_offsets[t]
            ax.imshow(img_aug)
            ax.set_title(f"Trans ({dx}, {dy}), Label: {label}")
            ax.axis("off")
    
    plt.tight_layout()
    plt.show()


*Plots aggregated concepts*

In [8]:
def plot_image_with_boxes(image, boxes, ax):
    h, w = image.shape[:2]
    
    ax.imshow(image, origin='upper', extent=[0, w, h, 0])
    ax.set_xlim(0, w)
    ax.set_ylim(h, 0)
    
    for box in boxes:
        x, y, w_box, h_box = box
        rect = patches.Rectangle((x, y), w_box, h_box,
                                 linewidth=2, edgecolor='red', facecolor='none')
        ax.add_patch(rect)
    
    ax.axis('off')

*Translate aggregated concepts and their bounding boxes*

In [9]:
def translate_bounding_boxes_and_image(image, boxes, desired_translation):
    h, w = image.shape[:2] 
    allowed_tx_min = []
    allowed_tx_max = []
    allowed_ty_min = []
    allowed_ty_max = []
    
    for (x, y, w_box, h_box) in boxes:
        allowed_tx_min.append(-x)
        allowed_tx_max.append(w - (x + w_box))
        
        allowed_ty_min.append(-y)
        allowed_ty_max.append(h - (y + h_box))
    
    global_tx_min = max(allowed_tx_min)
    global_tx_max = min(allowed_tx_max)   
    global_ty_min = max(allowed_ty_min)
    global_ty_max = min(allowed_ty_max)
    
    tx_desired, ty_desired = desired_translation
    tx = min(max(tx_desired, global_tx_min), global_tx_max)
    ty = min(max(ty_desired, global_ty_min), global_ty_max)
    
    translated_image = np.ones_like(image)
    
    translated_boxes = []
    for (x, y, w_box, h_box) in boxes:
        shape_crop = image[y:y + h_box, x:x + w_box]
        new_x = x + tx
        new_y = y + ty
        int_new_x = int(round(new_x))
        int_new_y = int(round(new_y))
        int_new_x = max(0, min(int_new_x, w - w_box))
        int_new_y = max(0, min(int_new_y, h - h_box))
        translated_image[int_new_y:int_new_y + h_box, int_new_x:int_new_x + w_box] = shape_crop
        translated_boxes.append([int_new_x, int_new_y, w_box, h_box])
    
    return translated_image, translated_boxes

*Iteratively applies the function to augment aggregated concepts through translations*

In [10]:
def augment_triplet(proto_imgs, annotations_dir, idx):
    image_size = 64
    img_tensor = proto_imgs[idx]
    img_np = img_tensor.permute(1, 2, 0).cpu().numpy()
    
    annotation_filename = os.path.join(annotations_dir, f"image_{idx}_annotations.json")
    with open(annotation_filename, 'r') as f:
        ann = json.load(f)
    original_boxes = ann["boxes"]
    original_labels = ann["labels"]
    
    scaling_values = [2, 4, 6, 8, 10]
    directions = {
        "north":        (0, -1),
        "north east":   (1, -1),
        "north west":   (-1, -1),
        "west":         (-1, 0),
        "east":         (1, 0),
        "south":        (0, 1),
        "south east":   (1, 1),
        "south west":   (-1, 1),
        
        "north 2":      (0, -2),
        "south 2":      (0, 2),
        "west 2":       (-2, 0),
        "east 2":       (2, 0),

        "north east 2": (2, -2),
        "north west 2": (-2, -2),
        "south east 2": (2, 2),
        "south west 2": (-2, 2)
    }

    direction_order = [
        "north", "north east", "north west",
        "west", "east", "south",
        "south east", "south west",
        "north 2", "south 2", "west 2", "east 2",
        "north east 2", "north west 2", "south east 2", "south west 2"
    ]

    augmented_images_list = []
    augmented_boxes_list = []
    
    for scale in scaling_values:
        for d in direction_order:
            multiplier = directions[d]
            desired_translation = (scale * multiplier[0], scale * multiplier[1])
            
            translated_image, translated_boxes = translate_bounding_boxes_and_image(
                img_np, original_boxes, desired_translation
            )
            
            translated_image_tensor = translated_image.transpose(2, 0, 1)
            
            assert translated_image_tensor.shape == (3, image_size, image_size), (
                f"Shape mismatch: {translated_image_tensor.shape}"
            )
            assert translated_image_tensor.min() >= 0.0 and translated_image_tensor.max() <= 1.0, (
                "Pixel values out of range"
            )
            
            augmented_images_list.append(torch.from_numpy(translated_image_tensor))
            augmented_boxes_list.append(translated_boxes)
    
    AUG_proto_imgs = torch.stack(augmented_images_list)
    num_augmentations = len(scaling_values) * len(direction_order)
    AUG_proto_labels = torch.tensor([original_labels] * num_augmentations)
    AUG_bbox_tensor = torch.tensor(augmented_boxes_list)
    
    return AUG_proto_imgs, AUG_proto_labels, AUG_bbox_tensor

# DATA LOADING

## Prototypes setup

In [11]:
# load initial prototypes
proto_imgs = torch.load('data/kand_annotations/init/concepts_init_aggregated.pt')
proto_labels = torch.load('data/kand_annotations/init/labels_init_aggregated.pt')
bbox_tensor = get_proto_imgs_bboxes()

print(proto_imgs.shape)
print(proto_labels.shape)
print(bbox_tensor.shape)

FileNotFoundError: [Errno 2] No such file or directory: 'data/kand_annotations/init/concepts_init_aggregated.pt'

In [None]:
plot_prototypes(proto_imgs_shapes=proto_imgs, proto_labels_shapes=proto_labels, bbox_tensor=bbox_tensor)

In [None]:
bbox_tensor[0]

# DATA AUGMENTATIONS

## YOLO

### Translations

In [None]:
# Define the range of image indices (0 to 26).
image_indices = range(3)

# Initialize lists to hold the stacked augmented tensors for all images
all_augmented_images = []
all_augmented_labels = []
all_augmented_boxes = []

annotations_dir=f"data/kand_annotations/init/bboxes_init"

# Loop through each image index and generate augmentations
for selected_idx in image_indices:
    # Call augment_triplet function to get the augmented data for this image
    AUG_proto_imgs, AUG_proto_labels, AUG_bbox_tensor = augment_triplet(proto_imgs, annotations_dir, selected_idx)
    
    # Add the results to the lists
    all_augmented_images.append(AUG_proto_imgs)
    all_augmented_labels.append(AUG_proto_labels)
    all_augmented_boxes.append(AUG_bbox_tensor)

# Stack the lists of augmented data across all images
AUG_proto_imgs = torch.stack(all_augmented_images)  # Shape: (27, 40, 3, 64, 64)
AUG_proto_labels = torch.stack(all_augmented_labels)  # Shape: (27, 40, 3)
AUG_bbox_tensor = torch.stack(all_augmented_boxes)  # Shape: (27, 40, 3, 4)

# Randomly select 3 image indices from the range 0 to 26 for visualization
random_indices = random.sample(image_indices, 3)

# Loop through the randomly selected indices and plot the original image and its 40 augmentations
for i, idx in enumerate(random_indices):
    # Get the original image and annotation for the selected index
    img_tensor = proto_imgs[idx]
    img_np = img_tensor.permute(1, 2, 0).cpu().numpy()
    
    annotation_filename = os.path.join(annotations_dir, f"image_{idx}_annotations.json")
    with open(annotation_filename, 'r') as f:
        ann = json.load(f)
    original_boxes = ann["boxes"]
    
    # Plot the original image with bounding boxes
    print(f"IMAGE {i+1} CHECK")
    fig, ax_orig = plt.subplots(1, 1, figsize=(6, 6))
    plot_image_with_boxes(img_np, original_boxes, ax_orig)
    ax_orig.set_title(f"Original Image {idx}")
    plt.show()
    
    # Plot a 5 (scales) x 8 (directions) grid of augmented images.
    scaling_values = [2, 4, 6, 8, 10]
    direction_order = ["north", "north east", "north ovest", "left", "right", "south", "south east", "south ovest"]

    fig, axes = plt.subplots(len(scaling_values), len(direction_order), figsize=(16, 10))
    for row_idx, scale in enumerate(scaling_values):
        for col_idx, d in enumerate(direction_order):
            i = row_idx * len(direction_order) + col_idx
            aug_img = AUG_proto_imgs[idx][i].numpy().transpose(1, 2, 0)   # shape: (H, W, C)
            aug_boxes = AUG_bbox_tensor[idx][i].tolist()  # shape: (3, 4)
            
            ax = axes[row_idx, col_idx]
            plot_image_with_boxes(aug_img, aug_boxes, ax)
            ax.set_title(f"{d}\nScale: {scale}", fontsize=8)
    plt.tight_layout()
    plt.show()
    
    if i < 2:  # Only add these checks for the first 3 images
        print(f"FIRST CHECK" if i == 0 else "SECOND CHECK" if i == 1 else "THIRD CHECK")

### Data check

In [None]:
# Add one dimension to match the augmented images
proto_imgs_unsq = proto_imgs.unsqueeze(1)  
proto_labels_unsq = proto_labels.unsqueeze(1) 
bbox_tensor_unsq = bbox_tensor.unsqueeze(1)  

print("Unsqueezed original images: ", proto_imgs_unsq.shape)
print("Unsqueezed original labels: ", proto_labels_unsq.shape)
print("Unsqueezed original bounding boxes: ", bbox_tensor_unsq.shape)
print()

print("Augmented prototypes: ", AUG_proto_imgs.shape)
print("Augmented labels: ", AUG_proto_labels.shape)
print("Augmented bounding boxes: ", AUG_bbox_tensor.shape)
print()

# Stack images, labels and bounding boxes (initial + augmented)
final_images = torch.cat((proto_imgs_unsq, AUG_proto_imgs), dim=1)  
final_labels = torch.cat((proto_labels_unsq, AUG_proto_labels), dim=1)
final_bboxes = torch.cat((bbox_tensor_unsq, AUG_bbox_tensor), dim=1)  

final_images_ = final_images.view(-1, 3, 64, 64)
final_labels_ = final_labels.view(-1, final_labels.shape[-1])
final_bboxes_ = final_bboxes.view(-1, 3, 4)
print("Final images flattened shape: ", final_images_.shape) 
print("Final labels flattened shape:", final_labels_.shape) 
print("Final bounding boxes flattened shape: :", final_bboxes_.shape)  

In [None]:
print("Final images shape: ", final_images_.shape)  
print("Final labels shape: ", final_labels_.shape) 
print("Final bounding boxes shape: ", final_bboxes_.shape) 
print("Range of pixel values in final_images:", final_images_.min().item(), "to", final_images_.max().item())
print("Dtype of final_images:", final_images_.dtype)
print("Dtype of final_labels:", final_labels_.dtype)
print("Range of values in final_bboxes:", final_bboxes_.min().item(), "to", final_bboxes_.max().item())
print("Dtype of final_bboxes:", final_bboxes_.dtype)
print()

### Visual check

In [None]:
def plot_flat_images(images_flat, labels_flat, bboxes_flat):
    num_images = images_flat.shape[0]
    # Define grid layout: 9 images per row (adjust as needed)
    cols = 9
    rows = (num_images + cols - 1) // cols

    fig, axs = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
    axs = axs.flatten()

    for idx in range(num_images):
        ax = axs[idx]
        # Convert image from [channels, 64, 64] to [64, 64, channels]
        img_np = images_flat[idx].permute(1, 2, 0).cpu().numpy()
        # Get bounding boxes (assumed to be in the correct format)
        boxes = bboxes_flat[idx].cpu().numpy()
        # Use your custom function to plot the image with boxes
        plot_image_with_boxes(img_np, boxes, ax)

        ax.axis('off')
        ax.set_title(f"{labels_flat[idx].tolist()}", fontsize=8)

    # Turn off any remaining subplots
    for j in range(num_images, len(axs)):
        axs[j].axis('off')

    plt.tight_layout()
    plt.show()

plot_flat_images(final_images_, final_labels_, final_bboxes_)

### Saving

In [None]:
torch.save(final_images_, 'data/kand_annotations/yolo_annotations/images.pt')
torch.save(final_labels_, 'data/kand_annotations/yolo_annotations/labels.pt')
torch.save(final_bboxes_, 'data/kand_annotations/yolo_annotations/bboxes.pt')

## PNets

### Prototypes extraction

In [None]:
def extract_shapes(proto_imgs, proto_labels, bbox_tensor):
    """
    Args:
        proto_imgs (Tensor): Shape (3, 3, 64, 64)
        proto_labels (Tensor): Shape (3, 6)
        bbox_tensor (Tensor): Shape (3, 3, 4) where each box is (x, y, w, h)
    
    Returns:
        shapes_tensor (Tensor): Tensor of cropped and resized shapes with shape (9, 3, 64, 64)
        labels_tensor (Tensor): Tensor of corresponding labels with shape (9, 2)
    """
    shape_crops = []
    label_pairs = []
    
    for i in range(proto_imgs.shape[0]):
        img = proto_imgs[i]         # (3, 64, 64)
        boxes = bbox_tensor[i]      # (3, 4)
        labels = proto_labels[i]    # (6,)
        
        for j in range(boxes.shape[0]):
            x, y, w, h = boxes[j].tolist()
            x, y, w, h = int(x), int(y), int(w), int(h)
            
            crop = img[:, y:y+h, x:x+w]
            
            # Resize the cropped image to 64x64.
            resize_transform = T.Resize((64, 64))
            crop_resized = resize_transform(crop)
            
            shape_crops.append(crop_resized)
            
            pair = torch.tensor([labels[j], labels[j+3]])
            label_pairs.append(pair)
    
    shapes_tensor = torch.stack(shape_crops, dim=0)
    labels_tensor = torch.stack(label_pairs, dim=0) 

    min_val = min(shape.min().item() for shape in shapes_tensor)
    max_val = max(shape.max().item() for shape in shapes_tensor)
    assert min_val >= 0 and max_val <= 1, f"Min: {min_val}, Max: {max_val}"
    assert shapes_tensor.shape == (9, 3, 64, 64), f"Shapes: {shapes_tensor.shape}, but expected (9, 3, 64, 64)"
    assert labels_tensor.shape == (9, 2), f"Labels: {labels_tensor.shape}, but expected (9, 2)"    
    return shapes_tensor, labels_tensor


### Plot

In [None]:
proto_tensor, label_tensor = extract_shapes(proto_imgs, proto_labels, bbox_tensor)
fig, axes = plt.subplots(3, 3, figsize=(12, 12))

for i in range(3):
    for j in range(3):
        idx = i * 3 + j
        img_np = proto_tensor[idx].permute(1, 2, 0).cpu().numpy()
        label = label_tensor[idx].tolist()
        
        axes[i, j].imshow(img_np)
        axes[i, j].axis("off")
        axes[i, j].set_title(f"Labels: {label}", fontsize=10, fontweight="bold")

plt.tight_layout()
plt.show()

### Anti-aliasing

*antialias_image*: applies the anti alias to a single image with the specified threshold for white pixels

*augment_images_with_antialiasing*: calls the function above for each image and for each threshold

In [None]:
def antialias_image(image_tensor, kernel_size, sigma, white_threshold):
    """
    Augment one image by blurring its shape's edges.
    
    Args:
        image_tensor (Tensor): A tensor of shape (3, 64, 64) with pixel values in [0,1].
        kernel_size (int): Kernel size for the Gaussian blur (should be odd).
        sigma (float): Standard deviation for the Gaussian blur.
        
    Returns:
        Tensor: Augmented image tensor of shape (3, 64, 64) with values in [0,1].
    """
    # Convert tensor to numpy image (H, W, C) and scale to [0,255]
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()  # shape (64, 64, 3)
    image_np = np.clip(image_np, 0, 1)
    image_np = (image_np * 255).astype(np.uint8)
    
    background_mask = np.all(image_np > white_threshold, axis=-1)  # True for white pixels
    shape_mask = (~background_mask).astype(np.float32)
    
    blurred_mask = cv2.GaussianBlur(shape_mask, (kernel_size, kernel_size), sigma)
    blurred_mask = np.clip(blurred_mask, 0, 1)
    
    white_img = np.full_like(image_np, 255, dtype=np.uint8)
    
    blended = (image_np.astype(np.float32) * blurred_mask[..., None] +
               white_img.astype(np.float32) * (1 - blurred_mask[..., None]))
    blended = np.clip(blended, 0, 255).astype(np.uint8)
    
    blended_tensor = torch.tensor(blended).permute(2, 0, 1).float() / 255.0
    return blended_tensor


def augment_images_with_antialiasing(shapes_tensor, labels_tensor, kernel_size, sigma, white_thresholds):
    """
    Apply edge anti-aliasing augmentation to a batch of images while maintaining label consistency.

    Args:
        shapes_tensor (Tensor): Input tensor of shape (N, 3, 64, 64)
        labels_tensor (Tensor): Corresponding labels of shape (N, 2)
        kernel_size (int): Gaussian blur kernel size.
        sigma (float): Gaussian blur sigma.
        white_thresholds (list): List of white background thresholds to apply.

    Returns:
        Tuple[Tensor, Tensor]: Augmented images of shape (N * len(white_thresholds), 3, 64, 64),
                               Corresponding labels of shape (N * len(white_thresholds), 2).
    """
    augmented_images = []
    augmented_labels = []

    for threshold in white_thresholds:
        for i in range(shapes_tensor.shape[0]):
            aug_img = antialias_image(shapes_tensor[i], kernel_size=kernel_size, sigma=sigma, white_threshold=threshold)
            augmented_images.append(aug_img)
            augmented_labels.append(labels_tensor[i])  # Keep the same label

    return torch.stack(augmented_images, dim=0), torch.stack(augmented_labels, dim=0)

Plot the results of anti aliasing augmentations

In [None]:
white_thresholds = list(np.arange(50, 250, 5))
augmented_tensor_anti_aliasing, augmented_labels_anti_aliasing = augment_images_with_antialiasing(
    proto_tensor, label_tensor, kernel_size=15, sigma=7, white_thresholds=white_thresholds)

min_val = min(shape.min().item() for shape in augmented_tensor_anti_aliasing)
max_val = max(shape.max().item() for shape in augmented_tensor_anti_aliasing)
assert min_val >= 0 and max_val <= 1, f"Min: {min_val}, Max: {max_val}"
assert augmented_tensor_anti_aliasing.shape == (9*len(white_thresholds), 3, 64, 64), f"Expected ({9*len(white_thresholds)}, 3, 64, 64), but got {augmented_tensor.shape}"
    
# Example usage:
plot_augmented_results_with_antialiasing(proto_tensor, augmented_tensor_anti_aliasing, augmented_labels_anti_aliasing, white_thresholds)
print(augmented_tensor_anti_aliasing.shape)

### Scalings

*scale_image*: applies the anti alias to a single image with the specified threshold for white pixels

*augment_images_with_scaling*: calls the function above for each image and scaling factor

In [None]:
def scale_image(image_tensor, scale_factor):
    """
    Augment one image by scaling its shape and centering it on a white canvas.
    
    Args:
        image_tensor (Tensor): A tensor of shape (3, 64, 64) with pixel values in [0,1].
        scale_factor (float): Factor by which to scale the shape.
    
    Returns:
        Tensor: Augmented image tensor of shape (3, 64, 64) with values in [0,1].
    """
    # Convert tensor to numpy image (H, W, C) and scale to [0,255]
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
    image_np = np.clip(image_np, 0, 1)
    image_np = (image_np * 255).astype(np.uint8)
    
    # Create a mask of the shape (non-white pixels)
    shape_mask = np.any(image_np < 255, axis=-1)
    
    if not np.any(shape_mask):
        return image_tensor
    
    coords = np.argwhere(shape_mask)
    y0, x0 = coords.min(axis=0)
    y1, x1 = coords.max(axis=0) + 1  # add one to include the last index
    shape_crop = image_np[y0:y1, x0:x1]
    
    new_w = max(1, int(shape_crop.shape[1] * scale_factor))
    new_h = max(1, int(shape_crop.shape[0] * scale_factor))
    scaled_shape = cv2.resize(shape_crop, (new_w, new_h), interpolation=cv2.INTER_AREA)
    
    canvas = np.full_like(image_np, 255, dtype=np.uint8)
    canvas_h, canvas_w, _ = canvas.shape
    
    start_x = (canvas_w - new_w) // 2
    start_y = (canvas_h - new_h) // 2

    src_x_start = 0
    src_y_start = 0
    dst_x_start = start_x
    dst_y_start = start_y
    if dst_x_start < 0:
        src_x_start = -dst_x_start
        dst_x_start = 0
    if dst_y_start < 0:
        src_y_start = -dst_y_start
        dst_y_start = 0

    paste_w = min(new_w - src_x_start, canvas_w - dst_x_start)
    paste_h = min(new_h - src_y_start, canvas_h - dst_y_start)
    
    scaled_shape_cropped = scaled_shape[src_y_start:src_y_start+paste_h, src_x_start:src_x_start+paste_w]
    canvas[dst_y_start:dst_y_start+paste_h, dst_x_start:dst_x_start+paste_w] = scaled_shape_cropped
    scaled_tensor = torch.tensor(canvas).permute(2, 0, 1).float() / 255.0
    return scaled_tensor


def augment_images_with_scaling(shapes_tensor, labels_tensor, scaling_factors, record_factor=0.8):
    """
    Apply scaling augmentation to a batch of images.
    
    Args:
        shapes_tensor (Tensor): Input tensor of shape (N, 3, 64, 64)
        labels_tensor (Tensor): Corresponding labels tensor of shape (N, ?)
        scaling_factors (list): List of scaling factors to apply.
    
    Returns:
        Tuple[Tensor, Tensor]: Augmented images of shape (N * len(scaling_factors), 3, 64, 64),
                               Corresponding labels of shape (N * len(scaling_factors), ?)
    """
    augmented_images = []
    augmented_labels = []
    
    for factor in scaling_factors:
        for i in range(shapes_tensor.shape[0]):
            aug_img = scale_image(shapes_tensor[i], scale_factor=factor)
            augmented_images.append(aug_img)
            augmented_labels.append(labels_tensor[i])  # Keep same label
                
    return torch.stack(augmented_images, dim=0), torch.stack(augmented_labels, dim=0)

Plot the results of scaling augmentations

In [None]:
scaling_factors = list(np.arange(0.3, 1.4, 0.05))
augmented_tensor_scaling, augmented_labels_scaling = augment_images_with_scaling(proto_tensor, label_tensor, scaling_factors)

min_val = min(shape.min().item() for shape in augmented_tensor_scaling)
max_val = max(shape.max().item() for shape in augmented_tensor_scaling)
assert min_val >= 0 and max_val <= 1, f"Min: {min_val}, Max: {max_val}"
assert augmented_tensor_scaling.shape == (9*len(scaling_factors), 3, 64, 64), f"Expected ({9*len(scaling_factors)}, 3, 64, 64), but got {augmented_tensor_scaling.shape}"

plot_augmented_results_with_scaling(proto_tensor, augmented_tensor_scaling, augmented_labels_scaling, scaling_factors)
print("Scaling augmented tensor shape:", augmented_tensor_scaling.shape)

### Translations

*translate_image*: translates a single image

*augment_images_with_translation*. calles the above function for each image and for each translation coordinates

In [None]:
def translate_image(image_tensor, dx, dy):
    """
    Augment one image by translating it.
    
    Args:
        image_tensor (Tensor): A tensor of shape (3, 64, 64) with pixel values in [0,1].
        dx (int): Translation offset in the x-direction (pixels).
        dy (int): Translation offset in the y-direction (pixels).
    
    Returns:
        Tensor: Augmented image tensor of shape (3, 64, 64) with values in [0,1].
    """
    # Convert tensor to numpy image (H, W, C) and scale to [0,255]
    image_np = image_tensor.permute(1, 2, 0).cpu().numpy()
    image_np = np.clip(image_np, 0, 1)
    image_np = (image_np * 255).astype(np.uint8)
    
    h, w, _ = image_np.shape
    M = np.float32([[1, 0, dx], [0, 1, dy]])
    
    # Apply translation using warpAffine with white border
    translated_np = cv2.warpAffine(image_np, M, (w, h), borderValue=(255, 255, 255))
    
    # Convert back to tensor in [0,1]
    translated_tensor = torch.tensor(translated_np).permute(2, 0, 1).float() / 255.0
    return translated_tensor

def augment_images_with_translation(shapes_tensor, labels_tensor, translation_offsets):
    """
    Apply translation augmentation to a batch of images.
    
    Args:
        shapes_tensor (Tensor): Input tensor of shape (N, 3, 64, 64).
        labels_tensor (Tensor): Corresponding labels tensor of shape (N, ?).
        translation_offsets (list): List of (dx, dy) tuples to apply.
    
    Returns:
        Tuple[Tensor, Tensor]: Augmented images of shape (N * len(translation_offsets), 3, 64, 64),
                               Corresponding labels of shape (N * len(translation_offsets), ?).
    """
    augmented_images = []
    augmented_labels = []
    
    for (dx, dy) in translation_offsets:
        for i in range(shapes_tensor.shape[0]):
            aug_img = translate_image(shapes_tensor[i], dx=dx, dy=dy)
            augmented_images.append(aug_img)
            augmented_labels.append(labels_tensor[i])  # Keep same label
    
    return torch.stack(augmented_images, dim=0), torch.stack(augmented_labels, dim=0)

Plot the results of translation augmentations

In [None]:
translation_offsets = [
    (-0.1, 0), (0, -0.1), (0.1, 0), (0, 0.1),
    (-0.2, 0), (0, -0.2), (0.2, 0), (0, 0.2),
    (-0.3, 0), (0, -0.3), (0.3, 0), (0, 0.3),
    (-0.4, 0), (0, -0.4), (0.4, 0), (0, 0.4),
    (-0.5, 0), (0, -0.5), (0.5, 0), (0, 0.5),
    (-0.6, 0), (0, -0.6), (0.6, 0), (0, 0.6),
    (-0.7, 0), (0, -0.7), (0.7, 0), (0, 0.7),
    (-0.8, 0), (0, -0.8), (0.8, 0), (0, 0.8),
    (-0.9, 0), (0, -0.9), (0.9, 0), (0, 0.9),
    (-1, 0), (0, -1), (1, 0), (0, 1),
    (-1.1, 0), (0, -1.1), (1.1, 0), (0, 1.1),
    (-1.2, 0), (0, -1.2), (1.2, 0), (0, 1.2),
    (-1.3, 0), (0, -1.3), (1.3, 0), (0, 1.3),
    (-1.4, 0), (0, -1.4), (1.4, 0), (0, 1.4),
    (-1.5, 0), (0, -1.5), (1.5, 0), (0, 1.5),
    (-1.6, 0), (0, -1.6), (1.6, 0), (0, 1.6),
    (-1.7, 0), (0, -1.7), (1.7, 0), (0, 1.7),
    (-1.8, 0), (0, -1.8), (1.8, 0), (0, 1.8),
    (-1.9, 0), (0, -1.9), (1.9, 0), (0, 1.9),
    (-2, 0), (0, -2), (2, 0), (0, 2),
    (-2.1, 0), (0, -2.1), (2.1, 0), (0, 2.1),
    (-2.2, 0), (0, -2.2), (2.2, 0), (0, 2.2),
    (-2.3, 0), (0, -2.3), (2.3, 0), (0, 2.3),
    (-2.4, 0), (0, -2.4), (2.4, 0), (0, 2.4),
    (-2.5, 0), (0, -2.5), (2.5, 0), (0, 2.5),
] 
augmented_tensor_translation, augmented_labels_translation = augment_images_with_translation(
    proto_tensor, label_tensor, translation_offsets)

min_val = min(shape.min().item() for shape in augmented_tensor_translation)
max_val = max(shape.max().item() for shape in augmented_tensor_translation)
assert min_val >= 0 and max_val <= 1, f"Min: {min_val}, Max: {max_val}"
assert augmented_tensor_translation.shape == (9*len(translation_offsets), 3, 64, 64), f"Expected ({9*len(translation_offsets)}, 3, 64, 64), but got {augmented_tensor_translation.shape}"

plot_augmented_results_with_translation(proto_tensor, augmented_tensor_translation, augmented_labels_translation, translation_offsets)
print("Translation augmented tensor shape:", augmented_tensor_translation.shape)

# Saving

In [None]:
# Merge all augmented image tensors
final_images = torch.cat([
    proto_tensor,
    augmented_tensor_anti_aliasing,
    augmented_tensor_scaling,
    augmented_tensor_translation
], dim=0)

# Merge all corresponding labels
final_labels = torch.cat([
    label_tensor,
    augmented_labels_anti_aliasing,
    augmented_labels_scaling,
    augmented_labels_translation
], dim=0)

# Print final shape to verify
print(f"Final dataset shape: {final_images.shape}")  # Should be (total_samples, 3, 64, 64)
print(f"Final labels shape: {final_labels.shape}")  # Should be (total_samples, 6)
print("Range of pixel values in final_images:", final_images.min().item(), "to", final_images.max().item())
print("Dtype of final_images:", final_images.dtype)
print("Dtype of final_labels:", final_labels.dtype)
print()

torch.save(final_images, 'data/kand_annotations/pnet_proto/concept_prototypes.pt')
torch.save(final_labels, 'data/kand_annotations/pnet_proto/labels_prototypes.pt')