# Section 4 - Computer vision-based machine learning 
## Introduction to pyTorch models 

## Dr. Antonin Vacheret (avachere@imperial.ac.uk) 
## High Energy Physics Group
## 523 Blackett Lab

A quick run through some basics of pyTorch starting from a quick exploration of the models readily available

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

import torch
torch.version.__version__

## I. Pre-trained Legacy computer vision classifier models

In [None]:
from torchvision import models
dir(models)

This is the famous AlexNet model that shaked the field of machine learning in 2012:
https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf

Note: the lowercase models have fixed 

In [None]:
alexnet_function = models.AlexNet() # this is the "empty shell" of Alexnet
alexnet_trained = models.alexnet(pretrained=True) # fixed artchitecture already pretrained

This one is Resnet 101 which stands for residual network. This one is the 101 layer version.
https://arxiv.org/abs/1512.03385
It has beaten several benchmark in 2015 and started the deep learning revolution. It is trained on imagenet with 1.2M images on 1000 categories.


In [None]:
resnet = models.resnet101(pretrained=True) # beware this is taking on average a few mins to download

Let's take a look at a high def picture of a dog. You can replace this one with your prefered one.

In [None]:
from PIL import Image
img = Image.open("img/mydoge.jpg")

In [None]:
img

Importing Hi-definition image from img folder but now defining some transformation first (a very powerful feature of pytorch !) to preprocess the image and get the right input size for the network.

In [None]:
from torchvision import transforms
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )])

In [None]:
img_t = preprocess(img)

In [None]:
img_t

In [None]:
plt.imshow(img_t[2,:,:])

In [None]:
batch_t = torch.unsqueeze(img_t, 0)
batch_t

In [None]:
resnet.eval() # putting the model in inference mode (no training of the weights) 

In [None]:
out = resnet(batch_t)
out

In [None]:
#scores  = out.detach().numpy()
#plt.plot(scores[0])
#plt.show()

#### Now an operation involving a massive 44.5M parameters has just taken place !
This has produced a vector of a 1000 score, one for each label of the imagenet training set. Let's get the file that has the imagenet list of labels.

We need now to figure out what was the ranking for our dog picture. 

In [None]:
with open('data/imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]
labels

In [None]:
_, index = torch.max(out, 1) # this returns the value and index of the higest score
print(index)

Resnet gives us a score but what we are interested in is more something like a the probability of being of a certain category. We will use the softmax function for that (multi-class classifier). 

In [None]:
percentage = torch.nn.functional.softmax(out, dim=1)[0] # only one dimension, [0] is to return one value.
percentage

In [None]:
labels[index[0]], percentage[index[0]].item() 

Exercises:

* Sort the output so the five highest probabilities come out from the resnet outpout
    
* Dowload alexnet and look at the output for our dog image. Which model is best ?







## II. Pre-trained example of another type of model: the CycleGAN 
from Deep Learning with PyTorch

In [None]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)

In [None]:
netG = ResNetGenerator()

In [None]:
model_path = 'data/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

In [None]:
netG.eval()

In [None]:
from PIL import Image
from torchvision import transforms

In [None]:
preprocess = transforms.Compose([transforms.Resize(256),
                                 transforms.ToTensor()])

In [None]:
img = Image.open("img/horse.jpg")
img

In [None]:
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)

In [None]:
batch_out = netG(batch_t)

In [None]:
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
# out_img.save('data/zebra.jpg')
out_img