In [None]:
%load_ext autoreload
%autoreload 2

# 0. Intro: deepglue fashion-mnist demo
<img src="https://raw.githubusercontent.com/EricThomson/deepglue/main/docs/images/deep_glue_logo.png" alt="deepglue logo" align="right" width="160" style="margin-left:10px;">

Welcome to the deepglue fashion-mnist demo! If you are new to deepglue, this is the best place to start, as it introduces the key features and functions.

deepglue is a library of pytorch utilities designed to simplify and streamline deep learning workflows (our motto is "Keeping the  useful stuff together in one place"). In this demo, we'll highlight some of deepglue's key features using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset, which contains 70,000 images of clothing items organized into 10 categories (e.g., t-shirts). Because this demo is meant to run in multiple environments, including Google Colab, we will be using a small subset of fashion mnist with only 3,000 images (`fashion3k`).
  
The main steps we'll walk through in this notebook include:
- **Set up project**: Create directory structure for project.
- **Download and explore dataset**: Retrieve and explore the fashion3k dataset.
- **Define the network**: Set up resnet18 convolutional network for transfer learning.
- **Set up data for training**: Define augmentation transforms and data loaders for training. 
- **Train and evaluate the model**: Train the model on fashion3k, and evaluate performance on validation data.
- **Visualize Features**: Visualize feature clusters to help us understand the network's behavior.
  
If anything is unclear in this demo, feel free to ask questions at the repo's [Discussion forum](https://github.com/EricThomson/deepglue/discussions). If you find a problem, please [raise an issue](https://github.com/EricThomson/deepglue/issues). 

Also, in the rest of the notebook you can explore any function's documentation by typing the function name followed by a question mark. E.g.,  just type in `dg.train_one_epoch?` in a code cell to get a printout of the documentation for `train_one_epoch()`. To dig deeper, you can also explore [deepglue's source code](https://github.com/EricThomson/deepglue) and [online documentation](https://deepglue.readthedocs.io/en/latest/).

### Import packages
Let's get started by importing the packages we'll use in the rest of the notebook.

For those in Colab, we first need to install packages that are not natively part of the virtual environment.

In [None]:
try:
    import google.colab # noqa: F401 # ruff: ignore unused import error 
    in_colab = True
except ImportError:
    in_colab = False

if in_colab:
    print("Installing deepglue and umap-learn for colab envirnoment")
    !pip install -q deepglue umap-learn

In [None]:
import gdown
from IPython.display import HTML
import logging
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
from pprint import pprint
import umap
import zipfile

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
import torchvision.transforms.v2 as transforms

import deepglue as dg

Determine whether to use CPU or GPU to train. If your pytorch install detects a GPU, then set your device to `cuda`, otherwise default to `cpu`.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
device = 'cpu'
print(device)

Deepglue has loggers set up through many of its functions that print out what it is doing internally. For now, we will set the logger to just print warnings. To get more informative outputs you can set it to `INFO`.

If you don't want your world cluttered with logging messages, you can delete the following or comment it out.  

In [None]:
log_format = "%(asctime)s - %(filename)s - %(funcName)s - line %(lineno)d - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.WARNING,  # INFO DEBUG WARNING ERROR etc
                    format=log_format,
                    force=True)  # to override colab defaults 

# 1. Set up project structure 
It's key to be organized for any deep learning project. Here, we will set up a directory structure for our fashion mnist project. 

## Set up directory for all projects
Since this is the first demo, we'll set up the default `projects` directory where all deepglue projects will be stored. 

The following will create a `.deepglue/projects/` directory in your system's home directory. The exact location will depend on your OS and environment (e.g., on Linux/Mac it will be `~`, on Windows it will be `C:/Users/Username/`, and in Colab it will be `root/`). If you prefer a different base directory, you can change `deepglue_dir` to a different value below.

In [None]:
deepglue_dir = Path.home() / ".deepglue"
projects_dir = deepglue_dir / "projects"

try:
    projects_dir.mkdir(parents=True, exist_ok=False) # prevent overwriting
except FileExistsError:
    print("'projects/' directory already exists. Skipping creation.\n")

print(f"Your deepglue projects directory is: {projects_dir}")

## Directory for fashion project
Within the deepglue project folder, let's create the minimal directory structure for our fashion mnist classification project. Given a new project name (`fashion`), deepglue's `create_project()` function creates the following directory structure within `projects/`:

    fashion/
        data/
        models/

This is a kind of minimal structure for a deep learning project. You can add more as needed.

In [None]:
project_name = "fashion"
project_dir, project_data_dir, project_models_dir = dg.create_project(projects_dir, project_name)

print(f"Your new project directory: {project_dir}")
print(f"Datasets will go into {project_data_dir}")

