<a href="https://colab.research.google.com/github/MattPlatt/PLATTLINE_WORKING/blob/main/PlattLine.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#connect to Drive
from google.colab import drive
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# Custom DataSet Class for Training

import os
import json
import torch
from torch.utils.data import Dataset
from PIL import Image

class COCODataset(Dataset):
    def __init__(self, image_dir, annotation_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        # Load JSON annotations
        with open(annotation_file, 'r') as f:
            self.coco = json.load(f)

        # Create a mapping from image ID to file name
        self.image_id_to_filename = {img['id']: img['file_name'] for img in self.coco['images']}

        # Create a mapping from image ID to annotations
        self.image_id_to_annotations = {}
        for ann in self.coco['annotations']:
            if ann['image_id'] not in self.image_id_to_annotations:
                self.image_id_to_annotations[ann['image_id']] = []
            self.image_id_to_annotations[ann['image_id']].append(ann)

        # List of image IDs
        self.image_ids = list(self.image_id_to_filename.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        # Get the image ID and file name
        image_id = self.image_ids[idx]
        file_name = self.image_id_to_filename[image_id]
        image_path = os.path.join(self.image_dir, file_name)

        # Load the image
        image = Image.open(image_path).convert("RGB")  # Convert to RGB (if grayscale, modify as needed)

        # Get annotations for the image
        annotations = self.image_id_to_annotations.get(image_id, [])

        # Extract bounding boxes and labels
        boxes = []
        labels = []
        for ann in annotations:
            boxes.append(ann['bbox'])  # COCO format: [x, y, width, height]
            labels.append(ann['category_id'])  # Class ID

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels}

        # Apply transforms if provided
        if self.transform:
            image = self.transform(image)

        return image, target


In [3]:
# Specify Dataset Locations for Train and VAl

train_image_dir = "/content/drive/MyDrive/S2gen/data/train/easy/sliced_images"
train_annotation_file = "/content/drive/MyDrive/S2gen/data/train/easy/sliced_coco_annotations.json"

val_image_dir = "/content/drive/MyDrive/S2gen/data/val/easy/sliced_images"
val_annotation_file = "/content/drive/MyDrive/S2gen/data/val/easy/sliced_coco_annotations.json"


In [4]:
# Data Augmentor and Data Loader for preprocessing

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

# Custom collation function to handle varying number of objects per image
def collate_fn(batch):
    images, targets = zip(*batch)
    images = torch.stack(images)  # Stack images into a batch tensor
    return images, targets  # Targets remain as a tuple of dictionaries

# Define transforms (resize, normalize, etc.)
transform = transforms.Compose([
    transforms.Resize((550, 435)),  # Resize images to the model's input size
    transforms.ToTensor(),          # Convert to PyTorch tensor
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize grayscale image to [-1, 1]
])

# Create datasets
train_dataset = COCODataset(image_dir=train_image_dir, annotation_file=train_annotation_file, transform=transform)
val_dataset = COCODataset(image_dir=val_image_dir, annotation_file=val_annotation_file, transform=transform)

# Create data loaders with custom collation function
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)



In [5]:
# Backbone of Model. This is the NN!

import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomBackbone(nn.Module):
    def __init__(self):
        super(CustomBackbone, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # Input channels = 3 (RGB)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # Output: 550x435 -> 550x435
        x = self.pool(x)           # Output: 550x435 -> 275x217
        x = F.relu(self.conv2(x))  # Output: 275x217 -> 275x217
        x = self.pool(x)           # Output: 275x217 -> 137x108
        return x


In [6]:
# Detection Head Class The RPN is responsible for generating region proposals,
# which are potential areas in the image that may contain objects. It uses anchor boxes of
#various sizes and aspect ratios, evaluates them to predict Objectness Score and Bounding Box

class RegionProposalNetwork(nn.Module):
    def __init__(self, in_channels, num_anchors):
        super(RegionProposalNetwork, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=3, stride=1, padding=1)
        self.cls_logits = nn.Conv2d(128, num_anchors, kernel_size=1, stride=1)  # Objectness scores
        self.bbox_pred = nn.Conv2d(128, num_anchors * 4, kernel_size=1, stride=1)  # Box deltas

    def forward(self, x):
        x = F.relu(self.conv(x))              # Output: 137x108x128
        logits = self.cls_logits(x)          # Output: 137x108x(num_anchors)
        bbox_deltas = self.bbox_pred(x)      # Output: 137x108x(num_anchors * 4)
        return logits, bbox_deltas


In [11]:
# The Detection Head processes the refined proposals from the RPN. It:

# Classifies Each Proposal: Assigns a specific class label (e.g., "Cables") or background.
# Refines Bounding Boxes: Further adjusts the proposal boxes to fit objects more precisely.

class DetectionHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(DetectionHead, self).__init__()
        self.fc1 = nn.Linear(473472, 256)  # Adjust to match the new flattened size
        self.fc2 = nn.Linear(256, 256)
        self.cls_score = nn.Linear(256, num_classes)  # Class scores
        self.bbox_pred = nn.Linear(256, num_classes * 4)  # Box deltas

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)
        return scores, bbox_deltas


