# Transfer learning
---
In this module, we will
 - analyze resnet nerural network architecture
 - make predictions over random images with resnet pretrained on ImageNet dataset (***http://www.image-net.org/***)
 - create your own dataset
 - finetune resnet for our custom dataset


In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

---

## Introduction to Resnet18

### Architecture
PyTorch tooling related to image processing can be found in `torchvision` module. Information about accessible pretrained models for PyTorch can be found at ***https://pytorch.org/docs/stable/torchvision/models.html***.

In [None]:
from torchvision import models

In [None]:
resnet18 = models.resnet18()
resnet18

In [None]:
conv1 = resnet18.layer4[0].conv1
conv1

### Prediction with pretrained model

In [None]:
from image_processing_workshop.utils import get_image_from_url
from image_processing_workshop.visual import plot_image
from torchvision import datasets, transforms
import numpy as np
import os
import torch

In [None]:
# Default path where to save pretrained models.
os.environ["TORCH_HOME"] = "./"

In [None]:
models.resnet.model_urls

In [None]:
resnet18 = models.resnet18(pretrained=True)
info = resnet18.eval()

In [None]:
url = 'https://media.wired.com/photos/5b86fce8900cb57bbfd1e7ee/master/pass/Jaguar_I-PACE_S_Indus-Silver_065.jpg'
img = get_image_from_url(url)
plot_image(img)

In [None]:
# Transformation for resnet, normalization is important!!
transformation = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
transformed_img = transformation(img)

In [None]:
img.shape, transformed_img.shape

In [None]:
batch = transformed_img.unsqueeze(0)
resnet18(batch)[0][:10]

In [None]:
# Let's adjust last layer and add softmax, softmax doesn't need training, it's just normalization.
resnet18.fc = torch.nn.Sequential(
    resnet18.fc, 
    torch.nn.Softmax(dim=1))

In [None]:
resnet18(batch)[0][:10]

In [None]:
from image_processing_workshop.utils import get_imagenet_category_names
from image_processing_workshop.visual import plot_classify

In [None]:
get_imagenet_category_names(use_cache=True)[:10]

In [None]:
plot_classify(transformed_img, resnet18, topn=5, category_names=get_imagenet_category_names(), figsize=(15,15))

In [None]:
url = 'https://s.w-x.co/util/image/w/411spacex.jpg?v=at&w=815&h=458'
img = get_image_from_url(url)
transformed_img = transformation(img)
plot_classify(transformed_img, resnet18, topn=5, category_names=get_imagenet_category_names())

### Prediction with fietuned model on CEOs

In [None]:
state_dict = torch.load('./models/ceo_resnet.pth')

In [None]:
state_dict.keys()

In [None]:
# Loading model
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(512, 4),
    torch.nn.Softmax(dim=1))
model.load_state_dict(state_dict['net'])
info = model.eval()

# Transformation for resnet, normalization is important!!
transformation = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
ceo_names = ['Elon Musk', 'Satya Nadella', 'Steve Jobs', 'Sundar Pichai']

In [None]:
url = 'http://www.gstatic.com/tv/thumb/persons/476283/476283_v9_ba.jpg'
img = get_image_from_url(url)
transformed_img = transformation(img)
plot_classify(transformed_img, model, topn=5, category_names=ceo_names)

---

## Building of our dataset
Let's Google a few (2-4) image categories (different people, few animal species, different cars, flowers etc..) and build dataset.

### 1. Scrape image url
 - go to Google images https://www.google.com/imghp?hl=EN
 - search image category and scroll a bit through it
 - run console `Ctrl-Shift-J`
 - paste javascript snippet to console  

```javascript
javascript:document.body.innerHTML = `<a href="data:text/csv;charset=utf-8,${escape(Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou).join('\n'))}" download="urls.txt">download urls</a>`;
```

 - rename `urls.txt` in `download="urls.txt"` according to category you scrape e.g. `cats.txt`
 - download image urls
 - copy `cats.txt` to docker shared folder with this ipython notebook

### 2. Download images to proper folder hierarchy
 - run `scrape_urls` with url file and appropriate category name
 - each run of `scrape_urls` will create following folder hierarchy
 
`root_folder/train/class_name/*.jpg`   
`root_folder/valid/class_name/*.jpg`

In [None]:
from image_processing_workshop.utils import scrape_urls

In [None]:
root_folder = './your_dataset'

In [None]:
scrape_urls('your_urls.txt', category_name='category name', root_folder=root_folder)

### 3. First look at data

In [None]:
from torchvision import datasets, transforms, utils
from image_processing_workshop.visual import plot_image

In [None]:
# Lets visualize download images.
transformation = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()])
dataset = datasets.ImageFolder(os.path.join(root_folder, 'valid'), transformation)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [None]:
images, labels = next(iter(loader))

In [None]:
image_grid = utils.make_grid(images)
plot_image(image_grid, figsize=(15,15))

### 4. Prepare training dataset

In [None]:
transformation = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
dataset = {x: datasets.ImageFolder(os.path.join(root_folder, x), transformation[x]) for x in ['train', 'valid']}
loader = {x: torch.utils.data.DataLoader(dataset[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'valid']}

dataset_size = {x: len(dataset[x]) for x in ['train', 'valid']}
class_names = dataset['train'].classes

In [None]:
dataset_size

In [None]:
class_names

### 5. Classify with original renset

In [None]:
plot_classify(dataset['valid'][4][0], resnet18, topn=10, category_names=get_imagenet_category_names())

---

## Finetune resnet
We will use pretrained model on top of which we will place small classification feed forward network. Then we will train this whole architecture in 2 steps:   

1) fix of Resnet pretrained weights and tain just our simple added classifier     
2) use smaller learning rate and adjust also top layers of Resnet