# 2. Download data
As mentioned above, we have a subsampled version of the full fashion mnist dataset. It has been split into training, validation, and test data that is stored online:

In [None]:
data_url = r'https://drive.google.com/uc?id=1tBiL2H9xcjy7ClkLsxuag4_gBKusau1J'
filename = 'fashion3k.zip'
data_zip_path = project_data_dir / filename
data_dir = project_data_dir / 'fashion3k' # data dir we'll unzip into
print(f"Will attempt download to {data_zip_path}")

In [None]:
%%time
if not data_zip_path.exists():
    print(f"Downloading {filename}. This can take a minute.")
    gdown.download(data_url, str(data_zip_path), quiet=False, fuzzy=True, use_cookies=True)
else:
    print(f"{filename} already downloaded. Download skipped.")

Unzip the compressed data into the `data_dir` we already created above. The following has some extra wrinkles to avoid errors and repeating decompression. 

In [None]:
%%time
if data_dir.exists():
    print("Already extracted. Skipping.")
else:
    if data_zip_path.suffix == '.zip':
        print(f"Unzipping to {project_data_dir}...")
        with zipfile.ZipFile(data_zip_path, 'r') as zip_ref:
            zip_ref.extractall(project_data_dir)
        print("Done unzipping!\n")
    else:
        print('Not a zip file.')

train_dir = data_dir / 'train'
valid_dir = data_dir / 'valid'
test_dir = data_dir / 'test'

Our final directory structure follows a standard that is used by several `pytorch` functions:

        projects/
            fashion/
                models/
                data/
                    fashion3k/
                        train/
                            0/  [tshirt]
                            1/  [trouser]
                        valid/
                            0/
                            1/   
                        test/
                            0/
                            1/

Each end node in this directory tree (`0/`) contains image data from the relevant category: feel free to navigate to the relevant spots on your machine to check out the data. Below we will show how to use `deepglue` to inspect random samples of the data. 

Our mapping from subdirectory names to actual categories is as follows:

In [None]:
category_map = {'0': 'tshirt',
                '1': 'trouser',
                '2': 'pullover',
                '3': 'dress',
                '4': 'coat',
                '5': 'sandal',
                '6': 'shirt',
                '7': 'sneaker',
                '8': 'bag',
                '9': 'ankle_boot'}

categories = ['0','1','2','3','4','5','6','7','8','9']
category_names = [category_map[key] for key in categories]

<div class="alert alert-info">
<h3 style="margin: 1px 0 6px 0;">Scalable data structures</h3>
In real-world projects there will often be multiple datasets placed in the project's <code>data/</code> directory. Since we are working with a single dataset, we could in theory just extract the <code>train/</code>, <code>valid/</code> and <code>test/</code> folders directly into <code>data/</code> to keep a more flat structure. But our more nested structure mimics how larger projects handle multiple datasets in a scalable way, so we'll stick with it.
</div>

# 3. Explore data 
An initial look at the dataset.

Deepglue will plot some random images from the dataset:

In [None]:
dg.plot_random_sample(data_dir, category_map, split_type='train', num_to_plot=20);

Within the project, the data are divided into training, validation, and testing splits. How many are in each split?

In [None]:
num_per_split = dg.count_by_split(data_dir)

total_samples = num_per_split['train'] + num_per_split['valid'] + num_per_split['test']
print(f"Num samples total: {total_samples}")
pprint(num_per_split)
data_splits = ['train', 'valid', 'test']
proportion_per_split =  [num_per_split['train']/total_samples, 
                         num_per_split['valid']/total_samples, 
                         num_per_split['test']/total_samples]

# Plot number in train/validation/test splits
plt.bar(data_splits, proportion_per_split)
plt.title("Proportion per split")
plt.xlabel('Split Type')
plt.ylabel('Proportion');

In this tiny sample of `fashion_mnist`, we have 3000 total samples, with 2100 images set aside for training, 450 for validation and testing respectively (a 70/15/15 split). 


<div class="alert alert-info">
<h3 style="margin: 1px 0 6px 0;">fashion3k versus fashion_mnist</h3>

<p>Because fashion3k is a much smaller dataset than the full fashion-mnist (70,000 images), this lets us train our models even in runtime environments without a great deal of computational resources (e.g., free-tier Colab). The tradeoff is that we will be sacrificing accuracy. We made this tradeoff because the main point of this demo is to quickly illustrate key concepts, not to maximize accuracy.</p>

<p>If you want more data, we discuss how to download the full fashion mnist dataset at the end of this notebook.</p>
</div>

In [None]:
num_category_by_split = dg.count_category_by_split(data_dir)

`count_category_by_split()` creates a dict with the splits as the keys (`train`, `valid`, and `test`), and each dictionary contains the number of samples from each category:

In [None]:
print('Training split counts:')
pprint(num_category_by_split['train'])

