In [1]:
import torch
from torch import nn
import numpy as np
import pandas as pd
import os, cv2
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
image_folder = "data/train/images"
annotation_folder = "data/train/labels"

In [3]:
class YOLOdataset(Dataset):
    def __init__(self, image_dir, annot_dir, S=4, B=2, C=20, transform=None):
        self.image_dir = image_dir
        self.annot_dir = annot_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]
        self.S, self.B, self.C = S, B, C
        self.transform = transform

        self.classes = [
            "aeroplane", "bicycle", "bird", "boat", "bottle",
            "bus", "car", "cat", "chair", "cow", "diningtable",
            "dog", "horse", "motorbike", "person", "pottedplant",
            "sheep", "sofa", "train", "tvmonitor"
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_filename = self.image_files[idx]
        image_path = os.path.join(self.image_dir, img_filename)
        annot_path = os.path.join(self.annot_dir, img_filename.replace(".jpg", ".xml"))

        image = Image.open(image_path).convert("RGB")
        boxes, labels = self.parse_voc_xml(annot_path, image.size)
        target = self.encode_target(boxes, labels)

        if self.transform:
            image = self.transform(image)
        return image, target

    def parse_voc_xml(self, xml_path, image_size):
        boxes, labels = [], []
        tree = ET.parse(xml_path)
        root = tree.getroot()
        w, h = image_size

        for obj in root.findall("object"):
            label = obj.find("name").text
            xml_box = obj.find("bndbox")
            xmin = float(xml_box.find("xmin").text)
            ymin = float(xml_box.find("ymin").text)
            xmax = float(xml_box.find("xmax").text)
            ymax = float(xml_box.find("ymax").text)

            x_center = ((xmin + xmax) / 2) / w
            y_center = ((ymin + ymax) / 2) / h
            box_w = (xmax - xmin) / w
            box_h = (ymax - ymin) / h

            boxes.append([x_center, y_center, box_w, box_h])
            labels.append(self.classes.index(label))
        return boxes, labels

    def encode_target(self, boxes, labels):
        S, B, C = self.S, self.B, self.C
        target = torch.zeros((S, S, C + 5 * B))

        for box, label in zip(boxes, labels):
            x, y, w, h = box
            grid_x = min(int(S * x), S - 1)
            grid_y = min(int(S * y), S - 1)
            x_cell = S * x - grid_x
            y_cell = S * y - grid_y

            # fill first box
            target[grid_y, grid_x, 0:5] = torch.tensor([x_cell, y_cell, w, h, 1])
            # class one-hot
            target[grid_y, grid_x, 5 * B + label] = 1

        return target

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [5]:
train_dataset = YOLOdataset(
    image_dir="data/train/Images",
    annot_dir="data/train/labels",
    transform=transform
)

test_dataset = YOLOdataset(
    image_dir="data/test/Images",
    annot_dir="data/test/labels",
    transform=transform
)

image, target = train_dataset[0]
print(image.shape)  # [3, 224, 224]
print(target.shape) # [4, 4, 30]

torch.Size([3, 224, 224])
torch.Size([4, 4, 30])


In [6]:
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [7]:
for imgs, targets in train_loader:
    print("Batch image shape:", imgs.shape)    # [batch_size, 3, 448, 448]
    print("Batch target shape:", targets.shape) # [batch_size, 7, 7, 30]
    break

Batch image shape: torch.Size([4, 3, 224, 224])
Batch target shape: torch.Size([4, 4, 4, 30])


In [8]:
class YoloLoss(nn.Module):
    def __init__(self, S=4, B=2, C=20, λ_coord=5, λ_noobj=0.5):
        super(YoloLoss, self).__init__()
        self.mse = nn.MSELoss(reduction="sum")
        self.S = S
        self.B = B
        self.C = C
        self.lambda_coord = λ_coord
        self.lambda_noobj = λ_noobj

    def forward(self, predictions, target):
        # reshape [N, S, S, C + 5B]
        predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)

        # --- 1️⃣ CLASS LOSS ---
        class_pred = predictions[..., :self.C]
        class_target = target[..., self.B * 5:]
        obj_mask = target[..., 4].unsqueeze(-1)  # only where object exists
        class_loss = self.mse(obj_mask * class_pred, obj_mask * class_target)

        # --- 2️⃣ COORDINATE LOSS ---
        box_pred = predictions[..., self.C:self.C + 5]  # first box
        box_target = target[..., 0:5]

        # x, y loss
        box_pred_xy = box_pred[..., 0:2]
        box_target_xy = box_target[..., 0:2]
        coord_loss_xy = self.mse(obj_mask * box_pred_xy, obj_mask * box_target_xy)

        # w, h loss (sqrt)
        box_pred_wh = torch.sign(box_pred[..., 2:4]) * torch.sqrt(torch.abs(box_pred[..., 2:4] + 1e-6))
        box_target_wh = torch.sqrt(box_target[..., 2:4])
        coord_loss_wh = self.mse(obj_mask * box_pred_wh, obj_mask * box_target_wh)

        coord_loss = self.lambda_coord * (coord_loss_xy + coord_loss_wh)

        # --- 3️⃣ OBJECT CONFIDENCE LOSS ---
        conf_pred = box_pred[..., 4]
        conf_target = box_target[..., 4]
        obj_conf_loss = self.mse(obj_mask.squeeze(-1) * conf_pred, obj_mask.squeeze(-1) * conf_target)

        # --- 4️⃣ NO OBJECT CONFIDENCE LOSS ---
        noobj_mask = 1 - obj_mask
        noobj_conf_loss = self.mse(noobj_mask.squeeze(-1) * conf_pred, noobj_mask.squeeze(-1) * conf_target)
        noobj_conf_loss = self.lambda_noobj * noobj_conf_loss

        # --- TOTAL LOSS ---
        total_loss = coord_loss + obj_conf_loss + noobj_conf_loss + class_loss
        return total_loss

