In [42]:
import sys
import os
import torch
from ultralytics import YOLO
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import functional as F
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
import torchvision.transforms as T


class CocoaDataset(Dataset):
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, index):
        # Load image as PIL.Image
        image_path = self.image_files[index]
        image = Image.open(image_path).convert('RGB')  # Ensure 3 channels (RGB)

        # Attempt to load label
        label_path = os.path.join(self.label_dir, os.path.basename(image_path).replace('.jpg', '.txt'))
        if os.path.exists(label_path):
            boxes = []
            with open(label_path, 'r') as f:
                for line in f:
                    class_label, x_center, y_center, width, height = map(float, line.strip().split())
                    boxes.append([class_label, x_center, y_center, width, height])
            boxes = torch.tensor(boxes)
        else:
            boxes = torch.empty((0, 5))

        # Apply transformations
        if self.transforms:
            image = self.transforms(image)

        return image, boxes


# def collate_fn(batch):
#     images = []
#     targets = []

#     for i, (image, boxes) in enumerate(batch):
#         images.append(image)

#         if boxes.numel() > 0:  # If there are any boxes
#             # Add batch index to boxes
#             batch_indices = torch.full((boxes.size(0), 1), i, dtype=boxes.dtype)
#             boxes = torch.cat((batch_indices, boxes), dim=1)
        
#         targets.append(boxes)

#     # Stack images (images are already resized to the same size in the dataset)
#     images = torch.stack(images, dim=0)

#     # Concatenate all targets into a single tensor
#     targets = torch.cat(targets, dim=0) if len(targets) > 0 else torch.empty((0, 6))

#     return images, targets



def collate_fn(batch):
    images, targets = zip(*batch)

    # Stack images
    images = torch.stack(images, 0)

    # Adjust targets
    batched_targets = []
    for i, target in enumerate(targets):
        if len(target) > 0:
            target[:, 0] = i  # Assign batch index
            batched_targets.append(target)
    batched_targets = torch.cat(batched_targets, dim=0) if batched_targets else torch.empty((0, 6))

    return images, batched_targets




def train(model, dataloader, optimizer, device):
    model.model.train() # Set the model to training mode?
    last_loss = 0
    
    
    for images, targets in dataloader:
        images = images.to(device)
        targets = targets.to(torch.float32).to(device)  # Ensure targets are float32 and on the correct device

        # Skip batch if no targets
        if targets.numel() == 0:
            print("Skipping batch with no targets.")
            continue

        # Debugging shapes
        print(f"Images shape: {images.shape}")  # Should be [batch_size, 3, 640, 640]
        print(f"Targets shape: {targets.shape}")  # Should be [num_annotations, 6]

        # Forward pass
        try:
            loss, output = model(images, targets)
        except RuntimeError as e:
            print(f"Error during forward pass: {e}")
            print(f"Images: {images.shape}")
            print(f"Targets: {targets}")
            raise e
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        last_loss = loss.item()
    
    return last_loss
    
def validate(model, dataloader, device):
    model.model.eval() # Set the model to evaluation mode?
    total_loss = 0

    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device)
            targets = targets.to(device)
            
            # Forward pass
            loss, output = model(images, targets)

            total_loss += loss.item()
        
        average_loss = total_loss / len(dataloader)

    return average_loss   

In [44]:
# Define transforms
transforms = T.Compose([
    T.Resize((640, 640)),  # Resize all images to 640x640
    T.ToTensor()           # Convert images to PyTorch tensors
])


# Load training dataset
dataset_dir = '../datasets/cocoa_diseases'
train_dataset = CocoaDataset(
    image_dir = os.path.join(dataset_dir, 'images/train'),
    label_dir = os.path.join(dataset_dir, 'labels/train'),
    transforms = transforms
)

# Load validation dataset
val_dataset = CocoaDataset(
    image_dir = os.path.join(dataset_dir, 'images/val'),
    label_dir = os.path.join(dataset_dir, 'labels/val'),
    transforms = transforms
)


for images, targets in train_loader:
    images = images.to(device)
    targets = targets.to(torch.float32).to(device)

    print(f"Images shape: {images.shape}")
    print(f"Targets shape: {targets.shape}")
    print(f"Targets: {targets}")

    try:
        loss, output = model(images, targets)
        print(f"Loss: {loss}")
    except RuntimeError as e:
        print(f"Error during forward pass: {e}")
        break


OutOfMemoryError: CUDA out of memory. Tried to allocate 300.00 MiB (GPU 0; 14.57 GiB total capacity; 1.27 GiB already allocated; 204.75 MiB free; 1.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# Load dataloaders
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, batch_size = 64, shuffle = False, collate_fn=collate_fn)

# Initialize model, device, optimizer and other parameters
model = YOLO('../yolo11n.pt').to('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
num_epochs = 10

# Training and validation loop
for epoch in range(num_epochs):
    last_train_loss = train(model, train_loader, optimizer, device)
    average_val_loss = validate(model, val_loader, device)

    print(f'Epoch [{epoch+1}/{num_epochs}], Last Training Loss: {last_train_loss:.4f}, Validation Loss: {average_val_loss:.4f}')