Let's plot the proportion represented in the categories in the three data splits:

In [None]:
# get counts and proportions
train_counts = np.array([num_category_by_split['train'][key] for key in categories])
train_proportions = train_counts/sum(train_counts)

valid_counts = np.array([num_category_by_split['valid'][key] for key in categories])
valid_proportions = valid_counts/sum(valid_counts)

test_counts = np.array([num_category_by_split['test'][key] for key in categories])
test_proportions = test_counts/sum(test_counts)

# plot them
fig, (ax_test, ax_val, ax_train) = plt.subplots(3,1,figsize=(5,5)) # width x height

#train (bottom)
ax_train.bar(category_names, train_proportions)
ax_train.tick_params(axis='x', labelrotation=45)
ax_train.set_title("Training Data")
ax_train.set_xlabel('Category')

# validation (middle)
ax_val.bar(category_names, valid_proportions)
ax_val.set_ylabel('Proportion');
ax_val.set_title("Validation Data")
ax_val.set_xticks([]) 

# test (top)
ax_test.bar(category_names, test_proportions)
ax_test.set_title("Test Data")
ax_test.set_xticks([])

plt.tight_layout()

This is (by design) an extremely balanced data set, in all three splits. 

# 4. Define network model
To learn the fashion classificaiton task, we'll start with a pre-trained [resnet18 model](https://www.run.ai/guides/deep-learning-for-computer-vision/pytorch-resnet) that was trained on the Imagenet 1k dataset (1000 categories with over 1 million total images). We are initially freezing all the model parameters, and then will unfreeze the final two convolution layers, and add in a new fully connected (fc) layer that includes 50% dropout to prevent overfitting (dropout is only turned on during training). 

We'll be keeping the early layers which extract basic features, but allowing the later layers to learn the higher-level fashion-specific features in fashion-mnist. 

Feel free to substitute a different model and adapt the code to suit your needs. For instance, `resnet50` performs better, but it will also take up way more memory and take longer to train. For this little demo, the smaller and faster `resnet18` performs good enough. We'll discuss the topic of trying out different models at the end in the `Test data` section.

In [None]:
num_classes = len(category_map)  # Adjust to match the number of classes in Fashion MNIST
num_fc_hidden_units = 128  # Number of hidden units in fully connected layer: feel free to tweak

In [None]:
# Load pre-trained ResNet50 model
resnet_weights = models.ResNet18_Weights.IMAGENET1K_V1
resnet18 = models.resnet18(weights=resnet_weights)

# Freeze all model parameters initially
for param in resnet18.parameters():
    param.requires_grad = False
    
# unfreeze blocks 3 and 4
for param in resnet18.layer3.parameters():
    param.requires_grad = True
    
for param in resnet18.layer4.parameters():
    param.requires_grad = True

# Get the number of inputs for the original fully connected (fc) layer
num_fc_inputs = resnet18.fc.in_features

# Replace the final fully connected layer (note requires_grad is True by default for a brand-new layer)
resnet18.fc = nn.Sequential(nn.Linear(num_fc_inputs, num_fc_hidden_units),  # Projection from backbone to hidden units
                            nn.ReLU(),                                      # Activation for non-linearity
                            nn.Dropout(p=0.5),                             # Reduce overfitting
                            nn.Linear(num_fc_hidden_units, num_classes))    # Final layer for class prediction

We also need to define a loss function and optimizer (which includes the learning rate schedule) for the network.

In [None]:
loss_func = nn.CrossEntropyLoss() 
optimizer = optim.Adam(resnet18.parameters(), lr=0.0001)  

<div class="alert alert-info">
<h3 style="margin: 1px 0 6px 0;">Expected inputs for resnet</h3>
Resnet models were trained on 224 x 224 images, and this matters. If you try to feed the model different sized images, it will not perform as well, or you may get errors at some point in your project. Fashion mnist images are 28x28, so in what follows we will upsample them to 224x224. We'll also convert the inputs to RGB. We could define a completely new network and train from scratch, but then we'd lose out on the gains from the pretrained network. 
</div>

# 5. Set up data for training
Having a bunch of data in folders is great, but `torchvision` provides lots of utilities to make funneling such data through training pipelines really easy. Also, integrating transformations like random cropping, rotations, and normalization directly into the our pipelines simplifies training, and `torchvision` has a great api that integrates such augmentations directly into the data pipeline. 

Defining datasets and data loaders requires us first to define the transforms we will use on the data as we load it from disk. 

