In [1]:
import torch

In [2]:
### baseline 
IMG_SIZE = 416
GRID_SIZE = 52
NUM_CLASSES = 7
BATCH_SIZE = 8
EPOCHS = 50
LR = 1e-3

In [3]:
from torch.utils.data import DataLoader
from dataloader.data_load import UnderwaterDataset  # adjust import if needed

train_dataset = UnderwaterDataset(
    img_dir="/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/Dataset/train/images",
    label_dir="/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/Dataset/train/labels"
)

val_dataset = UnderwaterDataset(
    img_dir="/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/Dataset/valid/images",
    label_dir="/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/Dataset/valid/labels"
)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
)



In [4]:
len(train_loader)

56

In [5]:
import os
print(os.getcwd())


/Users/anirudhmamgain/Desktop/Object_detection_from_scratch


In [6]:
from models.detector import Detector
from models.loss import DetectionLoss

device = torch.device("mps")

model = Detector(num_classes=NUM_CLASSES).to(device)
criterion = DetectionLoss(num_classes=NUM_CLASSES)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)


In [7]:
import os
import torch

os.makedirs("checkpoints/customcnn", exist_ok=True)

best_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()

    tot_loss = 0.0
    tot_box = 0.0
    tot_obj = 0.0
    tot_cls = 0.0

    for imgs, targets in train_loader:
        imgs = imgs.to(device)
        targets = targets.to(device)

        preds = model(imgs)

        box_loss, obj_loss, cls_loss = criterion(preds, targets)

        # cls_loss = torch.clamp(cls_loss, max=10.0)

        loss = 5.0 * box_loss + obj_loss + 0.05 * cls_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tot_loss += loss.item()
        tot_box  += box_loss.item()
        tot_obj  += obj_loss.item()
        tot_cls  += cls_loss.item()

    avg_loss = tot_loss / len(train_loader)

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(
            model.state_dict(),
            "checkpoints/best_model.pth"
        )

    torch.save(
        model.state_dict(),
        "checkpoints/last_model.pth"
    )

    n = len(train_loader)
    print(
        f"Epoch [{epoch+1}/{EPOCHS}] | "
        f"Total: {tot_loss/n:.3f} | "
        f"Box: {tot_box/n:.3f} | "
        f"Obj: {tot_obj/n:.3f} | "
        f"Cls: {tot_cls/n:.3f}"
    )


Epoch [1/50] | Total: 5.117 | Box: 0.966 | Obj: 0.163 | Cls: 2.442
Epoch [2/50] | Total: 5.094 | Box: 0.964 | Obj: 0.059 | Cls: 4.300
Epoch [3/50] | Total: 5.218 | Box: 0.966 | Obj: 0.057 | Cls: 6.637
Epoch [4/50] | Total: 5.285 | Box: 0.964 | Obj: 0.055 | Cls: 8.175
Epoch [5/50] | Total: 5.325 | Box: 0.965 | Obj: 0.050 | Cls: 9.014
Epoch [6/50] | Total: 5.513 | Box: 0.966 | Obj: 0.065 | Cls: 12.371
Epoch [7/50] | Total: 5.716 | Box: 0.964 | Obj: 0.054 | Cls: 16.852
Epoch [8/50] | Total: 5.958 | Box: 0.964 | Obj: 0.058 | Cls: 21.585
Epoch [9/50] | Total: 6.248 | Box: 0.964 | Obj: 0.062 | Cls: 27.358
Epoch [10/50] | Total: 6.284 | Box: 0.963 | Obj: 0.050 | Cls: 28.383
Epoch [11/50] | Total: 6.561 | Box: 0.962 | Obj: 0.044 | Cls: 34.107
Epoch [12/50] | Total: 6.857 | Box: 0.963 | Obj: 0.061 | Cls: 39.576
Epoch [13/50] | Total: 7.329 | Box: 0.963 | Obj: 0.041 | Cls: 49.412
Epoch [14/50] | Total: 7.682 | Box: 0.964 | Obj: 0.058 | Cls: 56.093
Epoch [15/50] | Total: 7.540 | Box: 0.963 | Obj:

KeyboardInterrupt: 

