# Lesson 21

[Course Repo](https://github.com/fastai/course22p2)



## WandB demo

Jono: Demo of WandB.  (notebook 21_cifar10_and_wandb.ipynb) 

Main points:

- Uses CFAR10 instead of fashion mnist.  The tools still work. 
- Weights and Biases (WandB) mointoring:
    -  [WandB](https://wandb.ai/site) for monitoring and tracking of model performance. 
    -  Call logging functions in your code to log metrics, hyperparameters, and other data to the WandB server, and then you can view and compare these results in the WandB dashboard.
    -  You can also use WandB to log models, code and more.
- Jono implements a WandB callback in a few lines to implment this!
- Jeremy says he doesn't use this (intentionally) because he worries about just doing wide sweeps of hyperparameters.  He prefers to do a more focused search.

## Image quality metrics

- We need a metric to evaluate the quality of the generated images, to compare different models.

- Frechet inception distance (FID), demo in notebout 18_Fid.ipynb  Uses DDPM model from previous lesson. 
   - FID is a metric for generative models.
   - It is a measure of how well the generated images are similar.
   - Looks to see what the 'typical' final layer activations look like for a set of images.


In [None]:

import pickle,gzip,math,os,time,shutil,torch,random
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
from scipy import linalg

from fastcore.foundation import L
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from minai.datasets import *
from minai.conv import *
from minai.learner import *
from minai.activations import *
from minai.init import *
from minai.sgd import *
from minai.resnet import *
from minai.augment import *
from minai.accel import *

In [None]:
from fastcore.test import test_close
from torch import distributions

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

### Use existing classifier with classifier head removed



In [2]:
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 512

@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]

dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)

b = xb,yb = next(iter(dls.train))

cbs = [DeviceCB(), MixedPrecision()]
# I dont have a data_aug2 so hopefully this one will work
model = torch.load('models/data_aug.pkl')
learn = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None)


  model = torch.load('models/data_aug.pkl')


In [4]:
model

Sequential(
  (0): ResBlock(
    (convs): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GeneralRelu()
      )
      (1): Sequential(
        (0): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (idconv): Sequential(
      (0): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1))
    )
    (act): GeneralRelu()
  )
  (1): ResBlock(
    (convs): Sequential(
      (0): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GeneralRelu()
      )
      (1): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): 

We want the output without the classifier head.  

In [5]:
del(learn.model[8])
del(learn.model[7])

Could also use callbacks to do this.

In [6]:
feats,y = learn.capture_preds()
feats = feats.float()
feats.shape,y

  def before_fit(self, learn): self.scaler = torch.cuda.amp.GradScaler()


(torch.Size([10000, 512]), tensor([9, 2, 1,  ..., 8, 1, 5]))

### Flechet Inception Distance (FID)


NOTE: At 31:00 he discovers "a bug" in the images scaling.  
 
- Uses means and covariance matrix of the global pooling layer (accross the samples/training data) of the classification model
- These are compared for two images (generated and real) to get a distance metric.
- Idea is that the pooling layer has activations for various feature in the images. If those features are correlated (e.g. ears and eyes) then the means and covariance matrix will be correlated.
 
$$
d_{F}(\mathcal N(\mu, \Sigma), \mathcal N(\mu', \Sigma'))^2 = \lVert \mu - \mu' \rVert^2_2 + \operatorname{tr}\left(\Sigma + \Sigma' -2\left(\Sigma \Sigma'  \right)^\frac{1}{2} \right)
$$

Note the need of the matrix square root. 

* Called Inception because the original paper used the Inception model.

* Primary caveat is that you have to be careful comparing FID scores with different sized data sets. Not a universal metric.

* FID is a good metric for comparing different models on the same data set.


Now we need our generator again.

In [None]:
betamin,betamax,n_steps = 0.0001,0.02,1000
beta = torch.linspace(betamin, betamax, n_steps)
alpha = 1.-beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

def noisify(x0, ᾱ):
    device = x0.device
    n = len(x0)
    t = torch.randint(0, n_steps, (n,), dtype=torch.long)
    ε = torch.randn(x0.shape, device=device)
    ᾱ_t = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
    xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl], alphabar)
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)