## Transforms
Torchvision has an amazing set of transforms that apply whether you are doing classification, object detection, or scene segmentation. To learn more about the transforms, see their [documentation](https://pytorch.org/vision/main/transforms.html#v2-api-ref) or [example page](https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py). 

We are going to set up a relatively simple transform here just to show the logic. First we'll define a couple of transforms that we will randomly apply in our transform function:

In [None]:
gaussian_noise = transforms.GaussianNoise(mean=0, # mean of sampled noise
                                         sigma=0.1, # std of sampled noise
                                         clip=True)  # clip to [0,1] after adding noise
gaussian_blur = transforms.GaussianBlur(kernel_size=(7,7), # kernel size (width, height)
                                        sigma=(0.8, 0.8)) # sigma: min, max randomly chosen

Then define a function for a data transform that can be applied to different data splits. 

In [None]:
def data_transform(train=False):
    """
    Define data transformation to be used when data is ingested for use by datasets.
    """
    transform_pipeline = []
    
    transform_pipeline.append(transforms.Resize((224, 224)))     # resnet expect this size
    transform_pipeline.append(transforms.ToImage())  # Convert PIL image to tensor: many transforms only work on torch tensor
    transform_pipeline.append(transforms.ToDtype(torch.float32, scale=True)) # many transforms only work with float
    transform_pipeline.append(transforms.RGB()) # resnet expects RGB (if it is already RGB, nothing changes)

    if train:
        transform_pipeline.append(transforms.RandomApply([gaussian_noise], p=1/3)) # will apply p of the time
        transform_pipeline.append(transforms.RandomApply([gaussian_blur], p=1/3))  
        transform_pipeline.append(transforms.RandomHorizontalFlip(p=1/2))
        transform_pipeline.append(transforms.RandomRotation(30, fill=0.445, expand=False))  # will rotate (-20,20); expand would resize so image fits in image shape
        
    # Normalize to standard values for resnet
    transform_pipeline.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                   std=[0.229, 0.224, 0.225]))  
    return transforms.Compose(transform_pipeline)

In the above, we have defined a general transform function that we can define for diffrent data splits: those that are set for training data will have distortions applied to the data for augmentation purposes (gaussian blur, noise, flips, rotations, etc.). The validation and test splits will only have "nondistorting" changes applied: they will be rescaled to the proper size, changed to torch tensors, converted to floats, etc. 

### Transform demo
Let's look at how the transform works in the case of training data just for fun. We'll set train to `True` so we'll see the effects of the augmenting transforms. 

In [None]:
demo_transform = data_transform(train=True)

Let's get a random image from our dataset:

In [None]:
category_to_plot = '7'
demo_image_path, demo_image_category = dg.sample_random_images(data_dir, 
                                                               category_map, 
                                                               category=category_to_plot) 
demo_image = np.array(Image.open(demo_image_path[0]))
dg.plot_transformed(demo_image, demo_transform, num_to_plot=9);
print(category_map[category_to_plot])

You can see that we get multiple views of the sneaker, some are pretty distorted. This is good! We want to give the network some tricky instances for training. 

### Define transforms for our datasets
We need transforms for training, validation, and test data:

In [None]:
image_transforms = {'train': data_transform(train=True),
                    'valid': data_transform(train=False),
                    'test': data_transform(train=False)}

## Datasets
Let's create our training, validation, and test datasets using torchvision's `datasets.ImageFolder` class. This is a convenience class that hooks into image folders that are structured the way we have set them up. In general, pytoch datasets provide a flexible way to define how your data is loaded, transformed, and accessed, making it useful for deep learning workflows. 

Datasets are designed to load data lazily, meaning they load individual items on-the-fly when accessed: this minimizes memory usage, which is especially helpful fo large datasets. The primary use of datasets is that they get wrapped into a `DataLoader`, which we will see in the next section. 

We'll create a dictionary of datasets, one dataset for each data split.

In [None]:
data = {'train': datasets.ImageFolder(root=train_dir, transform=image_transforms['train']),
        'valid': datasets.ImageFolder(root=valid_dir, transform=image_transforms['valid']),
        'test': datasets.ImageFolder(root=test_dir, transform=image_transforms['test'])}

## Dataloaders
The point of the above transforms/datasets is to create data loaders, which is what we actually use directly during training. They shuffle the data, create batches, and parallelize the workload using multiprocessing to make everything go faster.

Dataloaders require a couple of parameters and you might have to adjust these depending on your system. Just to review some ML terminology: one *epoch* is a single run through the entire data set. Our training data has 50k images, so one epoch of the training data will be a run through 50k images. This is way too much to run through all at once. A *batch* is the number of images in a subset for one round of processing during training: the error is calculated for this subset and the network is updated during training for each batch. `num_workers` is the number of (CPU) processes that will work to generate batches in parallel: such parallelization can *significantly* speed up runtime. 

In [None]:
batch_size = 256   # 128 on laptop, 256 on workstation

if in_colab:
    num_workers = 2  # max CPUs available on free colab is 2
else:
    num_workers = 6   # 4 on laptop, 6 on workstation

