In [10]:
import os
import pandas as pd
import random
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Define transformations for resizing and normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
     # Standard normalization
])

# Function to get 10 random images from each class
def get_random_images_per_class(csv_path, root_dir, transform, num_samples=10):
    # Load the CSV file
    data = pd.read_csv(csv_path)

    # Group data by class
    grouped = data.groupby('category')  # Updated to match the 'category' column

    # Store selected images
    images_by_class = {}

    for label, group in grouped:
        # Randomly sample `num_samples` rows
        sampled = group.sample(n=num_samples, replace=True if len(group) < num_samples else False)
        images_by_class[label] = []

        for _, row in sampled.iterrows():
            img_path = os.path.join(root_dir, row['image:FILE'])  # Updated to match 'image:FILE' column
            image = Image.open(img_path).convert("RGB")
            if transform:
                image = transform(image)
            images_by_class[label].append(image)

    return images_by_class

# Plot the images


# Usage
csv_path = "dataset/train.csv"  # Path to your training CSV file
root_dir = "dataset/"  # Root directory containing the images
images_by_class = get_random_images_per_class(csv_path, root_dir, transform, num_samples=10)
print(images_by_class)



{0: [tensor([[[0.5686, 0.5373, 0.4863,  ..., 0.4588, 0.4392, 0.4275],
         [0.5569, 0.5294, 0.4824,  ..., 0.4471, 0.4235, 0.4118],
         [0.5373, 0.5137, 0.4784,  ..., 0.4275, 0.4078, 0.3922],
         ...,
         [0.1412, 0.1412, 0.1451,  ..., 0.5765, 0.5922, 0.5843],
         [0.1529, 0.1529, 0.1569,  ..., 0.6039, 0.6039, 0.5804],
         [0.1608, 0.1608, 0.1608,  ..., 0.6667, 0.6667, 0.6431]],

        [[0.4627, 0.4353, 0.3961,  ..., 0.5490, 0.5294, 0.5176],
         [0.4510, 0.4275, 0.3922,  ..., 0.5373, 0.5137, 0.5020],
         [0.4314, 0.4118, 0.3843,  ..., 0.5216, 0.4980, 0.4863],
         ...,
         [0.1216, 0.1255, 0.1216,  ..., 0.5373, 0.5529, 0.5451],
         [0.1255, 0.1255, 0.1255,  ..., 0.5647, 0.5647, 0.5412],
         [0.1216, 0.1216, 0.1216,  ..., 0.6275, 0.6275, 0.6039]],

        [[0.3804, 0.3569, 0.3216,  ..., 0.1255, 0.1137, 0.1059],
         [0.3686, 0.3490, 0.3137,  ..., 0.1137, 0.0980, 0.0902],
         [0.3490, 0.3333, 0.3059,  ..., 0.0980, 0.086

In [11]:
import pandas as pd

# Load the CSV file
data = pd.read_csv("dataset/train.csv")

# Print column names
print(data.columns)


Index(['image:FILE', 'category'], dtype='object')