In [8]:
class PlattLine(nn.Module):
    def __init__(self, num_classes, num_anchors):
        super(PlattLine, self).__init__()
        self.backbone = CustomBackbone()  # Updated backbone
        self.rpn = RegionProposalNetwork(in_channels=32, num_anchors=num_anchors)  # Match backbone output channels
        self.detection_head = DetectionHead(in_channels=472896, num_classes=num_classes)  # Updated flattened size

    def forward(self, images):
        # Feature extraction
        feature_map = self.backbone(images)  # Output: [8, 32, 137, 108]

        # Region proposals
        rpn_logits, rpn_bbox_deltas = self.rpn(feature_map)

        # Flatten feature map for detection head
        flattened_features = torch.flatten(feature_map, start_dim=1)
        detection_scores, detection_bbox_deltas = self.detection_head(flattened_features)

        return rpn_logits, rpn_bbox_deltas, detection_scores, detection_bbox_deltas


In [14]:
# TRAIN THE MODEL, PRINT THE LOSSES, SAVE THE WEIGHTS, PLOT THE RESULTS

import torch.optim as optim
import matplotlib.pyplot as plt
import time

# Initialize model, optimizer, and loss functions
num_classes = 1  # One class ("Cables") + background
num_anchors = 9  # 3 scales x 3 aspect ratios
model = PlattLine(num_classes=num_classes, num_anchors=num_anchors)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
classification_loss = torch.nn.CrossEntropyLoss()
regression_loss = torch.nn.SmoothL1Loss()

# Training loop
num_epochs = 10
train_losses = []  # To store total loss per epoch
rpn_cls_losses = []  # To store RPN classification loss
rpn_reg_losses = []  # To store RPN regression loss
det_cls_losses = []  # To store detection classification loss
det_reg_losses = []  # To store detection regression loss

# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_rpn_cls_loss = 0
    epoch_rpn_reg_loss = 0
    epoch_det_cls_loss = 0
    epoch_det_reg_loss = 0

    start_time = time.time()  # Track time for the epoch

    print(f"\n[Epoch {epoch+1}/{num_epochs}] Starting training...")

    for batch_idx, (images, targets) in enumerate(train_loader):
        images = images.to(device)  # Move images to GPU/CPU directly
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]  # Move annotations to device

        # Forward pass
        rpn_logits, rpn_bbox_deltas, detection_scores, detection_bbox_deltas = model(images)

        print(f"rpn_logits shape: {rpn_logits.shape}")
        print(f"rpn_bbox_deltas shape: {rpn_bbox_deltas.shape}")
        print(f"detection_scores shape: {detection_scores.shape}")
        print(f"detection_bbox_deltas shape: {detection_bbox_deltas.shape}")
        print(f"Ground truth labels shape: {torch.cat([t['labels'] for t in targets]).shape}")
        print(f"Ground truth boxes shape: {torch.cat([t['boxes'] for t in targets]).shape}")


        # Compute losses
        rpn_cls_loss = classification_loss(rpn_logits, torch.cat([t['labels'] for t in targets]))
        rpn_reg_loss = regression_loss(rpn_bbox_deltas, torch.cat([t['boxes'] for t in targets]))
        det_cls_loss = classification_loss(detection_scores, torch.cat([t['labels'] for t in targets]))
        det_reg_loss = regression_loss(detection_bbox_deltas, torch.cat([t['boxes'] for t in targets]))

        loss = rpn_cls_loss + rpn_reg_loss + det_cls_loss + det_reg_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update epoch totals
        epoch_loss += loss.item()
        epoch_rpn_cls_loss += rpn_cls_loss.item()
        epoch_rpn_reg_loss += rpn_reg_loss.item()
        epoch_det_cls_loss += det_cls_loss.item()
        epoch_det_reg_loss += det_reg_loss.item()

        # Print per-batch loss (optional, for detailed monitoring)
        print(
            f"Batch {batch_idx + 1}/{len(train_loader)} | "
            f"Total Loss: {loss.item():.4f} | "
            f"RPN (Cls: {rpn_cls_loss.item():.4f}, Reg: {rpn_reg_loss.item():.4f}) | "
            f"Det (Cls: {det_cls_loss.item():.4f}, Reg: {det_reg_loss.item():.4f})"
        )

    # Track losses for the epoch
    train_losses.append(epoch_loss)
    rpn_cls_losses.append(epoch_rpn_cls_loss)
    rpn_reg_losses.append(epoch_rpn_reg_loss)
    det_cls_losses.append(epoch_det_cls_loss)
    det_reg_losses.append(epoch_det_reg_loss)

    end_time = time.time()  # End epoch timer

    # Print concise epoch summary
    print(
        f"Epoch {epoch + 1}/{num_epochs} | "
        f"Total Loss: {epoch_loss:.4f} | "
        f"RPN (Cls: {epoch_rpn_cls_loss:.4f}, Reg: {epoch_rpn_reg_loss:.4f}) | "
        f"Det (Cls: {epoch_det_cls_loss:.4f}, Reg: {epoch_det_reg_loss:.4f}) | "
        f"Time: {end_time - start_time:.2f}s"
    )


