Cell 1: Markdown

In [None]:
# YOLOv8 Conjunctiva Segmentation Training

This notebook trains a custom YOLOv8 model to segment conjunctiva from eye images using your real dataset.

## Dataset Overview
- **Training images**: 100+ eye images with conjunctiva annotations
- **Validation images**: 80+ eye images with conjunctiva annotations  
- **Class**: Single class (conjunctiva)
- **Format**: YOLO format with normalized coordinates

SyntaxError: invalid syntax (3237995433.py, line 3)

Cell 2: Setup and Imports

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import json
from tqdm import tqdm
import warnings
import yaml
import albumentations as A
from albumentations.pytorch import ToTensorV2
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(f"Using device: {device}")

PyTorch version: 2.8.0+cpu
CUDA available: False
Using device: cpu


Cell 3: Import YOLOv8 Model

In [None]:
# Import only the model classes that exist
from yolov8_simple import YOLOv8, YOLOConfig

print("Successfully imported YOLOv8 model components")

Successfully imported YOLOv8 model components


Cell 4: Dataset Configuration

In [None]:
# Load dataset configuration
with open('data.yaml', 'r') as f:
    dataset_config = yaml.safe_load(f)

print("Dataset Configuration:")
print(f"  - Training path: {dataset_config['train']}")
print(f"  - Validation path: {dataset_config['val']}")
print(f"  - Test path: {dataset_config['test']}")
print(f"  - Classes: {dataset_config['names']}")

# Count images in each split
train_images = len([f for f in os.listdir(dataset_config['train']) if f.endswith(('.jpg', '.jpeg', '.png'))])
valid_images = len([f for f in os.listdir(dataset_config['val']) if f.endswith(('.jpg', '.jpeg', '.png'))])
test_images = len([f for f in os.listdir(dataset_config['test']) if f.endswith(('.jpg', '.jpeg', '.png'))])

print(f"\nDataset Statistics:")
print(f"  - Training images: {train_images}")
print(f"  - Validation images: {valid_images}")
print(f"  - Test images: {test_images}")
print(f"  - Total images: {train_images + valid_images + test_images}")

Dataset Configuration:
  - Training path: train/images
  - Validation path: valid/images
  - Test path: test/images
  - Classes: {0: 'conjunctiva'}

Dataset Statistics:
  - Training images: 300
  - Validation images: 100
  - Test images: 16
  - Total images: 416


Cell 5: Data Augmentation Setup

