In [93]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

from tqdm.autonotebook import tqdm, trange

import torchvision
from torchvision import models, datasets, transforms
import os

sns.set(font_scale=1.4, style='whitegrid')

In [2]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device

device(type='cuda')

In [4]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(244),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(244),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

data_dir = './hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=3) for x in ['train', 'val']}

dataset_size = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_name = image_datasets['train'].classes

In [50]:
class_name

['ants', 'bees']

In [37]:
def fit(model, criterion, optimizer, scheduler, num_epochs=25):
    losses = {'train': [], 'val': []}

    best_model_wts = model.state_dict()
    best_acc = 0.0

    pbar = trange(num_epochs, desc='Epoch:')
    for epoch in pbar:

        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for xb, yb in tqdm(dataloaders[phase], leave=True, desc=f'{phase} iter:'):
                xb, yb = xb.to(device), yb.to(device)

                if phase == 'train':
                    optimizer.zero_grad()

                if phase == 'eval':
                    with torch.no_grad():
                        outputs = model(xb)
                else:
                    outputs = model(xb)

                preds = torch.argmax(outputs, -1)
                loss = criterion(outputs, yb)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                running_loss += loss.item()
                running_corrects += int(torch.sum(preds == yb.data))

            epoch_loss = running_loss / dataset_size[phase]
            epoch_acc = running_corrects / dataset_size[phase]

            losses[phase].append(epoch_loss)

            pbar.set_description(f'{phase} Loss: {epoch_loss}, Acc: {epoch_acc}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()
    model.load_state_dict(best_model_wts)
    return losses, model

In [84]:
def visualize_model(model, num_images=6):
    images_so_far = 0
    fig = plt.figure()

    for i, data in enumerate(dataloaders['val']):
        inp, lab = data
        inp, lab = inp.to(device), lab.to(device)

        outputs = model(inp)
        _, preds = torch.max(outputs.data, 1)

        for j in range(inp.size()[0]):
            images_so_far += 1
            imshow(inp.cpu().data[j], class_name[preds[j]])

            if images_so_far == num_images:
                return

In [14]:
def evaluate(model):
    model.eval()

    running_correct = 0

    for xb, yb, in dataloaders['val']:
        xb, yb = xb.to(device), yb.to(device)

        outputs = model(xb)
        _, preds = torch.max(outputs, 1)

        running_correct += int(torch.sum(preds == yb))
    return running_correct / dataset_size['val']

In [91]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.figure(figsize=(15, 12))
    plt.axis('off')
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [64]:
model_vgg_extractor = models.vgg16(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\Pack/.cache\torch\hub\checkpoints\vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [65]:
model_vgg_extractor

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [66]:
for param in model_vgg_extractor.features.parameters():
    param.requires_grad = False

In [68]:
num_features = 25088
model_vgg_extractor.classifier = nn.Linear(num_features, 2)
model_vgg_extractor = model_vgg_extractor.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_vgg_extractor.classifier.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)

In [None]:
losses, model_vgg_extractor = fit(model_vgg_extractor, loss_fn, optimizer, exp_lr_scheduler)

In [70]:
f'Accuracy: {evaluate(model_vgg_extractor)}'

'Accuracy: 0.934640522875817'

In [71]:
torch.save(model_vgg_extractor.state_dict(), 'VGG16_extractor.pth')
model_vgg_extractor.load_state_dict(torch.load('VGG16_extractor.pth'))

<All keys matched successfully>

In [None]:
visualize_model(model_vgg_extractor)