# persistent workers makes things startup faster, but only if you have multiple workers
if num_workers > 0:
    persist_workers = True
else: 
    persist_workers = False

In [None]:
train_data_loader = DataLoader(data['train'], 
                               batch_size=batch_size, 
                               shuffle=True,
                               num_workers=num_workers,
                               drop_last=True,  # drop dangling batch at end
                               persistent_workers=persist_workers)

valid_data_loader = DataLoader(data['valid'], 
                               batch_size=batch_size, 
                               shuffle=True,
                               num_workers=num_workers,
                               persistent_workers=persist_workers)

test_data_loader = DataLoader(data['test'], 
                              batch_size=batch_size, 
                              shuffle=True,
                              num_workers=num_workers,
                              persistent_workers=persist_workers)

<div class="alert alert-info">
<h3 style="margin: 1px 0 6px 0;">Multiprocessing on different operating systems</h3>
Different systems can take extra time to create multiple workers. Windows and (newer) Macs uses a different method than Linux to create new workers. On some older Windows systems, setting <code>num_workers</code> to any number greater than <code>0</code> will simply cause your system to hang. If this happens, you unfortunately have to set <code>num_workers = 0</code>. However, do give it a minute to start, because once the workers are set up initially, the speed payoff is <em>significant</em>.
</div>

If you have already trained a model, you can jump to step 7 -- load the network.

# 6. Train the network
We'll use deepglue's `train_and_validate()` to update the weights in the network using the training data, and check performance on the validation data. This is what we've been building toward! 

There are a couple of parameters we need to set. How many times will we cycle throught the datasets during training (`num_epochs`), and the `topk` accuracy values.

A top-k prediction counts as correct if the prediction was in the top k highest probabilities from the network, even if it wasn't the top choice (e.g., if a network's highest estimates for an image is `[sneaker, sandal, ankle_boot]`, for a `sandal` then it is top-3 accurate). This is useful for datasets with lots of similar categories.

In [None]:
topk = (1,3)

# very slow if only running on cpu
if device == 'cpu':
    num_epochs = 5
else:
    num_epochs = 10 

This is the is the slowest step in the notebook. While we've tweaked the dataset to make it faster, it can still be slow on CPU-only systems so we set `num_epochs` to 5 on such systems. Feel free to increase it if you don't mind waiting.  Under the hood, `dg.train_and_validate()` works by churning through the data from the training and validation dataloaders, each `num_epoch` times. 

In [None]:
#  reset our logging level so we can get some feedback during training
logging.getLogger().setLevel(logging.INFO)

trained_model, train_history = dg.train_and_validate(resnet18,
                                                     train_data_loader,
                                                     valid_data_loader,
                                                     loss_func,
                                                     optimizer,
                                                     device=device,
                                                     topk=topk,
                                                     epochs=num_epochs);

## View loss/accuracy
Let's see how it did during training. Note we will explore other metrics below this is just to get a quick sense for how things went. 

In [None]:
epoch_array = np.arange(num_epochs)

In [None]:
fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(10,4))

# Plot loss 
ax_loss.plot(epoch_array, train_history['train_loss'], color='blue',  label="Training Data", marker='.')
ax_loss.plot(epoch_array, train_history['val_loss'], color='firebrick', label="Validation Data", marker='.');
ax_loss.set_title('Training Loss')
ax_loss.set_ylabel('Loss')
ax_loss.set_xlabel('Epoch')
ax_loss.legend();

# Plot topk accuracies
ax_acc.plot(epoch_array, train_history['val_topk_accuracy'][:,1], color='firebrick',  label="Validation top 3", marker='.');
ax_acc.plot(epoch_array, train_history['train_topk_accuracy'][:,1], color='blue', label="Training top 3", marker='.')
ax_acc.plot(epoch_array, train_history['val_topk_accuracy'][:,0],  color='darksalmon', label="Validation top 1", marker='.');
ax_acc.plot(epoch_array, train_history['train_topk_accuracy'][:,0],  color='lightsteelblue', label="Training top 1", marker='.')

ax_acc.axhline(y=100, color='k', linestyle='--', linewidth=0.5)
ax_acc.set_title('Training Accuracy')
ax_acc.set_ylabel('Accuracy')
ax_acc.set_xlabel('Epoch')
ax_acc.legend();

plt.tight_layout()

Things look reasonable on a first pass. We are inspecting to see if validation data trends in the opposite direction of the training data, which would be a classic sign of overfitting. 

# 7. Save the network
Docs on this: https://pytorch.org/tutorials/beginner/saving_loading_models.html 

Temporary until we have proper checkpoint save/load built into training function.

In [None]:
checkpoint_save_name = r"resnet18_final.pth" #
checkpoint_save_path = project_models_dir / checkpoint_save_name
checkpoint_save_path

