In [3]:
import torch
from torch import nn

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor # converts PIL image or numpy array into tensors

import matplotlib.pyplot as plt
import numpy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(device)

2.2.2
cpu


# TinyVGG Model

This is the model based off of the tinyvgg architecture

In [2]:
class TinyVGG(nn.Module):
  def __init__(self,
               input_shape: int,
               hidden_units: int,
               output_shape: int):
    super().__init__()
    self.conv_block_1 = nn.Sequential(nn.Conv2d(in_channels=input_shape,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=1),
                                      nn.ReLU(),
                                      nn.Conv2d(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=1),
                                      nn.ReLU(),
                                      nn.MaxPool2d(kernel_size=2,
                                                   stride=2) # default stride is same as kernel size
                                      )
    self.conv_block_2 = nn.Sequential(nn.Conv2d(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=1),
                                      nn.ReLU(),
                                      nn.Conv2d(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=1),
                                      nn.ReLU(),
                                      nn.MaxPool2d(kernel_size=2,
                                                   stride=2)
                                      )
    self.classifier = nn.Sequential(nn.Flatten(),
                                      nn.Linear(in_features=hidden_units, # get from errors
                                                out_features=output_shape))

    def forward(self,x):
      x = self.conv_block_1(x)
      print(x.shape)
      x = self.conv_block_2(x)
      print(x.shape)
      x = self.classifier(x)
      print(x.shape)
      return x # writing this in one line would benefit from operator fusion

------------------
# Dataloader
1000 images per class,

In [4]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
data_transform = transforms.Compose([transforms.Resize(size=(64, 64)), # resize to 64x64 so we can use tinyVGG architecture
                                     transforms.RandomHorizontalFlip(p=0.5), # randomly flips an image horizontally
                                     transforms.ToTensor()]) # transforms to tensor
data_transform(img)

In [None]:
def plot_transformed_images(image_paths, transform, n=3, seed=42):
  """
  selects random images from a path of images, and loads/transforms them and plots the original vs transformed version
  """
  if seed:
    random.seed(seed)
  random_image_paths = random.sample(image_paths, k=n)
  for image_path in random_image_paths:
    with Image.open(image_path) as f:
      fig, ax = plt.subplots(nrows=1, ncols=2)
      ax[0].imshow(f)
      ax[0].set_title(f'Original\nSize: {f.size}')
      ax[0].axis(False)

      transformed_image = transform(f).permute(1, 2, 0)
      ax[1].imshow(transformed_image)
      ax[1].set_title(f'Transformed\nSize: {transformed_image.shape}')
      ax[1].axis('off')

      fig.suptitle(f'Class: {image_path.parent.stem}', fontsize=16)

plot_transformed_images(image_paths=image_path_list,
                        transform=data_transform,
                        n=3,
                        seed=None)