In [None]:
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler, random_split
import matplotlib.pyplot as plt
from PIL import Image
from collections import deque
from pathlib import Path
import logging, os, glob
from _logging import set_logging
from _metrics import display_metrics
from _pckle import save_pickle_object, load_pickle_object
from _utility import gl, get_perc, get_dictionaries_from_list
from _model import train_model

set_logging(logging)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = ["Business", "Other"]
dict_classes, dict_classes_rev = get_dictionaries_from_list(classes)


In [None]:
train_loader = load_pickle_object(gl.pkl_train_loader)
val_loader = load_pickle_object(gl.pkl_val_loader)
test_loader = load_pickle_object(gl.pkl_test_loader)

In [None]:
def imshow(inp, _mean, _std, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array(_mean)
    std = np.array(_std)
    inp = std * inp + mean  # denormalise
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
def visualize_model(model, dataloaders, classes, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {classes[preds[j]]}')
                imshow(inputs.cpu().data[j])
                if images_so_far == num_images:
                    model.train(mode=was_training)
                return
        model.train(mode=was_training)

In [None]:
dataloaders = {"train": train_loader, "val": val_loader}

dataset_sizes = {"train": len(train_loader.dataset), "val": len(val_loader.dataset)}
model_conv = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
for param in model_conv.parameters():
    param.requires_grad = False
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()
# Only parameters of final layer are being optimized
optimizer_ft = optim.Adam(model_conv.fc.parameters(), lr=0.0001)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
model_conv = train_model(model_conv, criterion, optimizer_ft, exp_lr_scheduler, dataloaders, dataset_sizes, num_epochs=25)
save_pickle_object(model_conv, gl.pkl_model_conv)

In [None]:
visualize_model(model_conv, dataloaders, classes)

plt.ioff()
plt.show()