In [None]:
torch.save({'model': trained_model,   # Model parameters
            'optimizer': optimizer,  # Optimizer parameters
            'epochs': num_epochs,  
            'train_loss': train_history['train_loss'],
            'val_loss': train_history['val_loss'],  
            'train_accuracy': train_history['train_topk_accuracy'],
            'val_accuracy': train_history['val_topk_accuracy'],
            'topk': topk,}, checkpoint_save_path) 

## Load network (optional)
To save time, you can skip from Step 1 to this step once you have trained the network once. This is not set up to work on Colab across sessions. 

In [None]:
load_model = False
if load_model:
    # initialize things
    print("Loading final model checkpoint")
    checkpoint_load_name = r"resnet18_final.pth" #
    checkpoint_load_path = project_models_dir / checkpoint_load_name
    device = 'cuda'
    if checkpoint_load_path.exists() and torch.cuda.is_available():
        print("model exists, cuda available.")

    # load the data
    final_checkpoint = torch.load(checkpoint_load_path, weights_only=False)

    # Unpack values you want (TODO: cut some of these you don't use)
    trained_model = final_checkpoint['model']
    optimizer = final_checkpoint['optimizer']
    num_epochs = final_checkpoint['epochs']
    train_loss = final_checkpoint['train_loss']
    val_loss = final_checkpoint['val_loss']
    train_topk_accuracy = final_checkpoint['train_accuracy']
    val_topk_accuracy = final_checkpoint['val_accuracy']
    topk = final_checkpoint['topk']
else:
    print("Not loading model -- likely a training run")

# 8. Check model performance
First let's visually inspect model performance over some random images from the validation data. Then we'll look at some metrics calculated over the entire validation dataset.

## Inspect some predictions
We'll use a few deepglue convenience functions to get a few random images and predict their identity using the trained model:

In [None]:
rand_paths, rand_categories = dg.sample_random_images(data_dir, 
                                                      category_map, 
                                                      split_type='valid', 
                                                      num_images=10)
random_stack = dg.load_images_for_model(rand_paths, 
                                        data_transform(train=False)); # nondistorting transform
predicted_probs = dg.predict_batch(trained_model, random_stack, device=device); 

Using `dg.plot_prediction_grid()`, we'll plot the actual image and top prediction on the left, and the `top_n` predictions with their probabilities on the right: let's check out the top five predictions of the model. 

In [None]:
dg.plot_prediction_grid(random_stack, 
                        predicted_probs, 
                        rand_categories, 
                        category_map, 
                        top_n=5, 
                        figsize_per_plot=(2, 2), 
                        logscale=True);

## Metrics
While images are good (and we get good performance on individuals), we should examine performance over the entire validation data set.

Scikit learn has many metrics we can use. We have already built a dataloader for validation data above, and deepglue has a function for predicting all the data given a dataloader and the trained model.

In [None]:
val_preds, val_labels, val_probs = dg.predict_all(trained_model, valid_data_loader, device=device);

Now, with the predictions, correct labels, we can get lots of metrics. We'll focus on some of the basic metrics for now. 

### Confusion matrix
The confusion matrix tells you the count of the actual category and predicted category for all ten categories, which reveals basic error patterns (the main diagonal shows when the network is correct).  

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

cm_disp = ConfusionMatrixDisplay.from_predictions(val_labels,
                                                  val_preds,
                                                  display_labels=category_names,
                                                  xticks_rotation=45.,
                                                  cmap='magma');

### Classification report
Scikit learn generates a useful classification report that tells you the precision (indicates levels of false positives along a column of the confusion matrix, off the main diagonal) and recall (indicates levels of false negatives along a row of the confusion matrix, off the main diagonal) for each category. F1 is a combination of both precision and recall.   

The classification report also aggregates these measures into overall accuracy (overall proportion correct), average precision and recall (called 'macro average'), and weighted average (weighted by class size, which would be useful for imbalanced data). 

For more on the classification report, there is a useful discussion here: https://www.nb-data.com/p/breaking-down-the-classification. 

In [None]:
from sklearn.metrics import classification_report

In [None]:
report = classification_report(val_labels, val_preds, target_names=category_names);
print(report)

You can see that some of the categories are classified nearly perfectly (trousers, bags). While others are really tough for the classifier, in particular shirts (the recall is quite low, implying lots of false negatives -- images of shirts that were classified as non-shirts). Is this because our network needs more training, resnet18 is not up to the task, or maybe something intrinsic to our dataset? 

Before investing a ton of time tweaking parameters, let's do some visualization of our data and feature space to find the error patterns. 

# 9. Visualize clustering in feature space
As just mentioned, while our network is performing very well, we have a decision to make. Should we tweak some parameters? Should we bring in some more fancy learning rate scheduler? Maybe train for more epochs? Some more heavy augmentation in our transformer might be helpful. Or maybe we should bring in a bigger network like `resnet50`.