# Lekin YOLOv1 ka structure hota hai:
# [..., C + 0: C + 5] → first bbox [x, y, w, h, conf]
# [..., C + 5: C + 10] → second bbox [x, y, w, h, conf]
# [..., :C] → class probabilities

In [9]:
loss_fn = YoloLoss(S=4, B=2, C=20)

In [10]:
class YOLO(nn.Module):
    def __init__(self, B, C):
        super().__init__()

        self.B = B
        self.C = C
        self.S = 4
        self.dropout = nn.Dropout(0.5)

        
        self.conv_1 = nn.Conv2d(in_channels = 3, out_channels=64, kernel_size=7, stride=2, padding=3) #1
        self.leaky1 = nn.LeakyReLU(0.1)
        self.max_pool_1 = nn.MaxPool2d(kernel_size=2,stride=2, padding=0) #2

        self.conv_2 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=1) #3
        self.leaky2 = nn.LeakyReLU(0.1)
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) #4

        self.conv_3 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1, stride=1, padding=0) #5
        self.leaky3 = nn.LeakyReLU(0.1)

        self.conv_4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) #6
        self.leaky4 = nn.LeakyReLU(0.1)

        self.conv_5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=1, padding=0) #7
        self.leaky5 = nn.LeakyReLU(0.1)

        self.conv_6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) #8
        self.leaky6 = nn.LeakyReLU(0.1)
        self.max_pool_3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) #9


        # 10–17
        self.conv_7 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0) #10
        self.leaky7 = nn.LeakyReLU(0.1)

        self.conv_8 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) #11
        self.leaky8 = nn.LeakyReLU(0.1)

        self.conv_9 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0) #12
        self.leaky9 = nn.LeakyReLU(0.1)

        self.conv_10 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) #13
        self.leaky10 = nn.LeakyReLU(0.1)

        self.conv_11 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0) #14
        self.leaky11 = nn.LeakyReLU(0.1)

        self.conv_12 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) #15
        self.leaky12 = nn.LeakyReLU(0.1)

        self.conv_13 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0) #16
        self.leaky13 = nn.LeakyReLU(0.1)

        self.conv_14 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) #17
        self.leaky14 = nn.LeakyReLU(0.1)
        self.max_pool_4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)



        self.conv_15 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0) #18
        self.leaky15 = nn.LeakyReLU(0.1)

        self.conv_16 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1) #19
        self.leaky16 = nn.LeakyReLU(0.1)

        self.conv_17 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=0) #20
        self.leaky17 = nn.LeakyReLU(0.1)

        self.conv_18 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1) #21
        self.leaky18 = nn.LeakyReLU(0.1)

        self.conv_19 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1) #22
        self.leaky19 = nn.LeakyReLU(0.1)

        self.conv_20 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1) #23
        self.leaky20 = nn.LeakyReLU(0.1)

        self.conv_21 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1) #24
        self.leaky21 = nn.LeakyReLU(0.1)



        # self.fc1 = nn.Linear(in_features=3*3*1024, out_features=4096)
        # self.leaky22 = nn.LeakyReLU(0.1)
        # self.fc2 = nn.Linear(4096, 3*3*(self.B*5 + self.C))

        # self.fc1 = nn.Linear(in_features=self.S*self.S*1024, out_features=4096)
        self.fc1 = nn.Linear(1024 * self.S * self.S, 4096)
        self.leaky22 = nn.LeakyReLU(0.1)
        self.fc2 = nn.Linear(4096, self.S*self.S*(self.B*5 + self.C))

        # B = 2 (bounding boxes per grid cell)
        # C = 20 (for Pascal VOC 20 classes)


    def forward(self, x):

        out1 = self.leaky1(self.conv_1(x))
        out2 = self.max_pool_1(out1)
        
        out3 = self.leaky2(self.conv_2(out2))
        out4 = self.max_pool_2(out3)
        
        out5 = self.leaky3(self.conv_3(out4))
        out6 = self.leaky4(self.conv_4(out5))
        out7 = self.leaky5(self.conv_5(out6))
        out8 = self.leaky6(self.conv_6(out7))
        out9 = self.max_pool_3(out8)
        out10 = self.leaky7(self.conv_7(out9))
        out11 = self.leaky8(self.conv_8(out10))
        out12 = self.leaky9(self.conv_9(out11))
        out13 = self.leaky10(self.conv_10(out12))
        out14 = self.leaky11(self.conv_11(out13))
        out15 = self.leaky12(self.conv_12(out14))
        out16 = self.leaky13(self.conv_13(out15))
        out17 = self.leaky14(self.conv_14(out16))
        out18 = self.max_pool_4(out17)
        out19 = self.leaky15(self.conv_15(out18))
        out20 = self.leaky16(self.conv_16(out19))
        out21 = self.leaky17(self.conv_17(out20))
        out22 = self.leaky18(self.conv_18(out21))
        out23 = self.leaky19(self.conv_19(out22))
        out24 = self.leaky20(self.conv_20(out23))
        out25 = self.leaky21(self.conv_21(out24))
        
        out26 = out25.view(out25.size(0), -1)
        out27 = self.leaky22(self.fc1(out26))
        out28 = self.dropout(out27)
        out29 = self.fc2(out28)
        out29 = out29.view(-1, self.S, self.S, self.C + 5*self.B)
        
        return out29

