In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np

In [None]:
path = "/home/aldw/Repositories/setsolver-card-classifier/images"

# mean and std calculated with some python script
mean = [0.6, 0.6, 0.6]
std = [0.2, 0.2, 0.2]

train_transforms = transforms.Compose([transforms.RandomPerspective(distortion_scale=0.2),
                                       transforms.RandomRotation(25),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean, std)]) 

test_transforms = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(mean, std)]) 

train_valid_data = datasets.ImageFolder(root=f"{path}/train",transform=train_transforms)
test_data = datasets.ImageFolder(root=f"{path}/test",transform=test_transforms)

train_size = int(np.floor(0.8 * len(train_valid_data)))
valid_size = len(train_valid_data) - train_size
train_data, valid_data = torch.utils.data.random_split(train_valid_data, [train_size, valid_size])
    
# Create dataloader
trainloader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)
validloader = torch.utils.data.DataLoader(valid_data, batch_size=1, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=100, shuffle=True)

In [None]:
# Some helper functions
def show_image(img,label=""):
    print(label)
    img = img / 2 + 0.5
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img)
    plt.show()
    
def show(img_loader, n_images):
    dataiter = iter(img_loader)
    batch = next(dataiter)
    labels = batch[1][0:5]
    images = batch[0][0:5]
    for i in range(n_images):
        show_image(images[i],labels[i])

In [None]:
show(trainloader, 2)
show(testloader, 2)

In [None]:
class SetCardClassifier(nn.Module):
    def __init__(self):
        super(SetCardClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 5, padding=1)
        self.conv5 = nn.Conv2d(128, 256, 5, padding=1)
        self.pool = nn.MaxPool2d(2, 2) 
        self.fc1 = nn.Linear(3*6*256, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 81)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # start: 160x250x3, after: 80x125x16
        x = self.pool(F.relu(self.conv2(x))) # after: 40x62x32
        x = self.pool(F.relu(self.conv3(x))) # after: 20x31x64
        x = self.pool(F.relu(self.conv4(x))) # after: 9x14x128
        x = self.pool(F.relu(self.conv5(x))) # after: 3x6x256
        x = x.view(-1, 3*6*256) # flatten image input
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = F.log_softmax(self.fc3(x), dim=1)
        return x

In [None]:
model = SetCardClassifier()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);
print(f'training on device {device}')

epochs = 1500
min_valid_loss = np.Inf # track change in validation loss
saved_train_loss = np.Inf # track the training loss when saving a model (in case the valid loss is equal)

for epoch in range(epochs):
    #######################
    # Training            #
    #######################
    train_loss = 0
    for inputs, labels in trainloader:
        # Move input and label tensors to the default device
        inputs, labels = inputs.to(device), labels.to(device)
        
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    #######################
    # Validation          #
    #######################
    valid_loss = 0
    accuracy = 0
    model.eval()
    with torch.no_grad():
        # print("===========================================================")
        for inputs, labels in validloader:
            # show_image(inputs[0], labels[0])
            inputs, labels = inputs.to(device), labels.to(device)
            logps = model.forward(inputs)
            batch_loss = criterion(logps, labels)

            valid_loss += batch_loss.item()

            # Calculate accuracy
            ps = torch.exp(logps)
            top_p, top_class = ps.topk(1, dim=1)
            # print(f"label {labels[0]} | prediction {top_class} | prob {top_p}")
            equals = top_class == labels.view(top_class.shape)
            accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
    model.train()

    #######################
    # Logging             #
    #######################
    avg_train_loss = train_loss/len(trainloader)
    avg_valid_loss = valid_loss/len(validloader)
    avg_valid_acc = accuracy/len(validloader)
    print(f"Epoch {epoch+1:03d}/{epochs:03d}.. "
          f"Train loss: {avg_train_loss:.3f}.. "
          f"Valid loss: {avg_valid_loss:.3f}.. "
          f"Valid accuracy: {avg_valid_acc:.3f}")
        
    #######################
    # Saving best model   #
    #######################
    if (avg_valid_loss < min_valid_loss) or (avg_valid_loss == min_valid_loss and avg_train_loss < saved_train_loss):
        print(f"found a better model, saving...")
        torch.save(model.state_dict(), 'model.pt')
        min_valid_loss = avg_valid_loss
        saved_train_loss = avg_train_loss
    

In [None]:
# load the best model and test the model with new input data 
state_dict = torch.load('model.pt')
best_model = SetCardClassifier()
best_model.load_state_dict(state_dict)
best_model.to(device)

test_accuracy = 0

best_model.eval()
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        logps = best_model.forward(inputs)
        
        # Calculate accuracy
        ps = torch.exp(logps)
        top_p, top_class = ps.topk(1, dim=1)
        equals = top_class == labels.view(top_class.shape)
        test_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
    
    print(f"Test accuracy: {test_accuracy/len(testloader):.3f}")