In [50]:
import matplotlib.pyplot as plt
import torchvision.datasets as dts
from torchvision.transforms import ToTensor
from torchvision import transforms as transforms
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pathlib import Path
from PIL import Image
from typing import Tuple, List
import os
import random
from torch.utils.data import DataLoader




#all imports required

In [51]:

class CatsDogsDataset(Dataset):
    def __init__(self, root_dir: str, transform: transforms.Compose = None) -> None:
        super().__init__()
        self.root_dir = root_dir
        self.path = list(Path(self.root_dir).glob("**/*.jpg"))
        self.transform = transform
        self.classes, self.classes_to_idx = self.__find_classes()

    def __getitem__(self, index) -> Tuple[Image.Image, int]:
        image = self.__load_image(index)
        class_name = self.__get_class(index)
        class_idx = self.classes_to_idx[class_name]
        
        if self.transform:
            image = self.transform(image)

        return image, class_idx

    def __len__(self) -> int:
        return len(self.path)

    def __find_classes(self) -> Tuple[List[str], dict]:
        classes = sorted(entry.name for entry in os.scandir(self.root_dir) if entry.is_dir())
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx
    
    def __load_image(self, index: int) -> Image.Image:
        return Image.open(self.path[index]).convert("RGB")
    
    def __get_class(self, index: int) -> str:
        return self.path[index].parent.name
    
    
    train_transforms = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(std=(0.5,0.5,0.5),mean=(0.5,0.5,0.5))
    ])

    
    test_transform= transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()
    ])




In [49]:

train_cd = CatsDogsDataset(root_dir="data/train", transform= train_transforms)
test_cd = CatsDogsDataset(root_dir="data/test", transform = test_transforms)

#len(train_cd), len(test_cd)

(13338, 11662)

In [56]:
train_loader = torch.utils.data.DataLoader(train_cd, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_cd, batch_size=32, shuffle=False)
data_iterator = iter(train_loader)
images, labels = next(data_iterator)

image = random.choice(images)



Train Loader: <torch.utils.data.dataloader.DataLoader object at 0x000001ADCF6AD890> | Size: 417
Test Loader: <torch.utils.data.dataloader.DataLoader object at 0x000001ADCF6AC550> | Size: 365
