In [None]:
import os
import urllib.request
import tarfile
from shutil import copyfile
import math
import numpy as np
import pandas as pd

from torchvision import transforms

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.datasets import ImageFolder

from torch.utils import model_zoo

from pytoune.framework import Model, ModelCheckpoint, BestModelRestore, CSVLogger
from pytoune import torch_to_numpy

In [None]:
def download_and_extract_dataset(path):
    tgz_filename = "images.tgz"
    urllib.request.urlretrieve("http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz", tgz_filename)
    os.makedirs(path, exist_ok=True)
    archive = tarfile.open(tgz_filename)
    archive.extractall(path)

In [None]:
def copy(source_path, filenames, dest_path):
    for filename in filenames:
        source = os.path.join(source_path, filename)
        dest = os.path.join(dest_path, filename)
        copyfile(source, dest)

def split_train_valid_test(dataset_path, train_path, valid_path, test_path, train_split=0.6, valid_split=0.2): # test_split=0.2
    np.random.seed(42)
    for classname in sorted(os.listdir(dataset_path)):
        if classname.startswith('.'):
            continue
        train_class_path = os.path.join(train_path, classname)
        valid_class_path = os.path.join(valid_path, classname)
        test_class_path = os.path.join(test_path, classname)

        os.makedirs(train_class_path, exist_ok=True)
        os.makedirs(valid_class_path, exist_ok=True)
        os.makedirs(test_class_path, exist_ok=True)

        dataset_class_path = os.path.join(dataset_path, classname)
        filenames = sorted(filename for filename in os.listdir(dataset_class_path) if not filename.startswith('.'))
        np.random.shuffle(filenames)

        num_examples = len(filenames)
        train_last_idx = math.ceil(num_examples*train_split)
        valid_first_idx = train_last_idx + math.floor(num_examples*valid_split)
        train_filenames = filenames[0:train_last_idx]
        valid_filenames = filenames[train_last_idx:valid_first_idx]
        test_filenames = filenames[valid_first_idx:]
        copy(dataset_class_path, train_filenames, train_class_path)
        copy(dataset_class_path, valid_filenames, valid_class_path)
        copy(dataset_class_path, test_filenames, test_class_path)

In [None]:
# We do the split train/valid/test.

base_path = './CUB200'
dataset_path = os.path.join(base_path, 'images')
train_path = os.path.join(base_path, 'train')
valid_path = os.path.join(base_path, 'valid')
test_path = os.path.join(base_path, 'test')

In [None]:
download_and_extract_dataset(base_path)
split_train_valid_test(dataset_path, train_path, valid_path, test_path)

In [None]:
cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

In [None]:
# Training hyperparameters

batch_size = 32
learning_rate = 0.01
n_epoch = 30
num_classes = 200

In [None]:
torch.manual_seed(42)

In [None]:
# Creation of the PyTorch's datasets for our problem.

norm_coefs = {}
norm_coefs['cub200'] = [(0.47421962,  0.4914721 ,  0.42382449), (0.22846779,  0.22387765,  0.26495799)]
norm_coefs['imagenet'] = [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(*norm_coefs['cub200'])
])

train_set = ImageFolder(train_path, transform=transform)
valid_set = ImageFolder(valid_path, transform=transform)
test_set = ImageFolder(test_path, transform=transform)


train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)

In [None]:
# Loading a pretrained ResNet-18 networks and replacing 
# the head with the number of neurons equal to our number 
# of classes.

resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)

In [None]:
# We freeze the network except for its head.

def freeze_weights(resnet18):
    for name, param in resnet18.named_parameters():
        if not name.startswith('fc.'):
            param.require_grads = False

freeze_weights(resnet18)

In [None]:
# One nice feature of Pytoune is callbacks.

callbacks = [
    # Save the latest weights to be able to continue the optimization at the end for more epochs.
    ModelCheckpoint('last_epoch.ckpt', temporary_filename='last_epoch.ckpt.tmp'),
    
    # Save the weights in a new file when the current model is better than all previous models.
    ModelCheckpoint('best_epoch_{epoch}.ckpt', monitor='val_acc', mode='max', save_best_only=True, restore_best=True, verbose=True, temporary_filename='best_epoch.ckpt.tmp'),
    
    # Save the losses and accuracies for each epoch in a TSV.
    CSVLogger('log.tsv', separator='\t'),
]

In [None]:
# Finally, we start the training and output its final test 
# loss and accuracy.

# Optimizer and loss function
optimizer = optim.SGD(resnet18.parameters(), lr=learning_rate, weight_decay=0.001)
loss_function = nn.CrossEntropyLoss()

# Pytoune Model
model = Model(resnet18, optimizer, loss_function, metrics=['accuracy'])

# Send model on GPU
model.to(device)

# Train
model.fit_generator(train_loader, valid_loader, epochs=n_epoch, callbacks=callbacks)

# Test
test_loss, test_acc = model.evaluate_generator(test_loader)
print('Test:\n\tLoss: {}\n\tAccuracy: {}'.format(test_loss, test_acc))

In [None]:
logs = pd.read_csv('log.tsv', sep='\t')
print(logs)

best_epoch_idx = logs['val_acc'].idxmax()
best_epoch = int(logs.loc[best_epoch_idx]['epoch'])
print("Best epoch: %d" % best_epoch)

In [None]:
# Restore best model from checkpoint and test it.

resnet18 = models.resnet18(pretrained=False, num_classes=num_classes)

model = Model(resnet18, None, nn.CrossEntropyLoss(), metrics=['accuracy'])

model.to(device)

model.load_weights('best_epoch_{epoch}.ckpt'.format(epoch=best_epoch))

test_loss, test_acc = model.evaluate_generator(test_loader)
print('Test:\n\tLoss: {}\n\tAccuracy: {}'.format(test_loss, test_acc))