In [None]:
import torch
import requests
import cv2
import json
import numpy as np
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset, DataLoader
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics import YOLO

In [None]:
# 1. Custom Dataset that streams from URLs
class YOLOStreamingDataset(Dataset):
    def __init__(self, data_list, imgsz=640):
        """
        data_list: List of dicts: 
        {'url': '...', 'bboxes': [[x,y,w,h], ...], 'classes': [0, ...]}
        """
        self.data = data_list
        self.imgsz = imgsz

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Download image
        try:
            response = requests.get(item['url'], timeout=10)
            img = Image.open(BytesIO(response.content)).convert("RGB")
        except Exception as e:
            print(f"Error downloading {item['url']}: {e}")
            # Return a blank image as a fallback (or handle skip logic)
            img = Image.new('RGB', (self.imgsz, self.imgsz), (0, 0, 0))

        w0, h0 = img.size
        img = np.array(img)
        
        # Resize image for the model
        img = cv2.resize(img, (self.imgsz, self.imgsz))
        img = img.transpose(2, 0, 1)  # HWC to CHW
        img = np.ascontiguousarray(img)

        # Convert COCO [x,y,w,h] to YOLO [cls, cx, cy, w, h] normalized
        labels = []
        for i, bbox in enumerate(item['bboxes']):
            x, y, w, h = bbox
            cls = item['classes'][i]
            # Normalize to [0, 1]
            cx = (x + w / 2) / w0
            cy = (y + h / 2) / h0
            wn = w / w0
            hn = h / h0
            labels.append([cls, cx, cy, wn, hn])
        
        # Capture original shape before resizing
        # YOLO expects (height, width)
        ori_shape = (h0, w0) 
        
        # We are stretching the image to imgsz, so ratio is (new/old)
        # and padding is zero because we aren't letterboxing
        ratio_pad = ((self.imgsz / h0, self.imgsz / w0), (0, 0))

        return {
            'img': torch.from_numpy(img).float() / 255.0,
            'cls_bboxes': torch.tensor(labels) if labels else torch.zeros((0, 5)),
            'ori_shape': ori_shape,
            'ratio_pad': ratio_pad,
            'im_file': self.data[idx]['url'] # Required for tracking
        }

# 2. Custom Collate to handle variable numbers of objects per image
def custom_collate_fn(batch):
    imgs = []
    all_bboxes = []
    all_cls = []
    all_batch_idx = []
    
    # New metadata lists
    ori_shapes = []
    ratio_pads = []
    im_files = []

    for i, item in enumerate(batch):
        imgs.append(item['img'])
        ori_shapes.append(item['ori_shape'])
        ratio_pads.append(item['ratio_pad'])
        im_files.append(item['im_file'])
        
        labels = item['cls_bboxes']
        num_objs = labels.shape[0]
        if num_objs > 0:
            all_batch_idx.append(torch.full((num_objs,), i))
            all_cls.append(labels[:, 0])
            all_bboxes.append(labels[:, 1:])

    stacked_imgs = torch.stack(imgs, 0)

    output = {
        'img': stacked_imgs,
        'batch_idx': torch.cat(all_batch_idx, 0) if all_batch_idx else torch.zeros(0),
        'cls': torch.cat(all_cls, 0).view(-1, 1) if all_cls else torch.zeros((0, 1)),
        'bboxes': torch.cat(all_bboxes, 0) if all_bboxes else torch.zeros((0, 4)),
        # METADATA FOR VALIDATOR:
        'ori_shape': ori_shapes,
        'ratio_pad': ratio_pads,
        'im_file': im_files
    }
    return output

# 3. Custom Trainer to override the data loading logic
class URLStreamTrainer(DetectionTrainer):
    # We define these as class-level placeholders
    train_data_list = []
    val_data_list = []

    def build_dataset(self, img_path, mode="train", batch=None):
        """
        Overrides the dataset builder. 
        'mode' tells us if we are loading the 'train' or 'val' set.
        """
        data_source = self.train_data_list if mode == "train" else self.val_data_list
        
        return YOLOStreamingDataset(
            data_source, 
            imgsz=self.args.imgsz
        )

    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
        dataset = self.build_dataset(dataset_path, mode=mode)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=(mode == "train"),
            num_workers=self.args.workers,
            collate_fn=custom_collate_fn, # Ensure this is defined globally
            pin_memory=True
        )


