In [11]:
import torch
import torch.nn as nn
import torchvision.models as models

class SimpleDetector(nn.Module):
    def __init__(self, num_classes=7, max_objects=7):
        super(SimpleDetector, self).__init__()
        self.num_classes = num_classes
        self.max_objects = max_objects

        # Use a pre-trained ResNet18 as the feature extractor
        resnet = models.resnet18(pretrained=True)
        modules = list(resnet.children())[:-1]  # Remove the last classification layer
        self.feature_extractor = nn.Sequential(*modules)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, self.max_objects * (4 + self.num_classes))  # [B, 7 * (4+7)] = [B, 77]
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.feature_extractor(x)
        x = self.classifier(x)
        x = x.view(batch_size, self.max_objects, 4 + self.num_classes)  # [B, 7, 11]
        return x


In [12]:
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset

# Converts [x_min, y_min, x_max, y_max] → [x_center, y_center, width, height] (normalized)
def convert_bbox(x_min, y_min, x_max, y_max, img_width, img_height):
    x_center = (x_min + x_max) / 2 / img_width
    y_center = (y_min + y_max) / 2 / img_height
    width = (x_max - x_min) / img_width
    height = (y_max - y_min) / img_height
    return [x_center, y_center, width, height]

class CollegeFacilitiesDataset(Dataset):
    def __init__(self, annotations_file, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        # Load annotations
        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)

        # Extract unique class labels dynamically
        all_classes = set()
        for item in self.annotations:
            for obj in item['objects']:
                all_classes.add(obj['class_label'])
        
        self.classes = sorted(list(all_classes))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        print(f"Detected classes: {self.classes}")

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        record = self.annotations[idx]
        img_path = os.path.join(self.image_dir, record['filename'])
        image = Image.open(img_path).convert("RGB")
        img_width, img_height = image.size

        bboxes = []
        class_ids = []

        for obj in record['objects']:
            bbox = obj['bounding_box']
            bbox_norm = convert_bbox(
                bbox['x_min'], bbox['y_min'], bbox['x_max'], bbox['y_max'], img_width, img_height
            )
            bboxes.append(bbox_norm)
            class_ids.append(self.class_to_idx[obj['class_label']])

        bboxes = torch.tensor(bboxes, dtype=torch.float32)
        class_ids = torch.tensor(class_ids, dtype=torch.int64)

        if self.transform:
            image = self.transform(image)

        return image, bboxes, class_ids