These are all reasonable options, but this notebook is meant to demo deepglue basics, not dive into the weeds of machine vision. Also, sometimes it is helpful to visualize what's happening inside of a network before spending that 80% of your time eking out that 2% improvement in the model performance. Pytorch provides some useful tools to for feature extraction from networks, and deepglue has utilities for visualizing the features embedded there. 

## Feature extraction
Feature extraction works by pushing images through the network, and extracting the activity patterns from the desired layers for those images.

During normal network operation, we only care about the final prediction of the network: the network output. But for analysis of networks, `pytorch` provides a function (`create_feature_extractor()`) that we can use to extract activation values from *any layer* of the network when given images. You just feed the function a data loader, and it returns the activation values for the designatred layers. 

Toward that end, we will create a data loader specifically for the feature extractor: deepglue's `prepare_ordered_data()` creates a dataloader for feature extraction that will go in order through all the data in a data split. The function returns the data loader and the image paths (the latter will become important for visualizing the data).

In [None]:
image_paths, ordered_loader = dg.prepare_ordered_data(data_dir, 
                                                      image_transforms['valid'], 
                                                      num_workers=num_workers, 
                                                      batch_size=56, 
                                                      split_type='valid')

Next, we have to decide which layers we want to extract features from. 

In torchvision, the convention is that the variable `return_nodes` defines the layers from which features can subsequently be extracted (keys are the actual layer names, and values are the names we give them for access later):

In [None]:
return_nodes = {'relu': 'cnn_features',                  # Early traditional CNN features
                'layer1.1.relu': 'resnetL1_features',
                'layer2.1.relu': 'resnetL2_features',
                'layer3.1.relu': 'resnetL3_features',
                'layer4.1.relu': 'resnetL4_features'}
pprint(return_nodes)

To list all of the options for extracting features you could run the following code (expect a long printout):

    for name, module in trained_model.named_modules():
        print(name)

We are going to keep it simple and just examine the activity very late to visualize clustering of higher level abstract features. Note if you were using `resnet50`, there would be *many* more layers to choose from.

Note the earliest one (relu) is from the first convolutional layer that doesn't include any residual blocks. It would provide access to very low-level features, but beware there would be *lots* of dimensions so you would needs lots of RAM. The other four are the outputs of the four residual blocks (we expect higher-level fashion features in layer 4 so will extract those in what follows).

Torchvision's `create_feature_extractor()` just needs our trained model and dictionary of return nodes.

In [None]:
trained_model.eval();
feature_extractor = create_feature_extractor(trained_model.to(device), 
                                             return_nodes=return_nodes)

Things are all set up. We can pick the layer (or layers) we want, and run feature extraction. This can take a while because it is running through the entire validation dataset as it extracts the activation patterns from the selected layer.

In [None]:
layer = 'resnetL4_features'

In [None]:
features, labels = dg.extract_features(ordered_loader, 
                                       feature_extractor, 
                                       layer=layer,
                                       device=device)

The features extracted are fairly high dimension (over 25k dimensions). This is actually much lower than if we had extracted features from earlier layers in the network. 

In [None]:
features.shape, labels.shape

## Dimensionality reduction and visualization
It is hard to visualize tens of thousands of features. We will now use umap to project these high-dimensional activity patterns to a 2d space to visualize how much they are clustering according to category. We will visualize this using a static plot (Matplotlib) and an interactive plot (using Bokeh).

The following umap-baed visualizations are adapted from a umap demo: https://umap-learn.readthedocs.io/en/latest/basic_usage.html 

In [None]:
umap_model = umap.UMAP(n_components=2, 
                       n_epochs=200, 
                       low_memory=True,
                       verbose=True)

UMAP is itself an iterative ML algorithm, so the following can take a minute:

In [None]:
%%time
layer_features_umap = umap_model.fit_transform(features)

<div class="alert alert-info">
<h3 style="margin: 1px 0 6px 0;">Other dimensionality reduction techniques</h3>

If you want to compare UMAP to PCA or other dimensionality reduction techniques, it should flow nicely through in the code. For instance:    
<code>
pca = PCA(n_components=2)
layer_features_pca = pca.fit_transform(features)
</code>
    