### Setup model

In [None]:
from torch import nn, optim

In [None]:
output_size = len(dataset['train'].classes)
output_size

In [None]:
resnet18 = models.resnet18(pretrained=True)

In [None]:
resnet18.fc = torch.nn.Sequential(
    torch.nn.Linear(512, output_size),
    torch.nn.Softmax(dim=1))

loss_fce = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet18.fc.parameters())

### Basic finetuning of last layers

In [None]:
# Disable all params from training.
for param in resnet18.parameters():
    param.requires_grad = False

In [None]:
for param in resnet18.fc.parameters():
    param.requires_grad = True

In [None]:
model = resnet18

In [None]:
def get_valid_acc_and_loss(model, loss_fce, valid_loader):
    accuracy = 0
    loss = 0
    was_training = model.training
    model.eval()
    
    for images, labels in valid_loader:
        predictions = model(images)
        accuracy += (predictions.argmax(dim=1) == labels).type(torch.FloatTensor).mean().item() 
        loss += loss_fce(predictions, labels).item()
    
    model.train(mode=was_training)
    return accuracy / len(valid_loader) * 100, loss / len(valid_loader)

In [None]:
from collections import deque

# Initial params setup.
epochs = 20
report_period = 3
batch_iteration = 0

# Storing of some data.
train_leak_loss = deque(maxlen=report_period)
train_loss_history = []
valid_loss_history = []
valid_acc_history = []

In [None]:
for epoch in range(epochs):
    # Setup net to train mode and go through one epoch.
    model.train()
    for images, labels in loader['train']:
        batch_iteration += 1
        
        ##################
        # Training Phase #
        ##################        
        optimizer.zero_grad()
        predictions = model.forward(images)
        loss = loss_fce(predictions, labels)
        loss.backward()
        optimizer.step()
        
        ####################
        # Validation Phase #
        ####################
        train_leak_loss.append(loss.item())
        if batch_iteration % report_period == 0:
            model.eval()
            # We don't want to collect info for gradients from here.
            with torch.no_grad():
                valid_accuracy, valid_loss = get_valid_acc_and_loss(model, loss_fce, loader['valid'])
                
            print(f'Epoch: {epoch+1}/{epochs}.. ',
                  f"Train Loss: {round(np.mean(train_leak_loss), 2)}.. ",
                  f"Valid Loss: {round(valid_loss, 2)}.. ",
                  f"Valid Acc: {round(valid_accuracy, 2)}%")
            
            train_loss_history.append(np.mean(train_leak_loss))
            valid_loss_history.append(valid_loss)
            valid_acc_history.append(valid_accuracy)
                        
            model.train()

### Extra finetuning of resnet layer4
Nice reading about advanced strategies for setup of lr here: ***https://www.jeremyjordan.me/nn-learning-rate/***

 - let's setup as trainable parameters of last layer of resnet model
 - make learning rate 10 times smaller
 - re-run previous training cell with adjusted setup

In [None]:
optimizer.param_groups[0]['lr']

In [None]:
optimizer.param_groups[0]['params']

In [None]:
epochs=5

In [None]:
for param in model.layer4.parameters():
    param.requires_grad = True

In [None]:
optimizer.param_groups[0]['params'] = optimizer.param_groups[0]['params'] + list(model.layer4.parameters())

In [None]:
optimizer.param_groups[0]['lr'] = 0.0001

**Go back and re-run training cell**

### After all training save the model

In [None]:
training_state = {
    'optimizer': optimizer.param_groups,
    'net': model.state_dict()
}

In [None]:
torch.save(training_state, './models/finetuned_resnet.pth')

### Visualize progress of training

In [None]:
import matplotlib.pylab as plt 

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.gca()
ax.set_xlabel('Iteration')
ax.set_ylabel('Cross Entropy')
plt.plot(train_loss_history, label='Train loss')
plt.plot(valid_loss_history, label='Valid loss')
plt.legend(frameon=False)

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.plot(valid_acc_history, label='Valid acc')
ax = plt.gca()
ax.set_xlabel('Iteration')
ax.set_ylabel('Acc(%)')
plt.legend(frameon=False)

---

## Results evaluation

In [None]:
info = model.eval()

### View single images and predictions

In [None]:
from image_processing_workshop.visual import plot_classify, plot_image

In [None]:
plot_classify(dataset['valid'][120][0], model, category_names=class_names)

### Load reuslts to pandas df

In [None]:
from image_processing_workshop.eval import get_results_df
from image_processing_workshop.visual import plot_df_examples

In [None]:
df = get_results_df(model, loader['valid'])
df.head(10)

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.gca()
ax.set_xlabel('Prediction Score')
df[df.label_class_name=='one of your categories'].label_class_score.hist(ax=ax)

In [None]:
plot_df_examples(df.iloc[:25])

### Overall Recall and Precision


In [None]:
from image_processing_workshop.eval import get_rec_prec

In [None]:
get_rec_prec(df, class_names)

### Accuracy

In [None]:
from image_processing_workshop.eval import get_accuracy

In [None]:
get_accuracy(df)

### False Positives

In [None]:
from image_processing_workshop.eval import get_false_positives

In [None]:
fp = get_false_positives(df, label_class_name='one of your categories')
plot_df_examples(fp)

### Confusion Matrix

In [None]:
from image_processing_workshop.visual import plot_coocurance_matrix

In [None]:
plot_coocurance_matrix(df, use_log=False)