In [11]:
#!/usr/bin/env python
# coding: utf-8

# ----------------------------
#  IMPORTS
# ----------------------------
import os
import shutil
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision.models.detection.image_list import ImageList
from torchsummary import summary
import io
from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork, AnchorGenerator
import wandb

# ----------------------------
#  DEVICE
# ----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------------------
#  PATHS & CONSTANTS
# ----------------------------
IMAGE_DIR = '../dataset/images'
ANNOTATIONS_DIR = '../dataset/annotations'
TARGET_SIZE = (224, 224)
BATCH_SIZE = 8
NUM_EPOCHS = 5

# ----------------------------
#  INIT WandB
# ----------------------------
wandb.init(
    project="rpn-training",
    name="rpn_resnet18_run",
    config={
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE,
        "lr": 0.001,
        "backbone": "resnet18",
        "input_size": "224x224"
    },
    settings=wandb.Settings(init_timeout=120)
)

# ----------------------------
#  LOAD DATA
# ----------------------------
class_names = [d for d in os.listdir(IMAGE_DIR) if os.path.isdir(os.path.join(IMAGE_DIR, d))]
csv_files = [f for f in os.listdir(ANNOTATIONS_DIR) if f.endswith('.csv')]
label_map = {name: i + 1 for i, name in enumerate(class_names)}

dataset = []

for i in range(len(class_names)):
    class_name = class_names[i]
    class_dir = os.path.join(IMAGE_DIR, class_name)
    csv_file_name = csv_files[i]
    csv_path = os.path.join(ANNOTATIONS_DIR, csv_file_name)
    df_annotations = pd.read_csv(csv_path)

    for image_name in os.listdir(class_dir):
        image_path = os.path.join(class_dir, image_name)
        image = cv2.imread(image_path)
        if image is None:
            continue

        h, w, _ = image.shape
        row = df_annotations[df_annotations['image_name'] == image_name]
        if row.empty:
            continue

        ann = row.iloc[0, 1:].tolist()
        ann[0] = (ann[0] / w) * 224
        ann[1] = (ann[1] / h) * 224
        ann[2] = (ann[2] / w) * 224
        ann[3] = (ann[3] / h) * 224

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, TARGET_SIZE)

        image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
        label_tensor = torch.tensor([class_names.index(class_name)], dtype=torch.int64)
        ann_tensor = torch.tensor([ann], dtype=torch.float32)

        target = {'boxes': ann_tensor, 'labels': label_tensor}
        dataset.append((image_tensor, target))

def collate_fn(batch):
    return tuple(zip(*batch))

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# ----------------------------
#  BACKBONE (ResNet18)
# ----------------------------
resnet_model = torchvision.models.resnet18()
backbone = torch.nn.Sequential(*list(resnet_model.children())[:-2])
backbone.out_channels = 512

# ----------------------------
#  SAVE MODEL SUMMARY
# ----------------------------
# Make WandB output dir
wandb_output_dir = os.path.join(wandb.run.dir, "model_outputs")
os.makedirs(wandb_output_dir, exist_ok=True)

summary_path = os.path.join(wandb_output_dir, "model_summary.txt")
buffer = io.StringIO()
sys_stdout = sys.stdout
sys.stdout = buffer
summary(backbone, (3, 224, 224))
sys.stdout = sys_stdout

with open(summary_path, "w") as f:
    f.write(buffer.getvalue())

# Log summary as artifact
artifact = wandb.Artifact("model_summary", type="summary")
artifact.add_file(summary_path)
wandb.log_artifact(artifact)
print("Model summary saved and logged to WandB:", summary_path)

# Freeze backbone parameters
for param in backbone.parameters():
    param.requires_grad = False

# ----------------------------
#  RPN MODEL
# ----------------------------
anchor_generator = AnchorGenerator(
    sizes=((32, 64, 128),),
    aspect_ratios=((0.5, 1.0, 2.0),)
)
in_channels = backbone.out_channels
num_anchors = anchor_generator.num_anchors_per_location()[0]
rpn_head = RPNHead(in_channels=in_channels, num_anchors=num_anchors)

rpn_model = RegionProposalNetwork(
    anchor_generator,
    rpn_head,
    fg_iou_thresh=0.7,
    bg_iou_thresh=0.3,
    batch_size_per_image=256,
    positive_fraction=0.5,
    pre_nms_top_n={'training': 2000, 'testing': 1000},
    post_nms_top_n={'training': 1000, 'testing': 500},
    nms_thresh=0.7
)

optimizer = torch.optim.Adam(rpn_model.parameters(), lr=0.001)

