In [1]:
import os

from matplotlib import pyplot as plt
import numpy as np
import time
from tqdm import tqdm
from torchvision import datasets, models, transforms
import torch
from functools import reduce
from typing import Union
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from pathlib import Path

%load_ext autoreload
%autoreload 2

## Описание эксперемента
В этом примере мы применим методы визуализации для анализа классификационной модели.  
Пусть у нас есть набор данных caltech101, для которого мы хотим получить хороший классификатор. 
 
Этот ноутбук состоит из следующих этапов:
1. Подготовка данных
2. Загрузка предобученной на ImageNet модели, которую мы будем дообучать
3. Дообучение модели
4. Анализ ошибок модели на валидации с помощью методов визуализации, рассмотренных на лекции

# 1. Подготовка данных 

In [2]:
### Let's have a cell with global hyperparameters for the CNNs in this notebook

# Path to a directory with image dataset and subfolders for training, validation and final testing
DATA_PATH = "../datasets"  # PATH TO THE DATASET

# Number of threads for data loader
NUM_WORKERS = 4

# Image size: even though image sizes are bigger than 96, we use this to speed up training
SIZE_H = SIZE_W = 224
N_CHANNELS = 3

# Number of classes in the dataset
NUM_CLASSES = 2

# Epochs: number of passes over the training data, we use it this small to reduce training babysitting time
EPOCH_NUM = 30

# Batch size: for batch gradient descent optimization, usually selected as 2**K elements
BATCH_SIZE = 32

# Images mean and std channelwise
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]

# Last layer (embeddings) size for CNN models
EMBEDDING_SIZE = 256

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [3]:
transformer = transforms.Compose(
    [
        transforms.Resize((SIZE_H, SIZE_W)),  # scaling images to fixed size
        transforms.ToTensor(),  # converting to tensors
        transforms.Lambda(
            lambda x: torch.cat([x, x, x], 0) if x.shape[0] == 1 else x
        ),  # treat gray images
        transforms.Normalize(image_mean, image_std),  # normalize image data per-channel
    ]
)

In [4]:
# load dataset and split it into train and val
caltech101 = torchvision.datasets.Caltech101(
    root=DATA_PATH, download=True, transform=transformer
)
torch.manual_seed(0)
train_dataset, val_dataset = torch.utils.data.random_split(caltech101, [7000, 1677])

caltech101_unchanged = torchvision.datasets.Caltech101(root=DATA_PATH, download=True)
torch.manual_seed(0)
train_dataset_unchanged, val_dataset_unchanged = torch.utils.data.random_split(
    caltech101_unchanged, [7000, 1677]
)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Extracting ../datasets/caltech101/101_ObjectCategories.tar.gz to ../datasets/caltech101


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Extracting ../datasets/caltech101/Annotations.tar to ../datasets/caltech101
Files already downloaded and verified


In [5]:
n_train, n_val = len(train_dataset), len(val_dataset)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
)

# 2. Загрузка предобученной на ImageNet модели VGG16

In [6]:
""" VGG16
    """

num_classes = 101
model_ft = models.vgg16(pretrained=True)
model_ft.classifier[6] = nn.Linear(model_ft.classifier[6].in_features, num_classes)
model_ft.to(device);



# 3. Дообучение модели VGG16 на нашем датасете
Если у вас не хватает видео памяти, попробуйте уменьшить размер батча *BATCH_SIZE*

In [7]:
def compute_accuracy(model, val_loader):
    val_accuracy = []
    for X_batch, y_batch in val_loader:
        # move data to target device
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        # compute logits
        logits = model(X_batch)
        y_pred = logits.max(1)[1].data
        val_accuracy.append(np.mean((y_batch.cpu() == y_pred.cpu()).numpy()))
    return val_accuracy


