In [2]:
import torch                                            # Framework
import torch.nn as nn                                   # Layer architecture
import torch.optim as optim                             # Optimizers
from torch.optim import lr_scheduler                    # Learning rate adjuster
import torch.backends.cudnn as cudnn                    # CUDA interface
import numpy as np                                      # yep
import torchvision                                      # CV torch
from torchvision import datasets, models, transforms    # yep
import matplotlib.pyplot as plt                         # yep
import time                                             # yep
import os                                               # yep
from PIL import Image                                   # yep
from tempfile import TemporaryDirectory                 # Automatic temporary folders

cudnn.benchmark = True          # Auto setup
plt.ion()                       # interactive mode

<contextlib.ExitStack at 0x14b0fbad160>

```Load data```

In [4]:
# Data augmentation and normalization for training

# Transform pipelines
data_transforms = {
    'train': transforms.Compose([                       # Pipeline of transforms for 'train'
        transforms.RandomResizedCrop(224),              # Random crop every access to an image (e.g. epoch)
        transforms.RandomHorizontalFlip(),              # Random flip every access to an image (e.g. epoch)
        transforms.ToTensor(),                          # Image to pytorch tensor with range of [0-1]
        transforms.Normalize([0.485, 0.456, 0.406],     # Mean color
                             [0.229, 0.224, 0.225])     # std of colors
    ]),
    'val': transforms.Compose([                         # Pipeline of transforms for 'val'
        transforms.Resize(256),                         # Resize smaller side of an image to 256
        transforms.CenterCrop(224),                     # Crop center of a 224 side
        transforms.ToTensor(),                          
        transforms.Normalize([0.485, 0.456, 0.406],     
                             [0.229, 0.224, 0.225])     
    ]),
}


data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
# Output:
# { 'train': ImageFolder(dataset_path_train, transform=data_transforms['train']),
#   'val':   ImageFolder(dataset_path_val,   transform=data_transforms['val'])    }
# ImageFolder() returns an object of labeled images went through augmentation


dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
# Output:
# { 'train': torch.utils.data.DataLoader(
#               image_datasets['train'],
#               batch_size=4,
#               shuffle=True,
#               num_workers=4,
#   'val':   torch.utils.data.DataLoader(
#               image_datasets['val'],
#               batch_size=4,
#               shuffle=True,
#               num_workers=4 }
# DataLoader() returns an object with mini batches of images with labels


dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# For this instance CUDA wasn't installed so CPU will be used
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


```Visualize some images```

In [None]:
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))      # Convert image shape from [C,H,W] to plt's [H,W,C]

    mean = np.array([0.485, 0.456, 0.406])      # }
    std = np.array([0.229, 0.224, 0.225])       # } 
    inp = std * inp + mean                      # } Revert normalization done by transforms.Normalize()

    inp = np.clip(inp, 0, 1)                    # Limit values to the range of [0-1]
    plt.imshow(inp)  


    if title is not None:
        plt.title(title)
    plt.pause(0.001)                            # Pause so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])