In [None]:
def get_train_transforms(input_size=640):
    """Get training transforms with data augmentation"""
    return A.Compose([
        A.Resize(height=input_size, width=input_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.RandomBrightnessContrast(p=0.2),
        A.HueSaturationValue(p=0.2),
        A.RandomGamma(p=0.2),
        A.Blur(blur_limit=3, p=0.1),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

def get_val_transforms(input_size=640):
    """Get validation transforms (no augmentation)"""
    return A.Compose([
        A.Resize(height=input_size, width=input_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

print("Data augmentation configured:")
print("  - Training: Resize, Flip, Brightness/Contrast, Hue/Saturation, Gamma, Blur")
print("  - Validation: Resize only")

Data augmentation configured:
  - Training: Resize, Flip, Brightness/Contrast, Hue/Saturation, Gamma, Blur
  - Validation: Resize only


Cell 6: Custom Dataset Class

In [None]:
import torch
from torch.utils.data import Dataset
import os
import cv2
#You might need to install this "pip install albumentations"
import albumentations as A
from albumentations.pytorch import ToTensorV2

class ConjunctivaDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None, input_size=640):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.input_size = input_size
        
        # Get all image files
        self.image_files = [f for f in os.listdir(images_dir) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        print(f"Found {len(self.image_files)} images in {images_dir}")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_file = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_file)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load labels
        label_file = img_file.rsplit('.', 1)[0] + '.txt'
        label_path = os.path.join(self.labels_dir, label_file)
        
        bboxes = []
        class_labels = []
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        class_id = int(parts[0])
                        x_center = float(parts[1])
                        y_center = float(parts[2])
                        width = float(parts[3])
                        height = float(parts[4])
                        
                        bboxes.append([x_center, y_center, width, height])
                        class_labels.append(class_id)
        
        # Apply transforms
        if self.transform:
            transformed = self.transform(
                image=image,
                bboxes=bboxes,
                class_labels=class_labels
            )
            image = transformed['image']
            bboxes = transformed['bboxes']
            class_labels = transformed['class_labels']
        
        # Convert to tensor format for YOLO
        targets = []
        for bbox, class_id in zip(bboxes, class_labels):
            targets.append([class_id] + bbox)
        
        return image, targets
    
    def collate_fn(self, batch):
        images = []
        targets = []
        
        for img, target in batch:
            images.append(img)
            targets.extend([[len(images)-1] + t for t in target])
        
        images = torch.stack(images)
        targets = torch.tensor(targets, dtype=torch.float32)
        
        return images, targets

# Example transforms (you might have your own)
def get_train_transforms(input_size=640):
    return A.Compose([
        A.Resize(input_size, input_size),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

def get_val_transforms(input_size=640):
    return A.Compose([
        A.Resize(input_size, input_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

NameError: name 'Dataset' is not defined

Cell 7: Data Loaders

In [1]:
# Create data loaders
batch_size = 8  # Adjust based on your GPU memory
num_workers = 0  # Set to 0 for Windows, increase for Linux/Mac

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=train_dataset.collate_fn,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=val_dataset.collate_fn,
    pin_memory=True
)

print(f"Data loaders created:")
print(f"  - Training batch size: {batch_size}")
print(f"  - Validation batch size: {batch_size}")
print(f"  - Training batches per epoch: {len(train_loader)}")
print(f"  - Validation batches per epoch: {len(val_loader)}")

# Test data loading
print("\nTesting data loading...")
for batch_idx, (images, targets) in enumerate(train_loader):
    print(f"  - Batch {batch_idx + 1}: images shape {images.shape}, targets shape {targets.shape}")
    if batch_idx >= 2:  # Test first 3 batches
        break
print("Data loading test completed successfully!")

NameError: name 'DataLoader' is not defined

Cell 8: Loss Function Implementation

In [2]:
class YOLOLoss(nn.Module):
    def __init__(self, config):
        super(YOLOLoss, self).__init__()
        self.config = config
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.entropy = nn.CrossEntropyLoss()
        self.smooth_l1 = nn.SmoothL1Loss()
        
        # Constants
        self.lambda_class = 1
        self.lambda_noobj = 10
        self.lambda_obj = 1
        self.lambda_box = 10

    def forward(self, predictions, targets):
        obj = targets[..., 0] == 1  # in paper this is Iobj_i
        noobj = targets[..., 0] == 0  # in paper this is Inoobj_i

        # No object loss
        no_object_loss = self.bce(
            (predictions[..., 0:1][noobj]), (targets[..., 0:1][noobj]),
        )

        # Object loss
        object_loss = self.bce(
            (predictions[..., 0:1][obj]), (targets[..., 0:1][obj]),
        )

        # Box coordinate loss
        box_loss = self.smooth_l1(
            predictions[..., 1:5][obj], targets[..., 1:5][obj]
        )

        # Class loss
        class_loss = self.entropy(
            (predictions[..., 5:][obj]), (targets[..., 5][obj].long()),
        )

        return (
            self.lambda_box * box_loss
            + self.lambda_obj * object_loss
            + self.lambda_noobj * no_object_loss
            + self.lambda_class * class_loss
        )

print("YOLO Loss function implemented")

NameError: name 'nn' is not defined

In [None]:
Cell 9: Model Configuration

In [None]:
# Configuration for conjunctiva detection
config = YOLOConfig(
    num_classes=1,  # Single class: conjunctiva
    input_size=640,
    anchors=3,
    anchor_masks=[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
)

print("Model Configuration:")
print(f"  - Number of classes: {config.num_classes}")
print(f"  - Input size: {config.input_size}")
print(f"  - Number of anchors: {config.anchors}")

# Initialize model
model = YOLOv8(config)
model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Summary:")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")

Cell 10: Training Setup

In [None]:
# Training hyperparameters
learning_rate = 0.001
num_epochs = 50
save_dir = 'trained_models'

# Create save directory
os.makedirs(save_dir, exist_ok=True)

# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
loss_fn = YOLOLoss(config)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

print("Training Configuration:")
print(f"  - Learning rate: {learning_rate}")
print(f"  - Number of epochs: {num_epochs}")
print(f"  - Save directory: {save_dir}")
print(f"  - Optimizer: Adam with weight decay")
print(f"  - Loss function: YOLOLoss")
print(f"  - Scheduler: CosineAnnealingLR")

Cell 11: Training Functions

In [None]:
def train_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, (images, targets) in enumerate(progress_bar):
        images = images.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        loss = loss_fn(outputs, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Avg Loss': f'{total_loss/num_batches:.4f}'
        })
    
    return total_loss / num_batches

def validate_epoch(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validation")
        
        for batch_idx, (images, targets) in enumerate(progress_bar):
            images = images.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate loss
            loss = loss_fn(outputs, targets)
            
            total_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{total_loss/num_batches:.4f}'
            })
    
    return total_loss / num_batches

print("Training functions defined")

Cell 12: Training Loop

In [None]:
print("Starting training...")

train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 50)
    
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device)
    train_losses.append(train_loss)
    
    # Validation
    val_loss = validate_epoch(model, val_loader, loss_fn, device)
    val_losses.append(val_loss)
    
    # Update learning rate
    scheduler.step()
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_path = os.path.join(save_dir, 'yolov8_conjunctiva_best.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config,
            'best_val_loss': best_val_loss
        }, best_model_path)
        print(f"New best model saved: {best_model_path}")
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(save_dir, f'yolov8_conjunctiva_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config
        }, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

print("\nTraining completed!")

# Save final model
final_model_path = os.path.join(save_dir, 'yolov8_conjunctiva_final.pth')
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
    'input_size': config.input_size,
    'num_classes': config.num_classes,
    'final_train_loss': train_losses[-1],
    'final_val_loss': val_losses[-1]
}, final_model_path)

print(f"Final model saved to: {final_model_path}")

Cell 13: Training Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Val Loss', color='red')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_losses, label='Train Loss', color='blue')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(val_losses, label='Val Loss', color='red')
plt.title('Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Print final metrics
print(f"Final Training Loss: {train_losses[-1]:.4f}")
print(f"Final Validation Loss: {val_losses[-1]:.4f}")
print(f"Best Validation Loss: {best_val_loss:.4f}")

Cell 14: Simple Inference Function

In [None]:
def simple_detect(model, image_path, config, device, conf_threshold=0.5):
    """Simple inference function for testing"""
    # Load and preprocess image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Resize
    h, w = image.shape[:2]
    resized = cv2.resize(image, (config.input_size, config.input_size))
    
    # Normalize and convert to tensor
    resized = resized.astype(np.float32) / 255.0
    resized = (resized - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
    resized = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Inference
    model.eval()
    with torch.no_grad():
        predictions = model(resized)
    
    # Simple post-processing (this is a simplified version)
    # In a real implementation, you'd need proper NMS and anchor handling
    predictions = predictions.squeeze()
    
    # For now, just return the raw predictions
    return predictions.cpu().numpy()

print("Simple inference function defined")

Cell 15: Model Testing

In [None]:
# Load best model for testing
best_model_path = os.path.join(save_dir, 'yolov8_conjunctiva_best.pth')
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']}")
    print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")

# Test on validation images
val_images_dir = dataset_config['val']
val_image_files = [f for f in os.listdir(val_images_dir) 
                   if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

# Test on first 6 validation images
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

for i, img_file in enumerate(val_image_files[:6]):
    img_path = os.path.join(val_images_dir, img_file)
    
    # Load and display original image
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Perform simple inference
    predictions = simple_detect(model, img_path, config, device)
    
    axes[i].imshow(img)
    axes[i].set_title(f"{img_file}\nPredictions shape: {predictions.shape}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print("Model testing completed!")

Cell 16: Save Model Info

In [None]:
# Save model information
model_info = {
    'model_path': final_model_path,
    'best_model_path': best_model_path,
    'input_size': config.input_size,
    'num_classes': config.num_classes,
    'device': str(device),
    'total_params': sum(p.numel() for p in model.parameters()),
    'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
    'dataset_config': dataset_config,
    'training_config': {
        'learning_rate': learning_rate,
        'num_epochs': num_epochs,
        'batch_size': batch_size,
        'optimizer': 'Adam',
        'scheduler': 'CosineAnnealingLR'
    },
    'final_metrics': {
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'best_val_loss': best_val_loss
    }
}

with open(os.path.join(save_dir, 'model_info.json'), 'w') as f:
    json.dump(model_info, f, indent=2)

print("Model information saved to model_info.json")
print("\nTraining completed successfully!")
print(f"Best model saved at: {best_model_path}")
print(f"Final model saved at: {final_model_path}")