dls2 = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))


from diffusers import UNet2DModel

class UNet(UNet2DModel):
    def forward(self, x): return super().forward(*x).sample


smodel = torch.load('models/fashion_ddpm_mp.pkl').cuda()

In [None]:
@torch.no_grad()
def sample(model, sz, alpha, alphabar, sigma, n_steps):
    device = next(model.parameters()).device
    x_t = torch.randn(sz, device=device)
    preds = []
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
        z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device)
        ᾱ_t1 = alphabar[t-1]  if t > 0 else torch.tensor(1)
        b̄_t = 1 - alphabar[t]
        b̄_t1 = 1 - ᾱ_t1
        x_0_hat = ((x_t - b̄_t.sqrt() * model((x_t, t_batch)))/alphabar[t].sqrt())
        x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
        preds.append(x_0_hat.cpu())
    return preds


Show example for the case at hand, 512 channels. (around 38 minutes)

## KID 50:00

KID Kernal inception distance.   

- Uses the features directly not the means accross the set. 
- "The math doesnt matter" 
- Measure of simularity between the distribtuions of the features.   
- Has low bias but high variance.

Jeramy creates a class for thsi that returns both FID and KID.  He shows for example how the images improve during the denoising process. KID and FID do look the same. 

Jeremy also looks at the 'real' fid using the real FID . 

## Fixing the scaling bug 1:00

- Jeremy noticed that back in the DDPM_v2 notebook, the images were scaled from 0 to 1 instad of -1 to 1.   He fixes this:
```
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
```

- This makes everything worse, so he spent two days trying to find the other bug that must be responsble and was being offset by this bug.

- In the end he finds no relavent bug and asks: why is it a bug to scale from 0 to 1? Just because everyone else does it?  He tried it ny just subtracting 0.5 which keeps the range the same (smaller then normal). This leads to DDBM_v3. This is an improvement to the model, even accoring to the FID score.
 

### Schedule experiments 1:09

* uses 19 DDPM_v3 notebook.  

* As part of his debugging he started to question everything, for example the $\beta$ schedule

* Tested Cosine squedule vs linear schedule to compare them.  Remember that $\bar{\alpha}$ is what really matters.  

Note that the linear score has a long part of the time when the $\bar{\alpha}$ is near zero.  So he also tried decreasing $\beta_{max}$. In fact the curves are similar in that case. So in the next version he changed $\beta_{max}$ to 0.01   Results do look better. He also changed the model by making it bigger and training for longer.

Fid is nearly as good as real image. 


* SKip sampling

In [19]:
n_steps  = 1000
[t for t in range(n_steps) if (t+101)%((t+101)//100)==0][290:]

[949, 959, 969, 979, 989, 999]

In [None]:

@torch.no_grad()
def sample2(model, sz):
    ps = next(model.parameters())
    x_t = torch.randn(sz).to(ps)
    sample_at = {t for t in range(n_steps) if (t+101)%((t+101)//100)==0}
    preds = []
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=ps.device, dtype=torch.long)
        z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(ps)
        ᾱ_t1 = alphabar[t-1]  if t > 0 else torch.tensor(1)
        b̄_t = 1-alphabar[t]
        b̄_t1 = 1-ᾱ_t1
        if t in sample_at: noise = model((x_t, t_batch))
        x_0_hat = ((x_t - b̄_t.sqrt() * noise)/alphabar[t].sqrt())
        x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
        if t in sample_at: preds.append(x_t.float().cpu())
    return preds

    # And he has another that skipps even more less severly

    # FID SCORE TILL pretty good

### 1:20  DDIM

Whats the best paper for faster generation?