## Goal
- Implement classification with resnet18 in tinyImagenet dataset using d2l library
    - Wrap tiny imagenet data under d2l library
    - Fit with GPU
    - Record validation losses, accuracy, time taken to train

In [7]:
import torch
import time
import d2l.torch as d2l
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from datasets import load_dataset

In [None]:
batch_size = 128
num_workers = 2
learning_rate = 0.01
num_epochs = 10
device = d2l.try_gpu()
device

device(type='cuda', index=0)

In [12]:
class TinyImagenetTorch(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]["image"], self.dataset[idx]["label"]
        x = x.convert("RGB")
        if self.transform:
            x = self.transform(x)
        y = torch.tensor(y, dtype=torch.int64)
        return x, y

class TinyImagenetD2l(d2l.Module):
    def __init__(self, batch_size, num_workers):
        super().__init__()
        self.save_hyperparameters()
        self.train_data = load_dataset("Maysee/tiny-imagenet", split="train")
        self.val_data = load_dataset("Maysee/tiny-imagenet", split="valid")
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ])

    def get_dataloader(self, train):
        data = self.train_data if train else self.val_data
        return torch.utils.data.DataLoader(
            TinyImagenetTorch(data, self.transform), 
            batch_size=self.batch_size, shuffle=train)
    

tiny_imagenet = TinyImagenetD2l(batch_size, num_workers)
for x, y in tiny_imagenet.get_dataloader(train=True):
    print(x.shape, y.shape)
    break

        


torch.Size([32, 3, 64, 64]) torch.Size([32])


In [None]:
val_data

In [28]:
for batch in val_data:
    print(batch)
    break

{'image': tensor([[[247, 248, 250,  ..., 254, 254, 254],
         [251, 250, 250,  ..., 254, 254, 254],
         [254, 254, 252,  ..., 253, 253, 253],
         ...,
         [198, 238, 255,  ..., 222, 221, 217],
         [245, 247, 230,  ..., 249, 245, 240],
         [242, 238, 245,  ..., 254, 252, 252]],

        [[253, 254, 254,  ..., 254, 254, 254],
         [255, 254, 254,  ..., 254, 254, 254],
         [255, 255, 254,  ..., 253, 253, 253],
         ...,
         [200, 240, 255,  ..., 222, 221, 217],
         [255, 255, 241,  ..., 249, 248, 243],
         [255, 255, 255,  ..., 254, 255, 255]],

        [[251, 252, 253,  ..., 254, 254, 254],
         [254, 253, 253,  ..., 254, 254, 254],
         [255, 255, 253,  ..., 253, 253, 253],
         ...,
         [199, 239, 255,  ..., 230, 231, 227],
         [255, 255, 237,  ..., 255, 255, 252],
         [255, 253, 255,  ..., 255, 255, 255]]], dtype=torch.uint8), 'label': tensor(0), 'data': tensor([0])}
