## Goal
- Implement classification with resnet18 in tinyImagenet dataset using d2l library
    - Wrap tiny imagenet data under d2l library
    - Fit with GPU
    - Record validation losses, accuracy, time taken to train

In [None]:
import torch
import time
import d2l.torch as d2l
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms import transforms
from torchvision.models import resnet18
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, Image
import json
import requests

In [None]:
batch_size = 256
num_workers = 2
learning_rate = 0.01
num_epochs = 2
device = d2l.try_gpu()
device

In [None]:
class TinyImagenetD2l(d2l.Module):
    def __init__(self, batch_size, num_workers):
        super().__init__()
        self.save_hyperparameters()
        self.train_data = load_dataset("Maysee/tiny-imagenet", split="train")
        self.val_data = load_dataset("Maysee/tiny-imagenet", split="valid")
        self.transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor()
        ])
        self.get_imagenet_labels()

        # get number of classes
        self.num_classes = self.train_data.features['label'].num_classes

    def get_imagenet_labels(self):
        response = requests.get("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json")
        imagenet_index = json.loads(response.text)
        self.imagenet_reverse_index = { v[0]:v[1] for k, v in imagenet_index.items()}

    def transforms(self, batch):
        batch['image'] = [self.transform(x.convert("RGB")) for x in batch['image']]
        batch['label'] = torch.tensor(batch['label'])
        return batch

    def get_dataloader(self, train):
        data = self.train_data if train else self.val_data
        data.set_transform(self.transforms)
        dataloader = DataLoader(data, batch_size=self.batch_size, shuffle=train)
        return dataloader
    
    def train_dataloader(self):
        return self.get_dataloader(train=True)
    
    def val_dataloader(self):
        return self.get_dataloader(train=False)

    def visualize(self, batch, max_images=16):
        images = batch['image'][:max_images].permute(0, 2, 3, 1)
        labels = batch['label'][:max_images].unsqueeze(1)
        labels = self.train_data.features['label'].int2str(labels)
        
        labels = [self.imagenet_reverse_index[label] if label in self.imagenet_reverse_index.keys() else label for label in labels]

        d2l.show_images(images, 4, 4, titles=labels)
        
    

tiny_imagenet = TinyImagenetD2l(batch_size, num_workers)
train_loader = tiny_imagenet.get_dataloader(train=True)
val_loader = tiny_imagenet.get_dataloader(train=False)

tiny_imagenet.visualize(next(iter(train_loader)))

        


In [None]:
class ResnetD2l(d2l.Classifier):
    def __init__(self, num_classes, pretrained=False, lr=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.net = resnet18(pretrained=pretrained)
        self.net.fc = nn.Linear(512, num_classes)

        def init_weights(m):
            if type(m) == nn.Linear or type(m) == nn.Conv2d:
                nn.init.xavier_uniform_(m.weight)
        self.net.apply(init_weights)
    
    def forward(self, x):
        return self.net(x)
    
    def loss(self, y_hat, y):
        return nn.CrossEntropyLoss()(y_hat, y)

In [None]:
@d2l.add_to_class(d2l.Trainer)
def prepare_batch(self, batch):
    x = batch['image']
    y = batch['label']
    return (x.to(self.device), y.to(self.device))

In [None]:
data = TinyImagenetD2l(batch_size, num_workers)
model = ResnetD2l(num_classes=data.num_classes, pretrained=False, lr=learning_rate)
model.to(device)
trainer = d2l.Trainer(max_epochs=num_epochs, num_gpus=1)
trainer.device = device
trainer.fit(model=model, data=data)


In [None]:
model = ResnetD2l(num_classes=data.num_classes, pretrained=True, lr=learning_rate)
model.to(device)
trainer = d2l.Trainer(max_epochs=num_epochs, num_gpus=1)
trainer.device = device
trainer.fit(model=model, data=data)