In [None]:
# Propose dataset format

# my__data = [
#     {
#         'url': 'https://example.com/street_scene.jpg',    # URL to the image
#         'bboxes': [[100, 200, 50, 80]],                   # List of bounding boxes in COCO format [x, y, width, height]
#         'classes': [1]                                    # Corresponding class IDs for each bounding box
#     }
#     ...
# ]

In [None]:
with open("../final_dataset.json", 'r') as file:
    dataset = json.load(file)

dataset_labels = [item['classes'][0] for item in dataset]

In [None]:
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(dataset, test_size=0.2, stratify=dataset_labels, random_state=42)

URLStreamTrainer.train_data_list = train_data
URLStreamTrainer.val_data_list = val_data

In [None]:
# Initialize YOLO26 model
model = YOLO("yolo26n.pt")

# Start training using the Custom Trainer
model.train(
    trainer=URLStreamTrainer,
    epochs=5,
    imgsz=640,
    batch=64,
    workers=4,  # Use more workers to mitigate download latency
    data="wcs_dataset.yaml", # Needs a dummy yaml just for class names
    plots=False,
)

# Visualize a batch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def test_trainer_visualize(trainer, mode="train", max_images=8):
    """
    Hooks into the trainer, pulls one batch, and displays the ground truth in notebook output.
    """
    # 1. Get the dataloader from the trainer
    # We use a dummy path because our URLStreamTrainer ignores it
    loader = trainer.get_dataloader(dataset_path="dummy", batch_size=16, mode=mode)
    
    # 2. Grab the first batch
    batch = next(iter(loader))
    
    # 3. Unpack batch
    imgs = batch['img']          # [B, 3, 640, 640]
    batch_idx = batch['batch_idx'] # [N]
    bboxes = batch['bboxes']     # [N, 4] (cx, cy, w, h) normalized
    cls = batch['cls']           # [N, 1]
    
    # Limit the number of images to display
    num_images = min(imgs.shape[0], max_images)
    
    # Calculate subplot grid
    cols = min(2, num_images)
    rows = (num_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 7.5 * rows))
    if num_images == 1:
        axes = [axes]
    elif rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()
    
    for i in range(num_images):
        # Convert tensor back to numpy image (RGB)
        img = imgs[i].permute(1, 2, 0).cpu().numpy() # [640, 640, 3]
        img = (img * 255).astype(np.uint8)
        
        h, w = img.shape[:2]
        
        # Find all objects belonging to this image index
        obj_indices = (batch_idx == i).nonzero(as_tuple=True)[0]
        
        # Display the image
        axes[i].imshow(img)
        axes[i].set_title(f"Batch Item {i} - {mode} ({len(obj_indices)} objects)")
        axes[i].axis('off')
        
        # Draw bounding boxes
        for obj_idx in obj_indices:
            # Denormalize coordinates
            cx, cy, bw, bh = bboxes[obj_idx]
            class_id = int(cls[obj_idx])
            
            # Convert center-xywh to corners (x1, y1, width, height) for matplotlib
            x1 = (cx - bw/2) * w
            y1 = (cy - bh/2) * h
            bbox_width = bw * w
            bbox_height = bh * h
            
            # Create rectangle patch
            rect = patches.Rectangle((x1, y1), bbox_width, bbox_height, 
                                   linewidth=2, edgecolor='lime', facecolor='none')
            axes[i].add_patch(rect)
            
            # Add class label
            axes[i].text(x1, y1 - 5, f"Class: {class_id}", 
                        color='lime', fontsize=10, weight='bold',
                        bbox=dict(boxstyle="round,pad=0.3", facecolor='black', alpha=0.7))
        
        print(f"Image {i}: {len(obj_indices)} objects detected")
    
    # Hide any unused subplots
    for i in range(num_images, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# --- HOW TO RUN ---
# Assuming you've already initialized your trainer:
test_trainer_visualize(URLStreamTrainer(overrides={'model': 'yolo26n.pt', 'data': 'wcs_dataset.yaml'}))