In [None]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn

# Define paths
img_dir = "./train/images"
label_dir = "./train/labels"

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class CoordinateDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None):
        """
        Args:
            img_dir (string): Directory with all the images
            label_dir (string): Directory with all the label text files
            transform (callable, optional): Optional transform to be applied on images
        """
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transform = transform

        # Get sorted list of files to ensure proper pairing
        self.img_files = sorted([f for f in os.listdir(img_dir) if f.endswith(".png")])

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        # Get image path
        img_path = os.path.join(self.img_dir, self.img_files[idx])

        # Get corresponding label path (replace image extension with .txt)
        label_name = os.path.splitext(self.img_files[idx])[0] + ".txt"
        label_path = os.path.join(self.label_dir, label_name)

        # Load image
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)

        # Load coordinates from the first line of the label file
        with open(label_path, "r") as f:
            line = f.readline().strip()
            parts = line.split(" ")
            # Parse the coordinates: class_id, x_center, y_center, width, height
            coords = [float(p) for p in parts]

        # Convert coordinates to tensor without normalization
        coords = torch.tensor(coords, dtype=torch.float32)

        return image, coords

In [None]:
# Define a simple CNN model for coordinate prediction
class CoordinateCNN(nn.Module):
    def __init__(self, num_coords=5):  # Default: class_id + 4 coordinates
        super(CoordinateCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 32 * 32, 512),  # Adjust based on your input image size
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_coords),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
# Training function
def train_model(model, dataloader, criterion, optimizer, num_epochs=10):
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0

        for images, coords in dataloader:
            images = images.to(device)
            coords = coords.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, coords)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(dataloader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

In [None]:
if __name__ == "__main__":
    # Define transformations (without normalization for images)
    transform = transforms.Compose(
        [
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ]
    )

    # Create dataset and dataloader
    dataset = CoordinateDataset(
        img_dir=img_dir, label_dir=label_dir, transform=transform
    )
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    # Create and move model to device
    model = CoordinateCNN().to(device)

    # Define loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    train_model(model, dataloader, criterion, optimizer, num_epochs=20)

    # Save the trained model
    torch.save(model.state_dict(), "coordinate_prediction_model.pt")

    # Optional: Visualization of predictions
    def visualize_prediction(image, true_coords, pred_coords):
        # Convert image tensor to numpy for visualization
        img = image.permute(1, 2, 0).cpu().numpy()

        plt.figure(figsize=(10, 10))
        plt.imshow(img)

        # Extract coordinates (x_center, y_center, width, height)
        true_class, true_x, true_y, true_w, true_h = true_coords.cpu().numpy()
        pred_class, pred_x, pred_y, pred_w, pred_h = pred_coords.cpu().numpy()

        # Draw true bounding box (green)
        h, w = 256, 256  # Assuming resized images
        true_x1 = int((true_x - true_w / 2) * w)
        true_y1 = int((true_y - true_h / 2) * h)
        true_x2 = int((true_x + true_w / 2) * w)
        true_y2 = int((true_y + true_h / 2) * h)
        plt.plot(
            [true_x1, true_x2, true_x2, true_x1, true_x1],
            [true_y1, true_y1, true_y2, true_y2, true_y1],
            "g-",
            linewidth=2,
        )

        # Draw predicted bounding box (red)
        pred_x1 = int((pred_x - pred_w / 2) * w)
        pred_y1 = int((pred_y - pred_h / 2) * h)
        pred_x2 = int((pred_x + pred_w / 2) * w)
        pred_y2 = int((pred_y + pred_h / 2) * h)
        plt.plot(
            [pred_x1, pred_x2, pred_x2, pred_x1, pred_x1],
            [pred_y1, pred_y1, pred_y2, pred_y2, pred_y1],
            "r-",
            linewidth=2,
        )

        plt.title(
            f"True class: {int(true_class)}, Pred class: {int(round(pred_class))}"
        )
        plt.show()