In [8]:
# ---------------- CONFIG ----------------
IMG_DIR = "/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/Dataset/train/images"
LABEL_DIR = "/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/Dataset/train/labels"
CHECKPOINT = "/Users/anirudhmamgain/Desktop/Object_detection_from_scratch/checkpoints/last_model.pth"

BATCH_SIZE = 8
EPOCHS = 40
LR = 3e-4
CLS_WEIGHT = 0.03
NUM_CLASSES = 7
FREEZE_EPOCHS = 10

# ---------------- DEVICE ----------------
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# ---------------- DATA ----------------
dataset = UnderwaterDataset(
    img_dir=IMG_DIR,
    label_dir=LABEL_DIR
)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

# ---------------- MODEL ----------------
model = Detector(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(CHECKPOINT, map_location=device))
model.to(device)

# ---------------- LOSS ----------------
criterion = DetectionLoss(num_classes=NUM_CLASSES)

# ---------------- OPTIMIZER ----------------
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=15,
    gamma=0.1
)

# ---------------- TRAINING ----------------
for epoch in range(EPOCHS):
    model.train()

    # ---- Freeze backbone for first few epochs ----
    if epoch < FREEZE_EPOCHS:
        for p in model.backbone.parameters():
            p.requires_grad = False
    else:
        for p in model.backbone.parameters():
            p.requires_grad = True

    tot_loss = tot_box = tot_obj = tot_cls = 0.0

    for imgs, targets in loader:
        imgs = imgs.to(device)
        targets = targets.to(device)

        preds = model(imgs)

        box_loss, obj_loss, cls_loss = criterion(preds, targets)
        total_loss = 5.0 * box_loss + obj_loss + CLS_WEIGHT * cls_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        tot_loss += total_loss.item()
        tot_box  += box_loss.item()
        tot_obj  += obj_loss.item()
        tot_cls  += cls_loss.item()

    scheduler.step()

    n = len(loader)
    print(
        f"Epoch [{epoch+1}/{EPOCHS}] | "
        f"Total: {tot_loss/n:.3f} | "
        f"Box: {tot_box/n:.3f} | "
        f"Obj: {tot_obj/n:.3f} | "
        f"Cls: {tot_cls/n:.3f} | "
        f"LR: {scheduler.get_last_lr()[0]:.2e}"
    )

    torch.save(model.state_dict(), "checkpoints/new_last_model.pth")


Using device: mps
Epoch [1/40] | Total: 39.936 | Box: 0.962 | Obj: 0.037 | Cls: 1169.687 | LR: 3.00e-04
Epoch [2/40] | Total: 40.791 | Box: 0.960 | Obj: 0.035 | Cls: 1198.513 | LR: 3.00e-04
Epoch [3/40] | Total: 41.225 | Box: 0.960 | Obj: 0.035 | Cls: 1212.920 | LR: 3.00e-04
Epoch [4/40] | Total: 41.068 | Box: 0.959 | Obj: 0.035 | Cls: 1207.892 | LR: 3.00e-04
Epoch [5/40] | Total: 41.666 | Box: 0.961 | Obj: 0.032 | Cls: 1227.561 | LR: 3.00e-04
Epoch [6/40] | Total: 42.728 | Box: 0.961 | Obj: 0.033 | Cls: 1263.071 | LR: 3.00e-04
Epoch [7/40] | Total: 41.784 | Box: 0.961 | Obj: 0.034 | Cls: 1231.546 | LR: 3.00e-04
Epoch [8/40] | Total: 42.094 | Box: 0.961 | Obj: 0.033 | Cls: 1241.889 | LR: 3.00e-04
Epoch [9/40] | Total: 42.414 | Box: 0.960 | Obj: 0.034 | Cls: 1252.676 | LR: 3.00e-04
Epoch [10/40] | Total: 43.202 | Box: 0.960 | Obj: 0.032 | Cls: 1278.974 | LR: 3.00e-04
Epoch [11/40] | Total: 43.026 | Box: 0.961 | Obj: 0.031 | Cls: 1273.045 | LR: 3.00e-04
Epoch [12/40] | Total: 43.322 | Bo