In [204]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torchvision.datasets
from torchvision import transforms
from PIL import Image

In [205]:
transform = transforms.Compose([
    transforms.Resize(256),             
    transforms.CenterCrop(224),        
    transforms.ToTensor(),             
    transforms.Normalize(              
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

In [206]:
class CSVImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = 'data/images/' + self.df.iloc[idx]
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        return image

In [207]:
data = pd.read_csv('data/train.csv')
X = data['image']
y = data['class']

In [208]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1000)

In [209]:
X_train = CSVImageDataset(X_train, transform=transform)
X_test = CSVImageDataset(X_test, transform=transform)