In [4]:
import torch
import torchvision
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from pathlib import Path

# Device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"

### Dataloader

In [42]:
class CatsAndDogsDataset(Dataset):

    def __init__(self, target_directory, transform=None):
        self.paths = list(Path(target_directory).glob('*.*.jpg'))
        self.transform = transform
        self.classes = sorted(list(set(map(self.get_label, self.paths))))

    @staticmethod
    def get_label(path):
        filename = str(path.name)
        label = filename.split('.')[0]
        return label

    def load_image(self, index):
        image_path = self.paths[index]
        return Image.open(image_path)
    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = self.load_image(index)
        class_name = self.get_label(self.paths[index])
        class_idx = self.classes.index(class_name)

        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx


In [41]:
train_transforms = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor()
])

dataset = CatsAndDogsDataset('./train', transform=train_transforms)
p = Path('./train/cat.0.jpg')
print(p.name)
CatsAndDogsDataset.get_label(p)

cat.0.jpg


'cat'