In [None]:
# Load the data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

ROOT_DIR = "dataset//"
batch_size = 64

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
    transforms.RandomErasing(p = 0.25),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

train_ds = datasets.ImageFolder(f"{ROOT_DIR}train", transform = train_transforms)
val_ds = datasets.ImageFolder(f"{ROOT_DIR}val", transform = val_transforms)

train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = True, drop_last = False)
val_dl = DataLoader(val_ds, batch_size = batch_size, shuffle = True, drop_last = False)

In [None]:
# Take a look at the dataset
from torch.utils.tensorboard import SummaryWriter

print(train_ds.class_to_idx)
print(train_ds.classes)

writer = SummaryWriter("logs")
for i in range(10):
    train_img, target = train_ds[i]
    writer.add_image("train_ds", train_img, i)
    val_img, target = val_ds[i]
    writer.add_image("val_ds", val_img, i)

imgs, target = train_dl[0]
writer.add_images("train_dl", imgs, 0)
imgs, target = val_dl[0]
writer.add_images("val_dl", imgs, 0)

writer.close()