In [13]:
# train.py
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import CollegeFacilitiesDataset
from model import SimpleDetector

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CLASS_NAMES = [
    # 'chair', 'bench', 'desk', 'table', 'sofa', 'podium', 'cupboard', 'whiteboard', 'blackboard', 'notice board',
    # 'projector', 'screen', 'computer', 'monitor', 'keyboard', 'mouse', 'CPU', 'smart board', 'chalk', 'duster',
    # 'lab table', 'test tube', 'beaker', 'microscope', 'chemical bottle', 'lab coat', 'fire extinguisher', 'fume hood',
    # 'fan', 'AC', 'switchboard', 'tube light', 'window', 'curtain', 'door', 'clock', 'dustbin',
    # 'wash basin', 'toilet seat', 'urinal', 'mirror', 'soap dispenser', 'hand dryer', 'water cooler', 'bucket', 'mug',
    # 'bookshelf', 'book', 'newspaper stand', 'magazine rack',
    # 'CCTV camera', 'security guard', 'fire alarm', 'student', 'teacher'
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
  "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
  "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
  "backpack", "umbrella", "handbag", "tie", "suitcase",
  "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
  "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
  "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
  "chair", "couch", "potted plant", "bed", "dining table", "toilet",
  "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
  "microwave", "oven", "toaster", "sink", "refrigerator",
  "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]

NUM_CLASSES = len(CLASS_NAMES)
MAX_OBJECTS = 20  

model = SimpleDetector(num_classes=NUM_CLASSES, max_objects=MAX_OBJECTS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
bbox_loss_fn = torch.nn.MSELoss()
cls_loss_fn = torch.nn.CrossEntropyLoss()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

train_dataset = CollegeFacilitiesDataset(
   annotations_file="data/train/coco2017/annotations/labels.json",
    image_dir="data/train/coco2017/train2017",
    transform=transform
)

def collate_fn(batch):
    images, all_bboxes, all_class_ids = [], [], []

    for image, bboxes, class_ids in batch:
        images.append(image)

        n = bboxes.shape[0]
        if n < MAX_OBJECTS:
            bboxes = torch.cat([bboxes, torch.zeros((MAX_OBJECTS - n, 4))], dim=0)
            class_ids = torch.cat([class_ids, torch.zeros(MAX_OBJECTS - n, dtype=torch.long)], dim=0)
        elif n > MAX_OBJECTS:
            bboxes = bboxes[:MAX_OBJECTS]
            class_ids = class_ids[:MAX_OBJECTS]

        all_bboxes.append(bboxes)
        all_class_ids.append(class_ids)

    return torch.stack(images), torch.stack(all_bboxes), torch.stack(all_class_ids)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# Training loop
for epoch in range(500):
    model.train()
    total_loss = 0.0

    for images, bboxes, class_ids in train_loader:
        images = images.to(device)
        bboxes = bboxes.to(device)
        class_ids = class_ids.to(device)

        preds = model(images)

        pred_bboxes = preds[..., :4]
        pred_class_logits = preds[..., 4:]

        loss_bbox = bbox_loss_fn(pred_bboxes, bboxes)
        loss_cls = cls_loss_fn(pred_class_logits.view(-1, NUM_CLASSES), class_ids.view(-1))

        loss = loss_bbox + loss_cls

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")

torch.save(model.state_dict(), "simple_detector.pth")


Detected classes: ['vase with flowers']
Epoch 1 | Loss: 4.4904
Epoch 2 | Loss: 4.1011
Epoch 3 | Loss: 3.7320
Epoch 4 | Loss: 3.2816
Epoch 5 | Loss: 2.7164
Epoch 6 | Loss: 2.0891
Epoch 7 | Loss: 1.4637
Epoch 8 | Loss: 0.9080
Epoch 9 | Loss: 0.4847
Epoch 10 | Loss: 0.2342
Epoch 11 | Loss: 0.1227
Epoch 12 | Loss: 0.0775
Epoch 13 | Loss: 0.0616
Epoch 14 | Loss: 0.0644
Epoch 15 | Loss: 0.0693
Epoch 16 | Loss: 0.0649
Epoch 17 | Loss: 0.0556
Epoch 18 | Loss: 0.0476
Epoch 19 | Loss: 0.0406
Epoch 20 | Loss: 0.0320
Epoch 21 | Loss: 0.0258
Epoch 22 | Loss: 0.0275
Epoch 23 | Loss: 0.0330
Epoch 24 | Loss: 0.0314
Epoch 25 | Loss: 0.0244
Epoch 26 | Loss: 0.0194
Epoch 27 | Loss: 0.0171
Epoch 28 | Loss: 0.0149
Epoch 29 | Loss: 0.0142
Epoch 30 | Loss: 0.0148
Epoch 31 | Loss: 0.0139
Epoch 32 | Loss: 0.0116
Epoch 33 | Loss: 0.0094
Epoch 34 | Loss: 0.0085
Epoch 35 | Loss: 0.0084
Epoch 36 | Loss: 0.0081
Epoch 37 | Loss: 0.0069
Epoch 38 | Loss: 0.0053
Epoch 39 | Loss: 0.0041
Epoch 40 | Loss: 0.0042
Epoch 41 

In [15]:
import torch
from torchvision import transforms
from PIL import Image, ImageDraw
from model import SimpleDetector

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CLASS_NAMES = [
    # 'chair', 'bench', 'desk', 'table', 'sofa', 'podium', 'cupboard', 'whiteboard', 'blackboard', 'notice board',
    # 'projector', 'screen', 'computer', 'monitor', 'keyboard', 'mouse', 'CPU', 'smart board', 'chalk', 'duster',
    # 'lab table', 'test tube', 'beaker', 'microscope', 'chemical bottle', 'lab coat', 'fire extinguisher', 'fume hood',
    # 'fan', 'AC', 'switchboard', 'tube light', 'window', 'curtain', 'door', 'clock', 'dustbin',
    # 'wash basin', 'toilet seat', 'urinal', 'mirror', 'soap dispenser', 'hand dryer', 'water cooler', 'bucket', 'mug',
    # 'bookshelf', 'book', 'newspaper stand', 'magazine rack',
    # 'CCTV camera', 'security guard', 'fire alarm', 'student', 'teacher'
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
  "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
  "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
  "backpack", "umbrella", "handbag", "tie", "suitcase",
  "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
  "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
  "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
  "chair", "couch", "potted plant", "bed", "dining table", "toilet",
  "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
  "microwave", "oven", "toaster", "sink", "refrigerator",
  "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]

NUM_CLASSES = len(CLASS_NAMES)
MAX_OBJECTS = 20 

model = SimpleDetector(num_classes=NUM_CLASSES, max_objects=MAX_OBJECTS)
model.load_state_dict(torch.load("simple_detector.pth", map_location=device))
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# def denormalize_bbox(box, w, h):
#     x_c, y_c, bw, bh = box.tolist()

#     x1 = max(0, (x_c - bw / 2) * w)
#     y1 = max(0, (y_c - bh / 2) * h)
#     x2 = min(w, (x_c + bw / 2) * w)
#     y2 = min(h, (y_c + bh / 2) * h)

#     # Ensure y2 >= y1 and x2 >= x1
#     x1, x2 = sorted([x1, x2])
#     y1, y2 = sorted([y1, y2])

#     return [x1, y1, x2, y2]
def denormalize_bbox(box, img_w, img_h):
    x_c, y_c, w, h = box.tolist()
    x1 = (x_c - w/2) * img_w
    y1 = (y_c - h/2) * img_h
    x2 = (x_c + w/2) * img_w
    y2 = (y_c + h/2) * img_h

    # Ensure box coordinates are valid
    x1, x2 = sorted([x1, x2])
    y1, y2 = sorted([y1, y2])
    return [x1, y1, x2, y2]


def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    orig_w, orig_h = image.size
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)[0]  # [MAX_OBJECTS, 4+NUM_CLASSES]
        bboxes = output[:, :4]
        class_logits = output[:, 4:]
        class_preds = class_logits.argmax(dim=1)
        scores = class_logits.softmax(dim=1).max(dim=1)[0]

    draw = ImageDraw.Draw(image)

    for i in range(MAX_OBJECTS):
        if scores[i] < 0.5:
            continue
        label = CLASS_NAMES[class_preds[i]]
        box = denormalize_bbox(bboxes[i].cpu(), orig_w, orig_h)
        draw.rectangle(box, outline="red", width=2)
        draw.text((box[0], box[1]), label, fill="yellow")

    image.show()

# Example usage
predict("000000000030.jpg")