In [None]:
import torch
import gc
from ultralytics import YOLO
import albumentations as A
import cv2
import os
import numpy as np
from pathlib import Path
import shutil
from sklearn.model_selection import train_test_split
import yaml
train_dir = Path('split_train')
val_dir = Path('split_val')
dev='cuda'
class DetectionTrainer:
    def __init__(self, data_yaml_path):
        self.device = dev
        print(f"Using device: {self.device}")
        
        self.model = YOLO('yolov8s.pt')
        self.data_yaml_path = str(Path(data_yaml_path).absolute())
        
        # Simplified transforms to reduce memory usage
        self.transforms_with_boxes = {
            'brightness': A.Compose([
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.0, p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
            
            'contrast': A.Compose([
                A.RandomBrightnessContrast(brightness_limit=0.0, contrast_limit=0.3, p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
            
            'gamma': A.Compose([
                A.RandomGamma(gamma_limit=(80, 120), p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
            
            'noise': A.Compose([
                A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
            
            'blur': A.Compose([
                A.GaussianBlur(blur_limit=(3, 7), p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
            
            'rotation': A.Compose([
                A.RandomRotate90(p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
            
            'flip': A.Compose([
                A.Flip(p=1.0),
                A.Resize(640, 640),
            ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])),
        }
        
        self.transforms_no_boxes = {
            name: A.Compose([t for t in transform if not isinstance(t, A.Resize)] + [A.Resize(640, 640)])
            for name, transform in self.transforms_with_boxes.items()
        }

    def process_single_image(self, img_path, label_path, dest_dir):
        """Process a single image with memory management"""
        try:
            image = cv2.imread(str(img_path))
            if image is None:
                print(f"Warning: Could not read image {img_path}")
                return
            
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            boxes = []
            class_labels = []
            
            if label_path.exists():
                with open(label_path) as f:
                    for line in f.readlines():
                        class_id, x_center, y_center, width, height = map(float, line.strip().split())
                        boxes.append([x_center, y_center, width, height])
                        class_labels.append(class_id)

            # Save original image
            cv2.imwrite(str(dest_dir / 'images' / img_path.name), 
                       cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
            
            if label_path.exists():
                shutil.copy2(label_path, dest_dir / 'labels' / label_path.name)

            transforms = self.transforms_with_boxes if boxes else self.transforms_no_boxes
            
            for transform_name, transform in transforms.items():
                aug_img_name = f"{img_path.stem}_{transform_name}{img_path.suffix}"
                aug_label_name = f"{img_path.stem}_{transform_name}.txt"
                
                if boxes:
                    transformed = transform(
                        image=image,
                        bboxes=boxes,
                        class_labels=class_labels
                    )
                    
                    cv2.imwrite(
                        str(dest_dir / 'images' / aug_img_name),
                        cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR)
                    )
                    
                    with open(dest_dir / 'labels' / aug_label_name, 'w') as f:
                        for box, label in zip(transformed['bboxes'], transformed['class_labels']):
                            f.write(f"{int(label)} {' '.join(map(str, box))}\n")
                else:
                    transformed = transform(image=image)
                    cv2.imwrite(
                        str(dest_dir / 'images' / aug_img_name),
                        cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR)
                    )
                    
                    if label_path.exists():
                        shutil.copy2(label_path, dest_dir / 'labels' / aug_label_name)
                
                # Clear transformed data
                transformed = None
                
            # Clear memory
            image = None
            boxes = None
            class_labels = None
            gc.collect()
            
        except Exception as e:
            print(f"Error processing {img_path}: {str(e)}")

    def prepare_split_datasets(self, source_dir, train_dir, val_dir, val_split=0.2):
        """Split dataset with memory management"""
        source_dir = Path(source_dir)
        train_dir = Path(train_dir)
        val_dir = Path(val_dir)
        
        for dir_path in [train_dir / 'images', train_dir / 'labels',
                        val_dir / 'images', val_dir / 'labels']:
            dir_path.mkdir(parents=True, exist_ok=True)
        
        # Process images in smaller batches
        batch_size = 80
        image_files = list((source_dir / 'images').glob('*.jpg'))
        
        train_imgs, val_imgs = train_test_split(
            image_files,
            test_size=val_split,
            random_state=42
        )
        
        print(f"Processing {len(train_imgs)} training images in batches...")
        for i in range(0, len(train_imgs), batch_size):
            batch = train_imgs[i:i + batch_size]
            for img_path in batch:
                label_path = source_dir / 'labels' / img_path.with_suffix('.txt').name
                self.process_single_image(img_path, label_path, train_dir)
            gc.collect()
            print(f"Processed batch {i//batch_size + 1}/{len(train_imgs)//batch_size + 1}")
            
        print(f"Processing {len(val_imgs)} validation images...")
        for img_path in val_imgs:
            label_path = source_dir / 'labels' / img_path.with_suffix('.txt').name
            shutil.copy2(img_path, val_dir / 'images' / img_path.name)
            if label_path.exists():
                shutil.copy2(label_path, val_dir / 'labels' / label_path.name)
        
        # Clear memory
        train_imgs = None
        val_imgs = None
        gc.collect()

    def train(self, epochs, imgsz=640, batch_size=32):
        """Train with memory optimization"""
        print(f"Training on {self.device}")
        
        with open(self.data_yaml_path, 'r') as f:
            data_config = yaml.safe_load(f)
        
        data_config['train'] = str(train_dir / 'images')
        data_config['val'] = str(val_dir / 'images')
        
        temp_yaml_path = 'temp_data.yaml'
        with open(temp_yaml_path, 'w') as f:
            yaml.dump(data_config, f)
        
        args = dict(
            data=temp_yaml_path,
            epochs=epochs,
            imgsz=imgsz,
            batch=batch_size,  # Reduced batch size
            patience=20,
            save_period=10,
            verbose=True,
            device=self.device,
            project=str(Path().absolute() / 'runs'),
            augment=True,
            cache=False,
            workers=4,  # Reduced number of workers
            lr0=0.01,
            lrf=0.001,
            name="test_run"
        )
        
        # try:
        self.model.train(**args)
        # finally:
        #     for dir_path in [train_dir, val_dir]:
        #         if dir_path.exists():
        #             shutil.rmtree(dir_path)
        #     if os.path.exists(temp_yaml_path):
        #         os.remove(temp_yaml_path)

    def test(self, conf_threshold=0.25, iou_threshold=0.45):
        """Test the model on the test dataset"""
        print(f"\nRunning tests on {self.device}")
        
        # Create a temporary yaml file for testing
        with open(self.data_yaml_path, 'r') as f:
            data_config = yaml.safe_load(f)
        
        # Update paths to point to test directory
        data_config['val'] = str(Path('test/images'))  # Use test directory for validation
        
        temp_yaml_path = 'temp_test_data.yaml'
        with open(temp_yaml_path, 'w') as f:
            yaml.dump(data_config, f)
            
        results = self.model.val(
            data=temp_yaml_path,
            split='test',
            conf=conf_threshold,
            iou=iou_threshold,
            device=dev,
            verbose=True
        )
        return results

    def predict(self, image_path):
        """Run inference on a single image"""
        return self.model.predict(
            source=image_path,
            conf=0.25,
            iou=0.45,
            device=dev
        )

def prepare_dataset_structure():
    """
    Create the simplified dataset structure
    """
    dirs = ['train/images', 'train/labels', 
            'test/images', 'test/labels']
    
    for dir_path in dirs:
        Path(dir_path).mkdir(parents=True, exist_ok=True)

def create_data_yaml(dataset_path):
    """
    Create the data.yaml file for YOLOv8
    """
    # Convert to absolute path
    abs_path = str(Path(dataset_path).absolute())
    
    yaml_content = f"""
path: {abs_path}  # dataset root directory
train: {abs_path}/train/images  # train images
val: {abs_path}/temp_val/images  # temporary validation images directory
test: {abs_path}/test/images    # test images

# Classes
names:
  0: negative
  1: positive
    """
    
    with open('data.yaml', 'w') as f:
        f.write(yaml_content)


In [None]:

# Get absolute path of current directory
current_dir = str(Path().absolute())

# Create dataset structure
prepare_dataset_structure()

# Initialize trainer
trainer = DetectionTrainer('data.yaml')

In [None]:
changed_augments=False
if changed_augments:
    for dir_path in [train_dir, val_dir]:
        if dir_path.exists():
            shutil.rmtree(dir_path)
    print("Preparing train/val splits and augmentations...")
    trainer.prepare_split_datasets('train', train_dir, val_dir, val_split=0.2)

In [None]:
# Create data.yaml with absolute paths
create_data_yaml(current_dir)

# Train model
trainer.train(epochs=40)


# Run tests and get results
test_results = trainer.test(conf_threshold=0.25, iou_threshold=0.45)
