In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
import cv2
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict

class DetectionDataset(Dataset):
    def __init__(self, image_dir: str, label_dir: str, transforms=None):
        self.image_dir = Path(image_dir)
        self.label_dir = Path(label_dir)
        self.transforms = transforms
        self.image_files = list(self.image_dir.glob('*.jpg'))
        
    def __len__(self) -> int:
        return len(self.image_files)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        img_path = self.image_files[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Read labels
        label_path = self.label_dir / img_path.with_suffix('.txt').name
        boxes = []
        class_labels = []
        
        if label_path.exists():
            with open(label_path) as f:
                for line in f.readlines():
                    class_id, x_center, y_center, width, height = map(float, line.strip().split())
                    boxes.append([x_center, y_center, width, height])
                    class_labels.append(class_id)
        
        # Convert to numpy arrays
        boxes = np.array(boxes) if boxes else np.zeros((0, 4))
        class_labels = np.array(class_labels) if class_labels else np.zeros(0)
        
        # Apply transforms
        if self.transforms:
            transformed = self.transforms(
                image=image,
                bboxes=boxes,
                class_labels=class_labels
            )
            image = transformed['image']
            boxes = transformed['bboxes']
            class_labels = transformed['class_labels']
        
        # Convert to tensors
        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image).permute(2, 0, 1) / 255.0
            
        boxes = torch.tensor(boxes, dtype=torch.float32)
        class_labels = torch.tensor(class_labels, dtype=torch.long)
        
        return {
            'image': image,
            'boxes': boxes,
            'labels': class_labels,
            'image_id': str(img_path)
        }

class YOLOTrainer:
    def __init__(self, num_classes: int):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        # Create a fresh YOLO model with the correct number of classes
        self.model = YOLO('yolov8n.yaml')
        # Update the model configuration
        self.model.model.args['nc'] = num_classes
        # Move model to device
        self.model.to(self.device)
        
        # Define transforms
        self.train_transforms = A.Compose([
            A.RandomBrightnessContrast(p=0.5),
            A.RandomGamma(p=0.5),
            A.GaussNoise(p=0.3),
            A.OneOf([
                A.MotionBlur(p=0.5),
                A.MedianBlur(blur_limit=3, p=0.5),
                A.GaussianBlur(blur_limit=3, p=0.5),
            ], p=0.3),
            A.RandomRotate90(p=0.5),
            A.Flip(p=0.5),
            A.Resize(640, 640),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
        
        self.val_transforms = A.Compose([
            A.Resize(640, 640),
            A.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
    
    def train(
        self,
        train_dir: str,
        epochs: int = 100,
        batch_size: int = 16,
        learning_rate: float = 0.001,
        val_split: float = 0.2
    ):
        # Load all image paths
        all_images = list(Path(f"{train_dir}/images").glob('*.jpg'))
        
        # Split into train and validation sets
        train_imgs, val_imgs = train_test_split(
            all_images,
            test_size=val_split,
            random_state=42
        )
        
        # Create datasets
        train_dataset = DetectionDataset(
            train_dir + "/images",
            train_dir + "/labels",
            transforms=self.train_transforms
        )
        train_dataset.image_files = train_imgs  # Override with split images
        
        val_dataset = DetectionDataset(
            train_dir + "/images",
            train_dir + "/labels",
            transforms=self.val_transforms
        )
        val_dataset.image_files = val_imgs  # Override with split images
        
        print(f"Training on {len(train_imgs)} images")
        print(f"Validating on {len(val_imgs)} images")
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            collate_fn=self._collate_fn,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            collate_fn=self._collate_fn,
            pin_memory=True
        )
        
        # Optimizer and scheduler
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=learning_rate,
            epochs=epochs,
            steps_per_epoch=len(train_loader)
        )
        
        # Training loop
        best_val_loss = float('inf')
        
        for epoch in range(epochs):
            self.model.train()
            epoch_loss = 0
            
            for batch_idx, batch in enumerate(train_loader):
                # Move data to device
                images = batch['image'].to(self.device)
                targets = [{
                    'boxes': b.to(self.device),
                    'labels': l.to(self.device)
                } for b, l in zip(batch['boxes'], batch['labels'])]
                
                # Forward pass
                loss_dict = self.model(images, targets)
                loss = sum(loss_dict.values())
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
                
                # Update metrics
                epoch_loss += loss.item()
                
                if batch_idx % 10 == 0:
                    print(f"Epoch [{epoch+1}/{epochs}], "
                          f"Batch [{batch_idx}/{len(train_loader)}], "
                          f"Loss: {loss.item():.4f}")
            
            # # Validation phase
            # if epoch % 5 == 0:
            #     val_loss = self.validate(val_loader)
                
            #     # Save best model
            #     if val_loss < best_val_loss:
            #         best_val_loss = val_loss
            #         torch.save({
            #             'epoch': epoch,
            #             'model_state_dict': self.model.state_dict(),
            #             'optimizer_state_dict': optimizer.state_dict(),
            #             'loss': val_loss,
            #         }, 'best_model.pt')
            
            # # Save regular checkpoint
            # if epoch % 10 == 0:
            #     torch.save({
            #         'epoch': epoch,
            #         'model_state_dict': self.model.state_dict(),
            #         'optimizer_state_dict': optimizer.state_dict(),
            #         'loss': epoch_loss,
            #     }, f'checkpoint_epoch_{epoch}.pt')

    
    @staticmethod
    def _collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
        """Custom collate function for the data loader"""
        images = torch.stack([item['image'] for item in batch])
        boxes = [item['boxes'] for item in batch]
        labels = [item['labels'] for item in batch]
        image_ids = [item['image_id'] for item in batch]
        
        return {
            'image': images,
            'boxes': boxes,
            'labels': labels,
            'image_id': image_ids
        }
    
    def validate(self, val_loader: DataLoader):
        """Run validation"""
        self.model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(self.device)
                targets = [{
                    'boxes': b.to(self.device),
                    'labels': l.to(self.device)
                } for b, l in zip(batch['boxes'], batch['labels'])]
                
                loss_dict = self.model(images, targets)
                val_loss += sum(loss_dict.values()).item()
        
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}")
        
    def predict(self, image_path: str, conf_thres: float = 0.25):
        """Run inference on a single image"""
        self.model.eval()
        
        # Load and preprocess image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply validation transforms
        transformed = self.val_transforms(image=image)
        image_tensor = transformed['image'].unsqueeze(0).to(self.device)
        
        # Run inference
        with torch.no_grad():
            predictions = self.model(image_tensor)
        
        return predictions

# Initialize trainer with number of classes
trainer = YOLOTrainer(num_classes=2)  # Adjust number of classes

# Train the model
trainer.train(
    train_dir='train/',
    epochs=10,
    batch_size=128,
    learning_rate=0.01
)

Using device: cuda


  validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)
  A.Flip(p=0.5),


