In [2]:
import os
from urllib import request
from zipfile import ZipFile

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms

from skorch import NeuralNetClassifier
from skorch.helper import predefined_split

torch.manual_seed(360);

<torch._C.Generator at 0x7f0d4867d710>

In [3]:
def download_and_extract_data(dataset_dir='datasets'):
    data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')
    data_path = os.path.join(dataset_dir, 'hymenoptera_data')
    url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"

    if not os.path.exists(data_path):
        if not os.path.exists(data_zip):
            print("Starting to download data...")
            data = request.urlopen(url, timeout=15).read()
            with open(data_zip, 'wb') as f:
                f.write(data)

        print("Starting to extract data...")
        with ZipFile(data_zip, 'r') as zip_f:
            zip_f.extractall(dataset_dir)
        
    print("Data has been downloaded and extracted to {}.".format(dataset_dir))
    
download_and_extract_data('../data/raw/')

Starting to download data...
Starting to extract data...
Data has been downloaded and extracted to ../data/raw/.


In [4]:
data_dir = '../data/raw/hymenoptera_data'
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

train_ds = datasets.ImageFolder(
    os.path.join(data_dir, 'train'), train_transforms)
val_ds = datasets.ImageFolder(
    os.path.join(data_dir, 'val'), val_transforms)

In [5]:
class PretrainedModel(nn.Module):
    def __init__(self, output_features):
        super().__init__()
        model = models.resnet18(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, output_features)
        self.model = model
        
    def forward(self, x):
        return self.model(x)

In [6]:
from skorch.callbacks import LRScheduler

lrscheduler = LRScheduler(policy='StepLR', step_size=7, gamma=0.1)

In [7]:
from skorch.callbacks import Checkpoint

checkpoint = Checkpoint(
    f_params='best_model.pt', monitor='valid_acc_best')

In [8]:
from skorch.callbacks import Freezer

freezer = Freezer(lambda x: not x.startswith('model.fc'))

In [9]:
net = NeuralNetClassifier(
    PretrainedModel, 
    criterion=nn.CrossEntropyLoss,
    lr=0.001,
    batch_size=4,
    max_epochs=25,
    module__output_features=2,
    optimizer=optim.SGD,
    optimizer__momentum=0.9,
    iterator_train__shuffle=True,
    iterator_train__num_workers=4,
    iterator_valid__shuffle=True,
    iterator_valid__num_workers=4,
    train_split=predefined_split(val_ds),
    callbacks=[lrscheduler, checkpoint, freezer],
    device='cuda' # comment to train on cpu
)

In [10]:
net.fit(train_ds, y=None);

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/matheusguim/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:09<00:00, 4.81MB/s]
  epoch    train_loss    valid_acc    valid_loss    cp      dur
-------  ------------  -----------  ------------  ----  -------
      1        [36m0.6321[0m       [32m0.8954[0m        [35m0.2665[0m     +  10.7125
      2        [36m0.5036[0m       0.8497        0.3550        7.7381
      3        [36m0.4133[0m       [32m0.9281[0m        [35m0.2334[0m     +  7.7809
      4        [36m0.3411[0m       [32m0.9412[0m        [35m0.1871[0m     +  7.9085
      5        0.4514       0.8627        0.3128        7.6515
      6        0.4206       0.9020        0.2500        7.6331
      7        0.4306       [32m0.9542[0m        [35m0.1625[0m     +  7.7217
      8        0.5042       0.9412        0.2440        7.1988
      9        [36m0.3114[0m       0.9412        0.2066        