# Recognition   

In [6]:
import os
import glob
import random
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from torchvision import transforms
from models.detection import Model, ModelTrainer

In [19]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label_path = img_path.replace('images', 'labels').replace('.jpg', '.txt')
        
        with open(label_path, 'r') as f:
            label = list(map(float, f.readline().strip().split()))

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label)

def prepare_dataloaders(data_dirs, batch_size=32, valid_split=0.1, test_split=0.1):
    all_images = []
    for data_dir in data_dirs:
        all_images += glob.glob(os.path.join(data_dir, 'imgs_and_labels', 'images', '*.jpg'))

    # Filter images to include only those that have corresponding labels
    all_images = [img for img in all_images if os.path.exists(img.replace('images', 'labels').replace('.jpg', '.txt'))]

    random.shuffle(all_images)
    total_images = len(all_images)
    test_size = int(total_images * test_split)
    valid_size = int(total_images * valid_split)
    train_size = total_images - test_size - valid_size

    train_images = all_images[:train_size]
    valid_images = all_images[train_size:train_size + valid_size]
    test_images = all_images[train_size + valid_size:]

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    train_dataset = CustomDataset(train_images, transform=transform)
    valid_dataset = CustomDataset(valid_images, transform=transform)
    test_dataset = CustomDataset(test_images, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, test_loader

In [20]:
# Define the data directories
data_dirs = [
    'data/real_1geo_bright_512867',
    'data/real_1geo_onlyflash_512875',
    'data/real_2geos_bright_512830',
    'data/real_2geos_onlyflash_512887',
    'data/real_3geos_bright_512712',
    'data/real_4geos_bright_512639',
    'data/real_4geos_onlyflash_512894'
]

In [21]:
train_loader, valid_loader, test_loader = prepare_dataloaders(data_dirs, batch_size=32)

In [22]:
# Assuming train_dataset and val_dataset are already defined PyTorch datasets
model = Model(num_channels=3, num_classes=4)
trainer = ModelTrainer(model, train_loader, val_dataset=valid_loader)
trainer.train(num_epochs=20)

TypeError: 'DataLoader' object is not subscriptable