In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from skimage import io, transform
import numpy as np
import os
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from sklearn.model_selection import train_test_split
import pickle
import matplotlib.pyplot as plt
from datetime import datetime
from models import Classifier
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

--- **Utils** ---

In [13]:
def create_model_path():
    return os.path.join('models', datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) + '.pt'

In [12]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

--- . ---

In [3]:
composed = transforms.Compose([
    transforms.Resize((60,60)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

main_set = ImageFolder(root='simpsons_dataset', transform=composed)

In [4]:
batch_size = 500

In [5]:
train_set, val_set = random_split(main_set, (len(main_set) - int(len(main_set) * 0.2), int(len(main_set) * 0.2)))

In [6]:
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=2, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

In [7]:
class Classifier(nn.Module):
    def __init__(self, softmax=False):
        super(Classifier, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)

        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(128)

        self.pool = nn.MaxPool2d(2, 2)

        self.linear1 = nn.Linear(in_features=128 * 5 * 5, out_features=64)
        self.linear2 = nn.Linear(in_features=64, out_features=42)

        self.dropout1 = nn.Dropout(p=.25)
        self.dropout2 = nn.Dropout(p=.5)

    def forward(self, x):
        # conv 1
        x = self.conv1(x)
        x = F.relu(x)
        x = self.batch_norm1(x)
        x = self.pool(x)
        x = self.dropout1(x)

        # conv 2
        x = self.conv2(x)
        x = F.relu(x)
        x = self.batch_norm2(x)
        x = self.pool(x)
        x = self.dropout1(x)

        # conv 3
        x = self.conv3(x)
        x = F.relu(x)
        x = self.batch_norm3(x)
        x = self.pool(x)
        x = self.dropout1(x)

        # flatten
        x = x.view(-1, 128 * 5 * 5)

        # linear 1
        x = self.linear1(x)
        x = F.relu(x)

        # linear 2 - output
        x = self.linear2(x)
        return x
#         return F.softmax(x, dim=1)

In [8]:
classifier = Classifier()

In [9]:
optimizer = optim.Adam(classifier.parameters(), lr=.003)
criterion = CrossEntropyLoss()
criterion_validation = CrossEntropyLoss()
epochs = 20

In [10]:
losses = []
validation_losses = []

model_path = create_model_path()
for i in range(epochs):
    total_loss = 0
    valid_loss = 0
    for j, data in enumerate(train_loader, 0):
        images, labels = data

        optimizer.zero_grad()

        outputs = classifier(images)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() / 500
            
    with torch.no_grad():
        for data in val_loader:
            images, labels = data

            outputs = classifier(images)

            loss = criterion_validation(outputs, labels)

            valid_loss += loss.item() / 500
            
    print('Epoch: {}, loss: {}, valid loss: {}'.format(i+1, total_loss, valid_loss))
    losses.append(total_loss)
    validation_losses.append(valid_loss)

torch.save(classifier, model_path)

KeyboardInterrupt: 

In [None]:
plt.plot(np.arange(epochs), losses, label='loss')
plt.plot(np.arange(epochs), validation_losses, label='valid loss')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('loss')