In [25]:
import os
import torch
import imageio
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from PIL import Image


In [26]:
##! pip install imageio

In [27]:
##! pip install torchvision

In [28]:
import os

class PokemonDataset(Dataset):
    def __init__(self, dataset_save="pokemon_data.pt", raw_data=None, train=True, shuffle=False,
                 transform=None, target_transform=None, convert=False, size=32):
        self.targets = []
        self.labels = []
        self.data = []

        self.X_train = []
        self.X_test = []
        self.y_train = []
        self.y_test = []

        self.transform = transform
        self.target_transform = target_transform

        if convert:
            self.convert(dataset_save, raw_data, size)
        else:
            self.load(dataset_save)

        seed = 42 if not shuffle else int(torch.randint(0, 100, (1,)).item())
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            self.data, self.targets, test_size=0.2, random_state=seed)

        if train:
            self.data = self.X_train
            self.targets = self.y_train
        else:
            self.data = self.X_test
            self.targets = self.y_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image, target = self.data[index], self.targets[index]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)
        return image, target

    def convert(self, dataset_save, raw_data, size):
        dataset = []
        self.labels = []
        targets = []

        for folder in os.listdir(raw_data):
            folder_path = os.path.join(raw_data, folder)
            if not os.path.isdir(folder_path):
                continue

            for file in os.listdir(folder_path):
                if file.lower().endswith((".jpg", ".jpeg", ".png")):
                    if folder not in self.labels:
                        self.labels.append(folder)
                    label_idx = self.labels.index(folder)

                    img_path = os.path.join(folder_path, file)
                    try:
                        img = Image.open(img_path).convert("RGB")
                        img = img.resize((size, size))
                        img_tensor = transforms.ToTensor()(img)
                        dataset.append(img_tensor)
                        targets.append(label_idx)
                    except Exception as e:
                        print(f"Skipped {img_path} due to error: {e}")

        self.data = torch.stack(dataset)
        self.targets = torch.tensor(targets, dtype=torch.long)
        torch.save((self.data, self.targets, self.labels), dataset_save)

    def load(self, dataset_save):
        self.data, self.targets, self.labels = torch.load(dataset_save)

In [29]:
print("Dataset class loaded")

Dataset class loaded
