In [11]:
import os
import pandas as pd
from pathlib import Path
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch

# ------------------------
# Configuration
# ------------------------
DATA_DIR = Path("data\\raw")  # Change this if your dataset folder is elsewhere
SPLITS = ["train", "valid", "test"]  # Dataset splits
IMAGE_EXT = ".jpg"  # Image file extension


# ------------------------
# Function to process CSV and create mapping
# ------------------------
def process_classes_csv(csv_path):
    """
    Reads a classes.csv file and returns a DataFrame with only labeled images and a single 'label' column.
    """
    df = pd.read_csv(csv_path)
    # Remove unlabeled images
    df = df[df[" Unlabeled"] == False].copy()
    
    # Convert multi-label boolean columns to a single class label
    posture_cols = [" backwardbadposture", " forwardbadposture", " goodposture"]
    
    def determine_label(row):
        for col in posture_cols:
            if row[col]:
                return col
        return None  # fallback if no label found
    
    df["label"] = df.apply(determine_label, axis=1)
    df = df[["filename", "label"]].dropna()
    return df


# ------------------------
# PyTorch Dataset
# ------------------------
class PostureDataset(Dataset):
    def __init__(self, images_dir, csv_file, transform=None):
        self.images_dir = Path(images_dir)
        self.data = process_classes_csv(csv_file)
        self.transform = transform
        
        # Encode labels as integers
        self.label_encoder = LabelEncoder()
        self.data["label_idx"] = self.label_encoder.fit_transform(self.data["label"])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = self.images_dir / row["filename"]
        image = Image.open(img_path).convert("RGB")
        label = row["label_idx"]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


# ------------------------
# Example usage
# ------------------------
if __name__ == "__main__":
    from torchvision import transforms

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    datasets = {}
    dataloaders = {}

    for split in SPLITS:
        split_dir = DATA_DIR / split
        csv_file = split_dir / "_classes.csv"
        ds = PostureDataset(split_dir, csv_file, transform=transform)
        dl = DataLoader(ds, batch_size=32, shuffle=(split=="train"))
        datasets[split] = ds
        dataloaders[split] = dl

    # Example: iterate through one batch
    for images, labels in dataloaders["train"]:
        print(images.shape, labels.shape)
        break

torch.Size([32, 3, 224, 224]) torch.Size([32])


YOLOv8 classification dataset created successfully.