# ----------------------------
#  TRAINING LOOP
# ----------------------------
log_path = os.path.join(wandb_output_dir, "training_log.txt")

for epoch in range(NUM_EPOCHS):
    epoch_losses = []

    for images, targets in dataloader:
        optimizer.zero_grad()
        images_gpu = torch.stack([img.to(DEVICE) for img in images])
        targets_gpu = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        with torch.no_grad():
            features = backbone(images_gpu)

        image_list = ImageList(images_gpu, [img.shape[-2:] for img in images])
        _, loss_dict = rpn_model(image_list, {'0': features}, targets_gpu)
        loss = loss_dict['loss_objectness'] + loss_dict['loss_rpn_box_reg']

        if torch.isfinite(loss):
            epoch_losses.append(loss.item())
            loss.backward()
            optimizer.step()

    mean_loss = sum(epoch_losses) / len(epoch_losses)
    log_line = f"Epoch {epoch+1} | Loss: {mean_loss:.4f}\n"
    print(log_line.strip())

    with open(log_path, "a") as f:
        f.write(log_line)

    wandb.log({"epoch": epoch + 1, "loss": mean_loss})

    # Save checkpoint and log as artifact
    checkpoint_path = os.path.join(wandb_output_dir, f"rpn_epoch_{epoch+1}.pth")
    torch.save(rpn_model.state_dict(), checkpoint_path)
    ckpt_artifact = wandb.Artifact(f"rpn_checkpoint_epoch_{epoch+1}", type="model")
    ckpt_artifact.add_file(checkpoint_path)
    wandb.log_artifact(ckpt_artifact)

# ----------------------------
#  VISUALIZATION FUNCTION
# ----------------------------
def visualize_rpn_proposals(image_path, rpn_model_trained, backbone_model_trained):
    proposals_dir = os.path.join(wandb.run.dir, "rpn_proposals")
    os.makedirs(proposals_dir, exist_ok=True)

    rpn_model_trained.to(DEVICE)
    backbone_model_trained.to(DEVICE)
    rpn_model_trained.eval()
    backbone_model_trained.eval()

    img_bgr = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    image_rgb_resized = cv2.resize(img_rgb, TARGET_SIZE)
    img_tensor = (torch.tensor(image_rgb_resized, dtype=torch.float32).permute(2, 0, 1) / 255.0).to(DEVICE)

    with torch.no_grad():
        features = backbone_model_trained(img_tensor.unsqueeze(0))
        image_list = ImageList(img_tensor.unsqueeze(0), [tuple(img_tensor.shape[-2:])])
        proposals, _ = rpn_model_trained(image_list, {'0': features})

    top_proposals = proposals[0][:5].cpu().numpy()
    img_display = image_rgb_resized.copy()

    for i, box in enumerate(top_proposals):
        x1, y1, x2, y2 = map(int, box)
        color = (0,255,0) if i==0 else ((255,255,0) if i < 3 else (255,0,0))
        width = 3 if i==0 else (2 if i<3 else 1)
        cv2.rectangle(img_display, (x1, y1), (x2, y2), color, width)

    base_name = os.path.basename(image_path).split('.')[0]
    output_path = os.path.join(proposals_dir, f"{base_name}_rpn_proposals.png")
    plt.figure(figsize=(8, 8))
    plt.imshow(img_display)
    plt.axis('off')
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    # Log image as artifact
    img_artifact = wandb.Artifact(f"rpn_proposals_{base_name}", type="image")
    img_artifact.add_file(output_path)
    wandb.log_artifact(img_artifact)

    print(f"Saved and logged proposal visualization to: {output_path}")

# ----------------------------
#  RUN VISUALIZATION
# ----------------------------
TEST_IMAGE_PATH = '../dataset/images/airplane/image_0002.jpg'
visualize_rpn_proposals(TEST_IMAGE_PATH, rpn_model, backbone)


0,1
epoch,▁
loss,▁

0,1
epoch,1.0
loss,0.04534


Model summary saved and logged to WandB: /home/hossein-simchi/ML/computer-vision/Object detection/Region Proposal Netwrok/notebooks/wandb/run-20251208_101520-dnh70r7b/files/model_outputs/model_summary.txt
Epoch 1 | Loss: 0.0453
Epoch 2 | Loss: 0.0212
Epoch 3 | Loss: 0.0178
Epoch 4 | Loss: 0.0160
Epoch 5 | Loss: 0.0147
Saved and logged proposal visualization to: /home/hossein-simchi/ML/computer-vision/Object detection/Region Proposal Netwrok/notebooks/wandb/run-20251208_101520-dnh70r7b/files/rpn_proposals/image_0002_rpn_proposals.png