def train_model(model, train_loader, val_loader, loss_fn, opt, n_epochs):
    """
    model: нейросеть для обучения,
    train_loader, val_loader: загрузчики данных
    loss_fn: целевая метрика (которую будем оптимизировать)
    opt: оптимизатор (обновляет веса нейросети)
    n_epochs: кол-во эпох, полных проходов датасета
    """
    train_loss = []
    val_accuracy = []

    for epoch in range(n_epochs):
        start_time = time.time()

        model.train(True)  # enable dropout / batch_norm training behavior
        for X_batch, y_batch in train_loader:
            # move data to target device
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # train on batch: compute loss, calc grads, perform optimizer step and zero the grads
            opt.zero_grad()
            predictions = model(X_batch)
            loss = loss_fn(predictions, y_batch)
            loss.backward()
            #             torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            opt.step()
            train_loss.append(loss.item())

        model.train(False)  # disable dropout / use averages for batch_norm
        val_accuracy += compute_accuracy(model, val_loader)

        # print the results for this epoch:
        print(f"Epoch {epoch + 1} of {n_epochs} took {time.time() - start_time:.3f}s")

        train_loss_value = np.mean(train_loss[-n_train // BATCH_SIZE :])
        val_accuracy_value = np.mean(val_accuracy[-n_val // BATCH_SIZE :]) * 100

        print(f"  training loss (in-iteration): \t{train_loss_value:.6f}")
        print(f"  validation accuracy: \t\t\t{val_accuracy_value:.2f} %")

    return train_loss, val_accuracy

In [8]:
optimizer_ft = torch.optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()
train_loss, val_accuracy = train_model(
    model_ft, train_loader, val_loader, loss_fn, optimizer_ft, EPOCH_NUM
)

Epoch 1 of 30 took 65.559s
  training loss (in-iteration): 	1.282204
  validation accuracy: 			91.43 %
Epoch 2 of 30 took 66.473s
  training loss (in-iteration): 	0.259038
  validation accuracy: 			92.66 %
Epoch 3 of 30 took 66.967s
  training loss (in-iteration): 	0.116465
  validation accuracy: 			92.96 %
Epoch 4 of 30 took 67.112s
  training loss (in-iteration): 	0.071890
  validation accuracy: 			93.25 %
Epoch 5 of 30 took 67.167s
  training loss (in-iteration): 	0.031698
  validation accuracy: 			93.19 %
Epoch 6 of 30 took 67.244s
  training loss (in-iteration): 	0.033263
  validation accuracy: 			93.84 %
Epoch 7 of 30 took 67.261s
  training loss (in-iteration): 	0.018628
  validation accuracy: 			94.14 %
Epoch 8 of 30 took 67.262s
  training loss (in-iteration): 	0.020157
  validation accuracy: 			93.49 %
Epoch 9 of 30 took 67.373s
  training loss (in-iteration): 	0.023079
  validation accuracy: 			94.14 %
Epoch 10 of 30 took 67.407s
  training loss (in-iteration): 	0.018414
  v

In [9]:
torch.save(model_ft.state_dict(), "../trained_models/vgg16.pt")

# load model

num_classes = 101
model_ft = models.vgg16(pretrained=True)
model_ft.classifier[6] = nn.Linear(model_ft.classifier[6].in_features, num_classes)
model_ft.to(device)
model_ft.load_state_dict(torch.load("../trained_models/vgg16.pt"))
model_ft.eval();

In [10]:
# compute val accuracy

np.mean(compute_accuracy(model_ft, val_loader)) * 100

94.96099419448475

# 4. Анализ предсказаний нейросети

## Визуализация одного примера

In [13]:
import sys

sys.path.append("../../../seminars/")
from s5_visualization.scripts.visualize_cnn import get_explanations

In [14]:
image_index = 44
image_unchanged, image_category = val_dataset_unchanged[image_index]
image_transformed = torch.unsqueeze(transformer(image_unchanged), 0).to(device)
get_explanations(
    model_ft,
    image_transformed,
    image_unchanged,
    caltech101.categories[image_category],
    image_category,
    Path("../outputs/test.png"),
)



## Визуализация валидации

In [15]:
model_ft.eval()
out_dir = Path("../outputs/caltech101_vis/")
for image_index in tqdm(range(len(val_dataset_unchanged))):
    image_unchanged, image_category = val_dataset_unchanged[image_index]
    true_category_name = caltech101.categories[image_category]
    # prepare image
    image_transformed = torch.unsqueeze(transformer(image_unchanged), 0).to(device)
    # get class scores
    class_scores = model_ft(image_transformed)
    class_scores = class_scores.detach().cpu().numpy()[0]
    predicted_class = np.argmax(class_scores)

    formated_image_index = str(image_index).zfill(4)

    if predicted_class == image_category:
        # right classified image. Save its visualization to foulder with class category name

        save_path = out_dir / Path(
            f"true/{true_category_name}_id_{image_category}/{formated_image_index}.png"
        )
        get_explanations(
            model_ft,
            image_transformed,
            image_unchanged,
            true_category_name,
            image_category,
            save_path,
        )
    else:
        # misclassified image. Save vis with respect to true and predicted classes
        predicted_category_name = caltech101.categories[predicted_class]
        predicted_class_score = str(round(class_scores[predicted_class], 3))
        true_class_score = str(round(class_scores[image_category], 3))
        save_path_predicted_vis = out_dir / Path(
            f"mis_class/{predicted_category_name}_id_{predicted_class}_predicted"
        )
        save_path_predicted_vis /= Path(
            f"{formated_image_index}_predicted_category({predicted_category_name})_score_{predicted_class_score}_vis.png"
        )

        save_path_true_vis = out_dir / Path(
            f"mis_class/{predicted_category_name}_id_{predicted_class}_predicted"
        )
        save_path_true_vis /= Path(
            f"{formated_image_index}_true_category({true_category_name})_score_{true_class_score}_vis.png"
        )

        # predicted target vis
        get_explanations(
            model_ft,
            image_transformed,
            image_unchanged,
            true_category_name,
            predicted_class,
            save_path_predicted_vis,
        )
        # true target vis
        get_explanations(
            model_ft,
            image_transformed,
            image_unchanged,
            true_category_name,
            image_category,
            save_path_true_vis,
        )

100%|██████████| 1677/1677 [35:26<00:00,  1.27s/it] 
