In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.image_list import ImageList
from torchvision.ops import boxes as box_ops
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.ops import box_iou
from tqdm import tqdm
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision.transforms import functional as TF
from torchvision.models.detection import RetinaNet
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torch.cuda.amp import autocast, GradScaler

In [3]:
class YOLODataset(Dataset):
    def __init__(self, images_dir, labels_dir, transforms=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transforms = transforms
        self.images = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
        self.labels = [f for f in os.listdir(labels_dir) if f.endswith('.txt') ]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.images_dir, self.images[idx])
        label_path = os.path.join(self.labels_dir, os.path.splitext(self.images[idx])[0] + '.txt')
  
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img,(224,224))
        height, width, _ = img.shape

        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path) as f:
                for line in f.readlines():
                    parts = line.strip().split()
              
                    if len(parts) != 5:
                        continue  
                    class_id, cx, cy, w, h = map(float, parts)
                    if class_id >=12:
                        continue
                    x1 = (cx - w/2) * width
                    y1 = (cy - h/2) * height
                    x2 = (cx + w/2) * width
                    y2 = (cy + h/2) * height
                    if((cx>1 or cx<0) or (cy>1 or cy<0) or (w>1 or w<0) or (h>1 or h<0)):
                        continue;
                    if((abs(x1-x2)*abs(y1-y2))<=0):
                        continue;
                    boxes.append([x1, y1, x2, y2])
                    labels.append(int(class_id))

        target = {}
        target['path'] = label_path
        if boxes:
            target['boxes'] = torch.tensor(boxes, dtype=torch.float32)
            target['labels'] = torch.tensor(labels, dtype=torch.int64)
        else:
            target['boxes'] = torch.zeros((0, 4), dtype=torch.float32)
            target['labels'] = torch.zeros((0,), dtype=torch.int64)

        # target['path']= label_path

        img = TF.to_tensor(img)
        if self.transforms:
            img = self.transforms(img)

        return img, target


In [4]:
# --------------- Model -----------------
def get_retinanet_model(num_classes):
    backbone = resnet_fpn_backbone('resnet18', pretrained=True)
    model = RetinaNet(backbone, num_classes=num_classes)
    return model

In [14]:
# --------------- Training -----------------


def train(model, train_loader, val_loader, device):
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(100): 
        epoch_loss = 0
        model.train()

        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
        for images, targets in train_loader_tqdm:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) if hasattr(v, 'to') else v for k, v in t.items()} for t in targets]
            # print(f"Images: {len(images)}, Targets: {[len(t['boxes']) for t in targets]}")
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            batch_loss = losses.item()
            epoch_loss += batch_loss
            train_loader_tqdm.set_postfix(loss=batch_loss)

        avg_train_loss = epoch_loss / len(train_loader)

        # Validation Loss
        model.train() 
        val_loss = 0
        val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
        with torch.no_grad():
            for images, targets in val_loader_tqdm:
                images = list(img.to(device) for img in images)
                targets = [{k: v.to(device) if hasattr(v, 'to') else v for k, v in t.items()} for t in targets]

                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())

                batch_val_loss = losses.item()
                val_loss += batch_val_loss
                val_loader_tqdm.set_postfix(val_loss=batch_val_loss)

        print()
        

In [15]:
# --------------- Main -----------------
if __name__ == "__main__":
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    full_dataset = YOLODataset("/kaggle/input/dataset-retina/clean_images/clean_images", "/kaggle/input/dataset-retina/clean_label/clean_label")
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

    model = get_retinanet_model(num_classes=12)
    model.to(device)

    

In [None]:
train(model, train_loader, test_loader, device)
torch.save(model.state_dict(), '/kaggle/working/model_weights.pth')

Epoch 1/100
[Train]: 100%|██████████| 687/687 [25:56<00:00, 2.27s/it, loss=1.166]
[Val]: 100%|██████████| 156/156 [03:24<00:00, 1.31s/it, loss=1.091]

Epoch 2/100
[Train]: 100%|██████████| 687/687 [27:37<00:00, 2.41s/it, loss=1.155]
[Val]: 100%|██████████| 156/156 [03:25<00:00, 1.31s/it, loss=1.079]

Epoch 3/100
[Train]: 100%|██████████| 687/687 [26:22<00:00, 2.30s/it, loss=1.125]
[Val]: 100%|██████████| 156/156 [03:14<00:00, 1.24s/it, loss=1.029]

Epoch 4/100
[Train]: 100%|██████████| 687/687 [29:36<00:00, 2.59s/it, loss=1.135]
[Val]: 100%|██████████| 156/156 [03:53<00:00, 1.50s/it, loss=1.022]

Epoch 5/100
[Train]: 100%|██████████| 687/687 [25:53<00:00, 2.26s/it, loss=1.133]
[Val]: 100%|██████████| 156/156 [03:40<00:00, 1.41s/it, loss=1.010]

Epoch 6/100
[Train]: 100%|██████████| 687/687 [29:51<00:00, 2.61s/it, loss=1.091]
[Val]: 100%|██████████| 156/156 [04:06<00:00, 1.58s/it, loss=0.979]

Epoch 7/100
[Train]: 100%|██████████| 687/687 [26:49<00:00, 2.34s/it, loss=1.075]
[Val]: 100%|