In [11]:
device = "cuda"  if torch.cuda.is_available() else "cpu"
device

'cuda'

In [12]:
import gc
# device = "cpu"
gc.collect()
torch.cuda.empty_cache()

model = YOLO(B=2, C=20)
model.to(device)

YOLO(
  (dropout): Dropout(p=0.5, inplace=False)
  (conv_1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (leaky1): LeakyReLU(negative_slope=0.1)
  (max_pool_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_2): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (leaky2): LeakyReLU(negative_slope=0.1)
  (max_pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_3): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
  (leaky3): LeakyReLU(negative_slope=0.1)
  (conv_4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (leaky4): LeakyReLU(negative_slope=0.1)
  (conv_5): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
  (leaky5): LeakyReLU(negative_slope=0.1)
  (conv_6): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (leaky6): LeakyReLU(negative_slope=0.1)
  (max_pool_3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,

In [13]:
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# optimizer = torch.optim.SGD(model.parameters(),lr=1e-2,momentum=0.9, weight_decay=0.0005)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [14]:
import torch
import logging
import os
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt

# ==== CONFIG ====
epochs = 200
learning_rate = 1e-4
save_checkpoint = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_scale = 224
amp = False

# ==== LOGGING SETUP ====
os.makedirs("logs", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("detections", exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
log_file = f"logs/train_{timestamp}.log"

logging.basicConfig(
    filename=log_file,
    filemode='w',
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

logging.info(f"Starting training at {timestamp}")
logging.info(f"Device: {device}")

# ==== MODEL / LOSS / OPTIMIZER ====
model = YOLO(B=2, C=20).to(device)
loss_fn = YoloLoss(S=4, B=2, C=20)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# ==== CHECKPOINT HANDLING ====
# checkpoint_path = "checkpoints/best_model.pth"
checkpoint_path = "checkpoints/last_checkpoint.pth"
best_loss = float('inf')
start_epoch = 0

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_loss = checkpoint.get('best_loss', float('inf'))
    logging.info(f"✅ Resumed from epoch {start_epoch}, best loss so far: {best_loss:.4f}")

# ==== TRAINING LOOP ====
for epoch in range(start_epoch, epochs):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, total=len(train_loader), desc=f"Epoch [{epoch+1}/{epochs}]")
    for imgs, targets in loop:
        imgs, targets = imgs.to(device), targets.to(device)

        # forward
        preds = model(imgs)
        loss = loss_fn(preds, targets)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)
    scheduler.step()

    # ==== LOGGING ====
    logging.info(f"Epoch [{epoch+1}/{epochs}] | Avg Loss: {avg_loss:.4f}")
    print(f"✅ Epoch [{epoch+1}/{epochs}] | Avg Loss: {avg_loss:.4f}")

    # ==== SAVE CHECKPOINT ====
    if save_checkpoint:
        state = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_loss
        }
        torch.save(state, "checkpoints/last_checkpoint.pth")

    # ==== SAVE BEST MODEL ====
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(state, checkpoint_path)
        logging.info(f"💾 Saved Best Model at epoch {epoch+1} (Loss: {best_loss:.4f})")

    # ==== SAVE SAMPLE DETECTION ====
    # model.eval()
    # with torch.no_grad():
    #     sample_img, _ = next(iter(val_loader))
    #     sample_img = sample_img.to(device)
    #     pred = model(sample_img)

    #     # convert predictions to bounding boxes (you should have a decode function)
    #     # boxes, labels, scores = decode_predictions(pred)

    #     # visualize (placeholder example)
    #     save_path = f"detections/epoch_{epoch+1}.png"
    #     img_np = sample_img[0].permute(1,2,0).cpu().numpy()
    #     plt.imshow(img_np)
    #     plt.title(f"Epoch {epoch+1} Predictions")
    #     plt.axis("off")
    #     plt.savefig(save_path, bbox_inches="tight")
    #     plt.close()
    #     logging.info(f"📸 Saved detection example at {save_path}")

Epoch [135/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [22:13<00:00,  4.15it/s, loss=15.9]


✅ Epoch [135/200] | Avg Loss: 18.2605


Epoch [136/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:03<00:00,  4.60it/s, loss=26.6]


✅ Epoch [136/200] | Avg Loss: 18.2566


Epoch [137/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [19:43<00:00,  4.67it/s, loss=8.62]


✅ Epoch [137/200] | Avg Loss: 18.2649


Epoch [138/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:07<00:00,  4.58it/s, loss=30.6]


✅ Epoch [138/200] | Avg Loss: 18.2579


Epoch [139/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:32<00:00,  4.49it/s, loss=21.2]


✅ Epoch [139/200] | Avg Loss: 18.2628


Epoch [140/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:37<00:00,  4.47it/s, loss=31.5]


✅ Epoch [140/200] | Avg Loss: 18.2517


Epoch [141/200]: 100%|████████████████████████████████████████████████████| 5534/5534 [20:26<00:00,  4.51it/s, loss=15]


✅ Epoch [141/200] | Avg Loss: 18.2601


Epoch [142/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:33<00:00,  4.49it/s, loss=36.2]


✅ Epoch [142/200] | Avg Loss: 18.2561


Epoch [143/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:25<00:00,  4.52it/s, loss=10.2]


✅ Epoch [143/200] | Avg Loss: 18.2709


Epoch [144/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:20<00:00,  4.53it/s, loss=38.6]


✅ Epoch [144/200] | Avg Loss: 18.2539


Epoch [145/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:22<00:00,  4.53it/s, loss=13.3]


✅ Epoch [145/200] | Avg Loss: 18.2616


Epoch [146/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:40<00:00,  4.46it/s, loss=24.7]


✅ Epoch [146/200] | Avg Loss: 18.2460


Epoch [147/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:34<00:00,  4.48it/s, loss=13.5]


✅ Epoch [147/200] | Avg Loss: 18.2493


Epoch [148/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:40<00:00,  4.46it/s, loss=19.7]


✅ Epoch [148/200] | Avg Loss: 18.2557


Epoch [149/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:42<00:00,  4.45it/s, loss=7.19]


✅ Epoch [149/200] | Avg Loss: 18.2552


Epoch [150/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:30<00:00,  4.50it/s, loss=17.9]


✅ Epoch [150/200] | Avg Loss: 18.2626


Epoch [151/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:37<00:00,  4.47it/s, loss=12.1]


✅ Epoch [151/200] | Avg Loss: 18.2598


Epoch [152/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:45<00:00,  4.44it/s, loss=7.81]


✅ Epoch [152/200] | Avg Loss: 18.2526


Epoch [153/200]: 100%|████████████████████████████████████████████████████| 5534/5534 [20:44<00:00,  4.45it/s, loss=20]


✅ Epoch [153/200] | Avg Loss: 18.2539


Epoch [154/200]: 100%|██████████████████████████████████████████████████| 5534/5534 [20:44<00:00,  4.45it/s, loss=13.6]


✅ Epoch [154/200] | Avg Loss: 18.2537


Epoch [155/200]:   3%|█▌                                                  | 172/5534 [00:36<19:09,  4.66it/s, loss=9.5]


KeyboardInterrupt: 