You can then just replace `layer_features_umap` with `layer_features_pca` in the code. PCA is much faster, and will project based just on variance which is easier to interpret (the clusters won't look as nice though). 
</div>

### Static visualization
Let's first do a static visualization of the feature clusters using matplotlib

In [None]:
plt.figure(figsize=(5, 5))

# Scatter plot of the subsampled features
scatter = plt.scatter(
    layer_features_umap[:, 0],  # x-coordinates from UMAP
    layer_features_umap[:, 1],  # y-coordinates from UMAP
    c=labels,           # Color code by labels
    cmap=plt.cm.tab10,              
    s=10,                          
    alpha=0.7                      
)

# Adding a colorbar for the labels
colorbar = plt.colorbar(scatter, boundaries=np.arange(num_classes + 1) - 0.5)
colorbar.set_ticks(np.arange(num_classes))
colorbar.set_ticklabels(category_names)
plt.xticks([])
plt.yticks([])

# Title and labels
plt.title('UMAP Projection of Fashion')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')

# Show plot
plt.axis('equal')
plt.grid(True)
plt.show()

We can see patterns that match what we saw with our classification report: trousers and bags clearly distinct. Shirts (in pink) are all mixed in with other categories: the network just hasn't learned to separate out these categorical features. However, it would be nice to have more details. For instance, which of those spots were correct/incorrect? What do the shirts actually look like that are close to t-shirts? Could *we* differentiate them? Let's build an interactive visualization tool that will give us such details.

### Interactive feature visualization
We will build an interactive scatter plot that plots a little embeddable image sprite when we hover over a point in the plot:

In [None]:
ind_to_embed = 7
encoded_image_str = dg.create_embeddable_image(image_paths[ind_to_embed], size=(60,60))
HTML(f'<img src="{encoded_image_str}" />')

Is the above a shirt or a tshirt? What criteria were used for these categories anyway? These are things we can dig into a bit in what follows. 

To get predictions that map onto the ordered data loader that we used for feature extraction, we need extract predictions from the ordered data loader:

In [None]:
ordered_predictions, ordered_labels, _ = dg.predict_all(trained_model, ordered_loader, device=device);

In [None]:
print(f"BTW, the above was a {category_names[ordered_labels[ind_to_embed]]}")
print("How did the network do?")
print(f"Actual category: {ordered_labels[ind_to_embed]}: Predicted category {ordered_predictions[ind_to_embed]}")

We can use deepglue's `plot_interactive_projection()` to plot everything we wanted in one plot:
- Scatter plot of the features projected in umap space, where hovering shows you see the sprite of the corresponding image.
- o's show correct predictions, x's show incorrect predictions so you can see what features confused the network.
- You can interact with the plot by zooming, selecting regions that you are interested in (see selection tools on right to determine whether you wheel zoom or box select).

In [None]:
dg.plot_interactive_projection(layer_features_umap, 
                               ordered_labels.numpy(), 
                               image_paths, # to show sprites
                               category_map,
                               predictions=ordered_predictions.numpy(), # to differentiate correct/incorrect
                               title='UMAP Projection', 
                               image_size=(75, 75), # size of sprite popups
                               plot_size=500, # x/y dims of plot
                               legend_location="top_left",
                               show_in_notebook=True) # if not True, will pop up plot in new tab

With this visualization we can start to see the range of visual features that are important for each category, and how they start to blend into each other and make category membership determination very hard (even for us). This kind of information can be very helpful for making decisions about what direction to take a project. We also see what differentiates tshirts from shirts (canonical shirts have buttons, tshirts do not, though these distinctions are not perfectly honored in this messy real-world data set). 

# 10. Test data

So far, we have purposely held off analyzing the test data. We really want to avoid data leakage, so we purposely keep the test data sealed off until we are sure we have the final comparisons we want to make. 

Once you have put your final candidate model (or models) through the ringer, and narrowed things down to a small set of models (e.g., your best resnet, vgg, whatever, with the best augmentation and learning rate schedulers), then it's time to put them to the final test with your test data to see which generalizes best outside of all of our training regime (which includes training and validation splits). At that point it's time to get an estimate of how well the model generalizes outside *all* of the training data. It probably won't be as good as for the validation data, which we used to tweak the model multiple times. That's why we set aside the test data and leave it in a lock box until the very end. 

I'm purposely leaving it out of this notebook to give people a chance to try different models (e.g., `resnet50`), add different augmentations, tweak things, and add whatever improvements to the model(s). You can then apply the same metrics we went over above with validation data (accuracy, recall, etc) to the test data and see how well things generalized. 

This is left as an exercise for the reader.

<div class="alert alert-info">
<h3 style="margin: 1px 0 6px 0;">What about fashion mnist?</h3>

For those that want to use the above tools using the full fashion_mnist dataset with all 70,000 images, it is available at the following link:

<code>full_data_url = r'https://drive.google.com/uc?id=1B15ViE9lKquepM2TCvApe9gGPeQIxiUS'</code>

You will have to adjust some of the above variable names from Section 2 of this notebook (e.g., `fashion_3k` will become `fashion_mnist` or something like that), but this will be a good exercise in working with multiple datasets within a project. 

Using the full dataset will yield much better metrics on your classifier.
</div>