In [None]:
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

class CaptchaDatasetRGB(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            data_dir (string): Direct path to the specific data folder (train, val, or test)
                               e.g., '/path/to/UTN-CV25-Captcha-Dataset/part2/train'
            transform (callable, optional): Optional transform to be applied on images
        """
        self.data_dir = data_dir
        self.images_dir = os.path.join(self.data_dir, 'images')
        self.transform = transform
        self.image_list = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])
        
        # Load labels if available
        self.labels_dict = {}
        labels_file = os.path.join(self.data_dir, 'labels.json')
        if os.path.exists(labels_file):
            with open(labels_file, 'r') as f:
                labels = json.load(f)
                # Create a dictionary for faster lookup by image_id
                self.labels_dict = {item['image_id']: item for item in labels}
    
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = self.image_list[idx]
        img_path = os.path.join(self.images_dir, img_name)
        
        # Load image as RGB
        image = Image.open(img_path).convert('RGB')  # Convert to RGB instead of 'L'
        
        # Get image_id without extension
        image_id = os.path.splitext(img_name)[0]
        
        # Get labels if available
        label_info = self.labels_dict.get(image_id, {})
        
        # Extract captcha string and bounding boxes
        captcha_string = label_info.get('captcha_string', '')
        annotations = label_info.get('annotations', [])
        
        # Apply transforms if specified
        if self.transform:
            image = self.transform(image)
        
        sample = {
            'image': image,
            'image_id': image_id,
            'captcha_string': captcha_string,
            'annotations': annotations
        }
        
        return sample

# Custom collate function for RGB images
def custom_collate_fn_rgb(batch):
    """Custom collate function to handle variable-length annotations for RGB images"""
    images = torch.stack([item['image'] for item in batch])
    image_ids = [item['image_id'] for item in batch]
    captcha_strings = [item['captcha_string'] for item in batch]
    annotations = [item['annotations'] for item in batch]  # Keep as list of lists
    
    return {
        'image': images,
        'image_id': image_ids,
        'captcha_string': captcha_strings,
        'annotations': annotations
    }

# Helper function to create RGB dataloaders
def get_dataloader_rgb(data_folder, batch_size=32, shuffle=True):
    """
    Args:
        data_folder (string): Direct path to specific data folder
                            e.g. '/path/to/UTN-CV25-Captcha-Dataset/part2/train'
        batch_size (int): Batch size for the dataloader
        shuffle (bool): Whether to shuffle the data
    """
    # Define transformations for RGB images
    transform = transforms.Compose([
        # transforms.Resize((160, 640)),  # Keep original size
        transforms.ToTensor(),         # Convert to tensor [0, 1]
        transforms.Normalize(
            mean=[0.5, 0.5, 0.5],     # RGB channels normalization to [-1, 1]
            std=[0.5, 0.5, 0.5]
        )
    ])
    
    # Create dataset
    dataset = CaptchaDatasetRGB(
        data_dir=data_folder,
        transform=transform
    )
    
    # Create dataloader with custom collate function
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=0,
        collate_fn=custom_collate_fn_rgb
    )
    
    return dataloader

# Example usage and visualization
def visualize_rgb_sample(batch):
    """Visualize RGB CAPTCHA sample"""
    # Get first image in batch
    img = batch['image'][0]  # Shape: [3, 160, 640]
    
    # Denormalize from [-1, 1] to [0, 1]
    img = img * 0.5 + 0.5
    
    # Convert from [C, H, W] to [H, W, C] for matplotlib
    img = img.permute(1, 2, 0)
    
    plt.figure(figsize=(12, 4))
    plt.imshow(img)
    plt.title(f"RGB CAPTCHA: {batch['captcha_string'][0]}")
    plt.axis('off')
    plt.show()
    
    print(f"Image shape: {batch['image'].shape}")  # Should be [batch_size, 3, 160, 640]
    print(f"Image value range: [{batch['image'].min():.3f}, {batch['image'].max():.3f}]")

In [None]:
import os

# Base path to the dataset
base_path = '/home/utn/omul36yx/git/UTN-CAPTCHASOLVER/UTN-CV25-Captcha-Dataset/part2'

# Create RGB dataloaders for each split
train_loader_rgb = get_dataloader_rgb(os.path.join(base_path, 'train'), batch_size=32, shuffle=True)
val_loader_rgb = get_dataloader_rgb(os.path.join(base_path, 'val'), batch_size=32, shuffle=False)
test_loader_rgb = get_dataloader_rgb(os.path.join(base_path, 'test'), batch_size=32, shuffle=False)

# Print dataset sizes
print(f"RGB Training samples: {len(train_loader_rgb.dataset)}")
print(f"RGB Validation samples: {len(val_loader_rgb.dataset)}")
print(f"RGB Test samples: {len(test_loader_rgb.dataset)}")

# Display a sample RGB image
for batch in train_loader_rgb:
    visualize_rgb_sample(batch)
    print(f"Number of annotations for first image: {len(batch['annotations'][0])}")
    break