In [1]:
import os
import time
import json
import random
import numpy as np
from PIL import Image
from tabulate import tabulate
from tqdm import tqdm
from typing import List, Tuple
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    precision_score,
    recall_score,
)
import psutil
from pynvml import (
    nvmlInit,
    nvmlDeviceGetHandleByIndex,
    nvmlDeviceGetMemoryInfo,
    nvmlShutdown,
)
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import matplotlib.pyplot as plt

# --- SETTINGS --- #
SEED = 42
TILE_SIZE = 64
INPUT_SIZE = (640, 640)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_PARTS = 22  # Adjust if needed

# --- SEEDING --- #
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# --- TRANSFORMS --- #
transform = transforms.Compose([
    transforms.Resize((TILE_SIZE, TILE_SIZE)),
    transforms.ToTensor(),
])

def visualize_tile_locations(image_path, center, all_parts, offsets):
    image = Image.open(image_path).convert("RGB")
    fig, ax = plt.subplots()
    ax.imshow(image)
    cx, cy = center
    ax.scatter(cx, cy, c='r', label="Anchor")
    for part in all_parts:
        dx, dy = offsets.get(part, (0, 0))
        px = cx + dx
        py = cy + dy
        ax.scatter(px, py, alpha=0.7, label=part)
        rect = plt.Rectangle((px - TILE_SIZE//2, py - TILE_SIZE//2), TILE_SIZE, TILE_SIZE,
                             linewidth=1, edgecolor='blue', facecolor='none')
        ax.add_patch(rect)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title("Estimated Tile Positions")
    plt.tight_layout()
    plt.show()

def visualize_model_decision(model, dataset, all_parts, image_dir, sample_idx=0):
    import matplotlib.patches as patches

    model.eval()
    img_name = dataset.images[sample_idx]
    img_path = os.path.join(image_dir, img_name)
    image = Image.open(img_path).convert("RGB")

    # Ground truth
    true_labels = dataset.annotations[img_name].get("missing_parts", [])
    true_mask = {p: 1 for p in true_labels}

    # Get tiles and prediction
    tiles, label_tensor = dataset[sample_idx]
    tiles = tiles.unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(tiles)
        probs = torch.sigmoid(output).cpu().squeeze().numpy()
        preds = (probs > 0.5).astype(int)

    # Anchor is at image center
    cx, cy = image.size[0] // 2, image.size[1] // 2
    tile_centers = estimate_other_tiles((cx, cy), all_parts, dataset.average_offsets)

    # Plot image and predictions
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(image)

    for i, (part, (px, py)) in enumerate(zip(all_parts, tile_centers)):
        pred = preds[i]
        true = true_mask.get(part, 0)

        # Determine color
        if pred == 1 and true == 0:
            color = "blue"    # False positive
        elif pred == 1 and true == 1:
            color = "green"   # True positive
        elif pred == 0 and true == 1:
            color = "red"     # False negative
        else:
            color = "black"   # True negative

        rect = patches.Rectangle(
            (px - TILE_SIZE//2, py - TILE_SIZE//2),
            TILE_SIZE, TILE_SIZE,
            linewidth=2, edgecolor=color, facecolor='none'
        )
        ax.add_patch(rect)
        ax.text(px, py, part, fontsize=8, color=color, ha='center')

    plt.title(f"Prediction vs Ground Truth for '{img_name}'")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

# --- UTILITIES TO FIND ANCHOR AND AVG OFFSETS --- #

def part_in_center_tile(part_bbox, image_size, tile_size=TILE_SIZE):
    """Check if part bbox center lies inside center tile area."""
    img_w, img_h = image_size
    cx, cy = img_w // 2, img_h // 2
    tile_left = cx - tile_size // 2
    tile_top = cy - tile_size // 2
    tile_right = tile_left + tile_size
    tile_bottom = tile_top + tile_size

    part_cx = part_bbox['left'] + part_bbox['width'] / 2
    part_cy = part_bbox['top'] + part_bbox['height'] / 2

    return (tile_left <= part_cx <= tile_right) and (tile_top <= part_cy <= tile_bottom)

def find_anchor_part(images_subset, all_parts, image_dir):
    anchor_counts = {part: 0 for part in all_parts}
    for img_id, img_data in images_subset.items():
        img_path = os.path.join(image_dir, img_id)
        with Image.open(img_path) as img:
            img_w, img_h = img.size

        for part in img_data.get('available_parts', []):
            name = part['part_name']
            bbox = part['absolute_bounding_box']
            center_x = bbox['left'] + bbox['width'] / 2
            center_y = bbox['top'] + bbox['height'] / 2
            # Check if inside center tile
            left = img_w // 2 - TILE_SIZE // 2
            top = img_h // 2 - TILE_SIZE // 2
            if left <= center_x <= left + TILE_SIZE and top <= center_y <= top + TILE_SIZE:
                anchor_counts[name] += 1

    # Return the part with max count
    anchor_part = max(anchor_counts, key=anchor_counts.get)
    return anchor_part

def load_avg_positions_from_subset_with_anchor(train_images_subset, all_parts, anchor_part, image_dir):
    anchor_positions = []
    part_positions = {part: [] for part in all_parts}

    for img_id, img_data in train_images_subset.items():
        # Load image size
        img_path = os.path.join(image_dir, img_id)
        with Image.open(img_path) as img:
            img_w, img_h = img.size

        # Find anchor part position in this image
        anchor_pos = None
        for part in img_data.get('available_parts', []):
            name = part['part_name']
            bbox = part['absolute_bounding_box']
            center_x = bbox['left'] + bbox['width'] / 2
            center_y = bbox['top'] + bbox['height'] / 2
            if name == anchor_part:
                anchor_pos = (center_x, center_y)
                break
        if anchor_pos is None:
            # Skip image if anchor part missing
            continue
        anchor_positions.append(anchor_pos)
        # Record positions relative to anchor
        for part in img_data.get('available_parts', []):
            name = part['part_name']
            bbox = part['absolute_bounding_box']
            cx = bbox['left'] + bbox['width'] / 2
            cy = bbox['top'] + bbox['height'] / 2
            rel_x = cx - anchor_pos[0]
            rel_y = cy - anchor_pos[1]
            part_positions[name].append((rel_x, rel_y))

    avg_anchor_pos = np.mean(anchor_positions, axis=0)
    avg_offsets = {}
    for part in all_parts:
        if part_positions[part]:
            avg_offsets[part] = np.mean(part_positions[part], axis=0)
        else:
            avg_offsets[part] = (0, 0)
    return avg_anchor_pos, avg_offsets

# --- TILE CROPPING BASED ON ANCHOR OFFSETS --- #

def crop_tile(image: Image.Image, center: Tuple[int, int]) -> Image.Image:
    cx, cy = center
    left = max(cx - TILE_SIZE // 2, 0)
    top = max(cy - TILE_SIZE // 2, 0)
    right = left + TILE_SIZE
    bottom = top + TILE_SIZE
    # Ensure crop box inside image bounds
    right = min(right, image.width)
    bottom = min(bottom, image.height)
    left = right - TILE_SIZE
    top = bottom - TILE_SIZE
    return image.crop((left, top, right, bottom))

def estimate_other_tiles(center: Tuple[int, int], all_parts: List[str], average_offsets: dict):
    cx, cy = center
    tile_centers = []
    for part in all_parts:
        dx, dy = average_offsets.get(part, (0, 0))
        est_x = int(cx + dx)
        est_y = int(cy + dy)
        tile_centers.append((est_x, est_y))
    return tile_centers

# --- DATASET --- #
class BikeTileDataset(Dataset):
    def __init__(self, annotations, image_dir, image_ids, all_parts, average_offsets, target_size=(640, 640)):
        self.image_dir = image_dir
        self.images = image_ids
        self.annotations = annotations["images"]
        self.part_to_idx = {part: i for i, part in enumerate(all_parts)}
        self.all_parts = all_parts
        self.average_offsets = average_offsets
        self.target_size = target_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        original_w, original_h = image.size
        image = image.resize(self.target_size, Image.BILINEAR)
        resized_w, resized_h = self.target_size

        scale_x = resized_w / original_w
        scale_y = resized_h / original_h

        parts = self.annotations[img_name].get("available_parts", [])
        for part in parts:
            bbox = part["absolute_bounding_box"]
            bbox["left"] *= scale_x
            bbox["top"] *= scale_y
            bbox["width"] *= scale_x
            bbox["height"] *= scale_y

        cx, cy = resized_w // 2, resized_h // 2
        tile_centers = estimate_other_tiles((cx, cy), self.all_parts, self.average_offsets)

        tiles = [transform(crop_tile(image, center)) for center in tile_centers]
        tiles = torch.stack(tiles)

        missing_parts = self.annotations[img_name].get("missing_parts", [])
        label = torch.zeros(len(self.all_parts))
        for part in missing_parts:
            idx_part = self.part_to_idx[part]
            label[idx_part] = 1

        return tiles, label
    
# --- MODEL --- #
class TileMobileNet(nn.Module):
    def __init__(self, num_parts=NUM_PARTS):
        super().__init__()
        backbone = models.mobilenet_v2(pretrained=True)
        self.feature_extractor = backbone.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(1280, 512),
            nn.ReLU(),
            nn.Linear(512, num_parts),
        )

    def forward(self, x):
        B, T, C, H, W = x.shape  # B=batch, T=tiles
        x = x.view(B * T, C, H, W)
        feats = self.feature_extractor(x)
        pooled = self.pool(feats).view(B, T, -1)
        aggregated = pooled.mean(dim=1)
        out = self.classifier(aggregated)
        return out

# --- TRAINING & EVAL --- #

if torch.cuda.is_available():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)


def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_time, total_pixels = 0, 0, 0
    global gpu_memories, cpu_memories
    gpu_memories, cpu_memories = [], []

    for tiles, labels in tqdm(dataloader):
        tiles = tiles.to(DEVICE)
        labels = labels.to(DEVICE)

        start_time = time.time()

        preds = model(tiles)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - start_time
        total_loss += loss.item()
        total_time += batch_time
        total_pixels += tiles.numel()

        if torch.cuda.is_available():
            mem_info = nvmlDeviceGetMemoryInfo(handle)
            gpu_mem_used = mem_info.used / (1024**2)
            gpu_memories.append(gpu_mem_used)
        else:
            gpu_mem_used = 0
        
        cpu_mem_used = psutil.virtual_memory().used / (1024**2)
        cpu_memories.append(cpu_mem_used)

    avg_loss = total_loss / len(dataloader)
    print(f"Train Loss: {avg_loss:.4f}, Time: {total_time:.2f}s, Pixels: {total_pixels}")

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for tiles, labels in tqdm(dataloader):
            tiles = tiles.to(DEVICE)
            labels = labels.to(DEVICE)
            preds = model(tiles)
            loss = criterion(preds, labels)
            total_loss += loss.item()

            preds_bin = (torch.sigmoid(preds) > 0.5).cpu().numpy()
            labels_np = labels.cpu().numpy()

            all_preds.append(preds_bin)
            all_labels.append(labels_np)

    avg_loss = total_loss / len(dataloader)
    print(f"Loss: {avg_loss:.4f}")

    Y_pred = np.vstack(all_preds)
    Y_true = np.vstack(all_labels)

    micro_f1 = f1_score(Y_true, Y_pred, average="micro", zero_division=0)
    macro_f1 = f1_score(Y_true, Y_pred, average="macro", zero_division=0)

    FN = np.logical_and(Y_true == 1, Y_pred == 0).sum()
    TP = np.logical_and(Y_true == 1, Y_pred == 1).sum()
    FP = np.logical_and(Y_true == 0, Y_pred == 1).sum()

    N_images = Y_true.shape[0]
    miss_rate = FN / (FN + TP) if (FN + TP) > 0 else 0
    fppi = FP / N_images

    overall_acc = accuracy_score(Y_true.flatten(), Y_pred.flatten())
    overall_prec = precision_score(Y_true.flatten(), Y_pred.flatten(), zero_division=0)
    overall_rec = recall_score(Y_true.flatten(), Y_pred.flatten(), zero_division=0)
    overall_f1 = f1_score(Y_true.flatten(), Y_pred.flatten(), zero_division=0)

    print(f"Micro F1: {micro_f1:.4f}, Macro F1: {macro_f1:.4f}")
    print(f"Miss Rate: {miss_rate:.4f}, FPPI: {fppi:.4f}")
    print(f"Overall Acc: {overall_acc:.4f}, Precision: {overall_prec:.4f}, Recall: {overall_rec:.4f}, F1: {overall_f1:.4f}")

# --- MAIN SCRIPT --- #
if __name__ == "__main__":
    # Load annotations
    json_path = "../data/processed/final_annotations_without_occluded.json"
    image_dir = "../data/images"
    with open(json_path) as f:
        annotations = json.load(f)

    # Split dataset keys
    image_ids = list(annotations["images"].keys())
    random.shuffle(image_ids)

    n = len(image_ids)
    n_train = int(0.8 * n)
    n_test = n - n_train
    n_val = int(0.1 * n_train)
    n_train = n_train - n_val  # adjust train count after val split

    train_ids = image_ids[:n_train]
    val_ids = image_ids[n_train:n_train + n_val]
    test_ids = image_ids[n_train + n_val:]

    print(f"Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")

    # Extract training subset for anchor computations
    train_images_subset = {k: annotations["images"][k] for k in train_ids}
    all_parts = annotations["all_parts"]

    # Find anchor part based on training set
    anchor_part = find_anchor_part(train_images_subset, all_parts, image_dir)
    avg_anchor_pos, average_offsets = load_avg_positions_from_subset_with_anchor(train_images_subset, all_parts, anchor_part, image_dir)

    print(f"Anchor part: {anchor_part}")
    print(f"Average anchor position: {avg_anchor_pos}")
    print(f"Average offsets for parts (example): {dict(list(average_offsets.items())[:3])}")

    # Create datasets and loaders
    train_dataset = BikeTileDataset(annotations, image_dir, train_ids, all_parts, average_offsets)
    val_dataset = BikeTileDataset(annotations, image_dir, val_ids, all_parts, average_offsets)
    test_dataset = BikeTileDataset(annotations, image_dir, test_ids, all_parts, average_offsets)

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

    # Initialize model, loss, optimizer
    model = TileMobileNet().to(DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Training loop
    for epoch in range(5):
        print(f"Epoch {epoch + 1}")
        train(model, train_loader, optimizer, criterion)
        evaluate(model, val_loader, criterion)
        max_gpu_mem = max(gpu_memories) if gpu_memories else 0
        max_cpu_mem = max(cpu_memories)

        table = [
            ["Epoch", epoch],
            ["Maximum GPU Memory Usage (MB)", f"{max_gpu_mem:.2f}"],
            ["Maximum CPU Memory Usage (MB)", f"{max_cpu_mem:.2f}"],
        ]

        print(tabulate(table, headers=["Metric", "Value"], tablefmt="pretty"))

    print("Training complete.")


evaluate(model, test_loader, criterion)

Train: 5742, Val: 638, Test: 1596




Anchor part: back_pedal
Average anchor position: [323.4758003 289.5394436]
Average offsets for parts (example): {'back_pedal': array([0., 0.]), 'back_hand_break': array([ -41.00623624, -164.98789435]), 'front_pedal': array([28.69489194, 19.22082515])}
Epoch 1


100%|██████████| 359/359 [00:16<00:00, 21.23it/s]


Train Loss: 0.4867, Time: 3.42s, Pixels: 1552269312


100%|██████████| 40/40 [00:00<00:00, 40.39it/s]


Loss: 0.4571
Micro F1: 0.6432, Macro F1: 0.3644
Miss Rate: 0.4080, FPPI: 1.8339
Overall Acc: 0.7801, Precision: 0.7039, Recall: 0.5920, F1: 0.6432
+-------------------------------+---------+
|            Metric             |  Value  |
+-------------------------------+---------+
|             Epoch             |    0    |
| Maximum GPU Memory Usage (MB) | 3744.56 |
| Maximum CPU Memory Usage (MB) | 9403.96 |
+-------------------------------+---------+
Epoch 2


100%|██████████| 359/359 [00:16<00:00, 21.50it/s]


Train Loss: 0.4427, Time: 3.15s, Pixels: 1552269312


100%|██████████| 40/40 [00:00<00:00, 40.65it/s]


Loss: 0.4411
Micro F1: 0.6648, Macro F1: 0.4176
Miss Rate: 0.3831, FPPI: 1.7618
Overall Acc: 0.7917, Precision: 0.7206, Recall: 0.6169, F1: 0.6648
+-------------------------------+---------+
|            Metric             |  Value  |
+-------------------------------+---------+
|             Epoch             |    1    |
| Maximum GPU Memory Usage (MB) | 3724.88 |
| Maximum CPU Memory Usage (MB) | 9203.69 |
+-------------------------------+---------+
Epoch 3


100%|██████████| 359/359 [00:16<00:00, 21.44it/s]


Train Loss: 0.4173, Time: 3.13s, Pixels: 1552269312


100%|██████████| 40/40 [00:00<00:00, 40.33it/s]


Loss: 0.4339
Micro F1: 0.6742, Macro F1: 0.4549
Miss Rate: 0.3682, FPPI: 1.7868
Overall Acc: 0.7955, Precision: 0.7226, Recall: 0.6318, F1: 0.6742
+-------------------------------+---------+
|            Metric             |  Value  |
+-------------------------------+---------+
|             Epoch             |    2    |
| Maximum GPU Memory Usage (MB) | 3711.94 |
| Maximum CPU Memory Usage (MB) | 9225.93 |
+-------------------------------+---------+
Epoch 4


100%|██████████| 359/359 [00:16<00:00, 21.42it/s]


Train Loss: 0.3927, Time: 3.17s, Pixels: 1552269312


100%|██████████| 40/40 [00:01<00:00, 36.61it/s]


Loss: 0.4315
Micro F1: 0.6791, Macro F1: 0.4636
Miss Rate: 0.3709, FPPI: 1.6473
Overall Acc: 0.8009, Precision: 0.7377, Recall: 0.6291, F1: 0.6791
+-------------------------------+---------+
|            Metric             |  Value  |
+-------------------------------+---------+
|             Epoch             |    3    |
| Maximum GPU Memory Usage (MB) | 3719.69 |
| Maximum CPU Memory Usage (MB) | 9154.45 |
+-------------------------------+---------+
Epoch 5


100%|██████████| 359/359 [00:16<00:00, 21.43it/s]


Train Loss: 0.3676, Time: 3.17s, Pixels: 1552269312


100%|██████████| 40/40 [00:01<00:00, 39.18it/s]


Loss: 0.4349
Micro F1: 0.6887, Macro F1: 0.4973
Miss Rate: 0.3384, FPPI: 1.9122
Overall Acc: 0.7998, Precision: 0.7182, Recall: 0.6616, F1: 0.6887
+-------------------------------+---------+
|            Metric             |  Value  |
+-------------------------------+---------+
|             Epoch             |    4    |
| Maximum GPU Memory Usage (MB) | 3734.69 |
| Maximum CPU Memory Usage (MB) | 9161.47 |
+-------------------------------+---------+
Training complete.


100%|██████████| 100/100 [00:02<00:00, 44.78it/s]

Loss: 0.4327
Micro F1: 0.6910, Macro F1: 0.4988
Miss Rate: 0.3335, FPPI: 1.9204
Overall Acc: 0.8019, Precision: 0.7173, Recall: 0.6665, F1: 0.6910