Training on 432 images
Validating on 108 images
New https://pypi.org/project/ultralytics/8.3.41 available 😃 Update with 'pip install -U ultralytics'
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=yolov8n.yaml, data=coco8.yaml, epochs=100, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=cuda:0, workers=8, project=None, name=train3, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_con

[34m[1mtrain: [0mScanning /home/ayman/ml/pt/yolo/datasets/coco8/labels/train.cache... 4 images, 0 backgrounds, 0 corrupt: 100%|██████████| 4/4 [00:00<?, ?it/s]

[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, num_output_channels=3, method='weighted_average'), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))



[34m[1mval: [0mScanning /home/ayman/ml/pt/yolo/datasets/coco8/labels/val.cache... 4 images, 0 backgrounds, 0 corrupt: 100%|██████████| 4/4 [00:00<?, ?it/s]


Plotting labels to runs/detect/train3/labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.000119, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 640 train, 640 val
Using 8 dataloader workers
Logging results to [1mruns/detect/train3[0m
Starting training for 100 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      1/100     0.856G      2.827       5.36      4.451         18        640: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:00<00:00,  2.27it/s]

                   all          4         17          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      2/100      0.83G       3.45       5.49      4.262         22        640: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:00<00:00,  3.13it/s]

                   all          4         17          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      3/100      0.83G      4.056      5.643      4.238         24        640: 100%|██████████| 1/1 [00:01<00:00,  1.11s/it]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:00<00:00,  1.49it/s]

                   all          4         17          0          0          0          0






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      4/100      0.83G       3.49      5.544      4.246         34        640: 100%|██████████| 1/1 [00:22<00:00, 22.90s/it]


KeyboardInterrupt: 