In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import os
from torchvision.datasets import VOCDetection
import torchvision
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [2]:
def voc_target_to_detection(target):
    objs = target['annotation']['object']
    if not isinstance(objs, list):
        objs = [objs]
    boxes, labels = [], []
    for obj in objs:
        bndbox = obj['bndbox']
        xmin, ymin = float(bndbox['xmin']), float(bndbox['ymin'])
        xmax, ymax = float(bndbox['xmax']), float(bndbox['ymax'])
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(1)  # לפשט: כל האובייקטים = class 1
    return {
        "boxes": torch.tensor(boxes, dtype=torch.float32),
        "labels": torch.tensor(labels, dtype=torch.int64)
    }

def collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [voc_target_to_detection(item[1]) for item in batch]
    return images, targets

project_root = '/Users/ortalhanuna/my-code'
voc_expected = os.path.join(project_root, 'VOCdevkit', 'VOC2012')

if not os.path.isdir(voc_expected):
    print(f"VOC2012 not found at {voc_expected}.")
    raise FileNotFoundError("VOC dataset not found")

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

train_ds = VOCDetection(root=project_root, year="2012", image_set="train",
                        download=False, transform=transform)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)


In [3]:
class SimpleDetector(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1)
        )
        self.fc_box = nn.Linear(64, 4)   # bbox: x1, y1, x2, y2
        self.fc_class = nn.Linear(64, 2) # class: background / object

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        boxes = self.fc_box(x)
        class_logits = self.fc_class(x)
        return boxes, class_logits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleDetector().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
bbox_loss_fn = nn.SmoothL1Loss()
class_loss_fn = nn.CrossEntropyLoss()

In [None]:
num_epochs = 2
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for imgs, targets in train_loader:
        imgs = torch.stack(imgs).to(device)
        boxes_pred, class_logits = model(imgs)

        # לפשט: נשתמש רק באובייקט הראשון בכל תמונה
        boxes_target = torch.stack([t["boxes"][0] for t in targets]).to(device)
        labels_target = torch.stack([t["labels"][0] for t in targets]).to(device)

        loss_bbox = bbox_loss_fn(boxes_pred, boxes_target)
        loss_class = class_loss_fn(class_logits, labels_target)
        loss = loss_bbox + loss_class

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(loss.item())

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")


224.28736877441406
216.0354766845703
240.70848083496094
257.1263427734375
260.1449279785156
184.2801055908203
223.84193420410156
245.0272979736328
180.079345703125
238.35850524902344
255.6036376953125
234.8205108642578
167.4027557373047
226.9445343017578
254.82354736328125
208.81832885742188
189.5770721435547
215.69094848632812
238.7852325439453
261.0223388671875
185.87332153320312
202.6027069091797
222.59121704101562
248.22369384765625
229.00390625
221.24874877929688
190.29171752929688
177.00836181640625
225.88821411132812
208.6705780029297
211.54945373535156
143.74365234375
197.296142578125
158.43125915527344
177.5292510986328
196.52261352539062
149.5364532470703
189.84524536132812
131.84841918945312
122.22174072265625
145.59303283691406
134.4705810546875
159.24826049804688
180.6016845703125
123.11812591552734
122.66941833496094
91.75692749023438
133.00306701660156
95.82991027832031
114.77603149414062
98.04903411865234
108.56822204589844
101.47492218017578
101.12481689453125
105.1462

In [None]:
model.eval()
img, raw_target = train_ds[50]
target = voc_target_to_detection(raw_target)

with torch.no_grad():
    pred_box, pred_class_logits = model(img.unsqueeze(0).to(device))
    pred_box = pred_box[0].cpu().numpy()
    pred_label = pred_class_logits.argmax(-1).item()

fig, ax = plt.subplots(1)
ax.imshow(img.permute(1,2,0))

true_box = target["boxes"][0].numpy()
rect_true = patches.Rectangle((true_box[0], true_box[1]),
                              true_box[2]-true_box[0],
                              true_box[3]-true_box[1],
                              linewidth=2, edgecolor='g', facecolor='none')
ax.add_patch(rect_true)

rect_pred = patches.Rectangle((pred_box[0], pred_box[1]),
                              pred_box[2]-pred_box[0],
                              pred_box[3]-pred_box[1],
                              linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect_pred)

ax.set_title(f"Prediction label={pred_label}")
plt.show()
