In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os

# Custom Dataset Class for Crop Disease Detection
class CropDiseaseDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        Args:
            csv_file (str): Path to the CSV file with image annotations.
            img_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform or A.Compose([
            A.Resize(256, 256),          # Resize for faster training, adjust as needed
            A.HorizontalFlip(p=0.5),     # Random horizontal flip
            A.RandomBrightnessContrast(p=0.2),  # Random brightness/contrast adjustments
            ToTensorV2()                 # Convert image to PyTorch tensor
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Image file path
        img_path = os.path.join(self.img_dir, self.data.iloc[idx]['Image_ID'])
        
        # Load image and convert to RGB
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found at path: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load bounding box and class label from CSV
        xmin = self.data.iloc[idx]['xmin']
        ymin = self.data.iloc[idx]['ymin']
        xmax = self.data.iloc[idx]['xmax']
        ymax = self.data.iloc[idx]['ymax']
        label = self.data.iloc[idx]['class_id']

        # Create target dictionary with bounding box and class label
        target = {
            'boxes': torch.tensor([[xmin, ymin, xmax, ymax]], dtype=torch.float32),
            'labels': torch.tensor([label], dtype=torch.int64)
        }

        # Apply transformations
        transformed = self.transform(image=image)
        image = transformed['image']

        return image, target

# Define transformations with albumentations for data augmentation
transform = A.Compose([
    A.Resize(256, 256),                   # Resize for consistency and speed
    A.HorizontalFlip(p=0.5),              # Augmentation: random horizontal flip
    A.RandomBrightnessContrast(p=0.2),    # Augmentation: brightness/contrast
    ToTensorV2()                          # Convert to PyTorch tensor
])

# Initialize the dataset with the custom DataLoader
train_dataset = CropDiseaseDataset(csv_file='Train.csv', img_dir='datasets/train/images', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)

# Example usage: Iterate through the DataLoader
for images, targets in train_loader:
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of targets: {targets}")
    break  # Just loading one batch for demonstration