# Save the model weights after all epochs are completed
torch.save(model.state_dict(), "PlattLine_final.pth")
print("\nTraining complete. Model weights saved to 'PlattLine_final.pth'")

# Plot individual loss components
plt.figure(figsize=(12, 6))
plt.plot(range(1, num_epochs + 1), train_losses, label="Total Loss")
plt.plot(range(1, num_epochs + 1), rpn_cls_losses, label="RPN Classification Loss")
plt.plot(range(1, num_epochs + 1), rpn_reg_losses, label="RPN Regression Loss")
plt.plot(range(1, num_epochs + 1), det_cls_losses, label="Detection Classification Loss")
plt.plot(range(1, num_epochs + 1), det_reg_losses, label="Detection Regression Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss Components Over Epochs")
plt.legend()
plt.show()



[Epoch 1/10] Starting training...
rpn_logits shape: torch.Size([8, 9, 137, 108])
rpn_bbox_deltas shape: torch.Size([8, 36, 137, 108])
detection_scores shape: torch.Size([8, 1])
detection_bbox_deltas shape: torch.Size([8, 4])
Ground truth labels shape: torch.Size([87])
Ground truth boxes shape: torch.Size([87, 4])


ValueError: Expected input batch_size (8) to match target batch_size (87).

In [None]:
# Plot predictions

def plot_predictions(image, ground_truth, predictions):
    """
    Visualize ground truth and predicted bounding boxes on the image.
    """
    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(image.permute(1, 2, 0).cpu().numpy())  # Convert tensor to NumPy for plotting

    # Plot ground truth
    for bbox in ground_truth["boxes"]:
        x, y, w, h = bbox
        rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
        ax.text(x, y - 5, "GT: Cables", color='green', fontsize=12)

    # Plot predictions
    for bbox, score in zip(predictions["boxes"], predictions["scores"]):
        x, y, w, h = bbox
        rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(x, y + h + 5, f"Pred: {score:.2f}", color='red', fontsize=12)

    plt.show()


In [None]:
# evaluate model on Validation data.

model.eval()

for images, targets in val_loader:
    images = [img.to(device) for img in images]
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    # Forward pass
    with torch.no_grad():
        rpn_logits, rpn_bbox_deltas, detection_scores, detection_bbox_deltas = model(images)

    # Convert outputs to usable predictions
    predictions = {
        "boxes": detection_bbox_deltas[0].cpu().numpy(),  # Placeholder for processed boxes
        "scores": torch.softmax(detection_scores[0], dim=1)[:, 1].cpu().numpy(),  # Confidence for "Cables"
    }

    # Visualize the first image in the batch
    plot_predictions(images[0].cpu(), targets[0], predictions)
    break  # Only visualize one batch for now


In [None]:
# plot precision and recall

from sklearn.metrics import precision_score, recall_score

# Assume ground_truths and predictions are available in the following format:
# ground_truths = [{'boxes': [[x1, y1, w1, h1], ...], 'labels': [class_id, ...]}, ...]
# predictions = [{'boxes': [[x1, y1, w1, h1], ...], 'labels': [class_id, ...], 'scores': [conf, ...]}, ...]

# Flatten ground truth and predicted labels
all_gt_labels = []
all_pred_labels = []

for gt, pred in zip(ground_truths, predictions):
    all_gt_labels.extend(gt["labels"])
    all_pred_labels.extend(pred["labels"])

# Compute precision and recall
precision = precision_score(all_gt_labels, all_pred_labels, average="macro")
recall = recall_score(all_gt_labels, all_pred_labels, average="macro")

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")
