<a href="https://colab.research.google.com/github/aniketmaurya/talks/blob/main/2021-06-27%20PyTorch%20Lightning/02%20flash-intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U 'lightning-flash[image]' -q

In [None]:
import warnings

warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import flash
import pytorch_lightning as pl

In [None]:
from flash import Trainer
from flash.image import ImageClassifier

In [None]:
# functions to show an image
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [None]:
class CIFARDataModule(pl.LightningDataModule):

    def __init__(self, batch_size: int = 64):
        super().__init__()
        self.batch_size = batch_size
        
        self.train_data = torchvision.datasets.CIFAR10(
            '/Users/aniket/datasets/', download=True, transform=transform)

        self.val_data = torchvision.datasets.CIFAR10('/Users/aniket/datasets/',
                                                     train=False,
                                                     download=True,
                                                     transform=transform)
        

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_data,
                                           batch_size=self.batch_size,
                                           shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_data,
                                           batch_size=self.batch_size,
                                           shuffle=False)

    def on_after_batch_transfer(self, batch, dataloader_idx):
        data = {'input': batch[0], 'target': batch[1]}
        return data
    
cifar_dm = CIFARDataModule(batch_size=64)

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

num_classes = len(classes)
backbone = 'resnet18'

model = ImageClassifier(num_classes, backbone)

trainer = flash.Trainer(max_epochs=5, gpus=1)
trainer.fit(model, datamodule=cifar_dm)

## Prediction

In [None]:
data = next(iter(cifar_dm.train_dataloader()))

In [None]:


batch_size = 4
images, labels = data[0][:batch_size], data[1][:batch_size]


# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

In [None]:
labels

In [None]:
torch.argmax(torch.nn.Softmax()(model(images)), 1)