In [1]:
from nnet import Net
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torch
from torch.utils.data import DataLoader,random_split
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
from PIL import Image
from sklearn.metrics import f1_score 
from utils import TRANSFORM,NUM_CHANNELS

In [2]:
EPOCHS = 10
BATCH_SIZE = 16
LR = 0.01
MOMENTUM = 0.9

In [3]:
def load_dataset_from_folder(all_data_path = '../data/Generic',validation_split_size= 0.1,batch_size = 16, num_workers = 2,shuffle = True):
    all_data = ImageFolder(
        root = all_data_path,
        transform = TRANSFORM
    )
    
    classes = all_data.classes

    validation_size = int(validation_split_size * len(all_data))
    train_size = len(all_data) - validation_size
    train_dataset, test_dataset = torch.utils.data.random_split(all_data, [train_size, validation_size])

    training_data_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle
    )
    
    validation_dataset_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle
    )
        
    return training_data_loader,validation_dataset_loader,classes

In [4]:
net = Net(NUM_CHANNELS)
trainloader,testloader,classes = load_dataset_from_folder(batch_size = BATCH_SIZE)
# net.save(classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=MOMENTUM)

In [5]:
for epoch in range(EPOCHS):

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data


        outputs = net(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 10 == 9:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0
            
    net.save(classes)
    print('Model saved')

[1,    10] loss: 1.392
[1,    20] loss: 1.386
[1,    30] loss: 1.386
[1,    40] loss: 1.386
[1,    50] loss: 1.381
[1,    60] loss: 1.351
[1,    70] loss: 1.343
[1,    80] loss: 1.255
[1,    90] loss: 1.313
[1,   100] loss: 1.318
[2,    10] loss: 1.292
[2,    20] loss: 1.247
[2,    30] loss: 1.187
[2,    40] loss: 1.150
[2,    50] loss: 1.106
[2,    60] loss: 0.939
[2,    70] loss: 0.942
[2,    80] loss: 0.846
[2,    90] loss: 0.759
[2,   100] loss: 0.821
[3,    10] loss: 0.571
[3,    20] loss: 0.765
[3,    30] loss: 0.531
[3,    40] loss: 0.501
[3,    50] loss: 0.453
[3,    60] loss: 0.264
[3,    70] loss: 0.263
[3,    80] loss: 0.322
[3,    90] loss: 0.239
[3,   100] loss: 0.227
[4,    10] loss: 0.260
[4,    20] loss: 0.105
[4,    30] loss: 0.317
[4,    40] loss: 0.168
[4,    50] loss: 0.148
[4,    60] loss: 0.120


KeyboardInterrupt: 

In [6]:
class_correct = list(0. for i in range(4))
class_total = list(0. for i in range(4))

y_true = []
y_pred = []

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
#         print(images.shape)
        _, predicted = torch.max(outputs, 1)
        for i in range(len(predicted)):
            y_pred.append(predicted[i])
            y_true.append(labels[i])
        c = (predicted == labels).squeeze()
        for i in range(min(BATCH_SIZE,len(labels))):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(4):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

print(f1_score(y_true, y_pred, average='weighted'))

torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([16, 5, 50, 50])
torch.Size([14, 5, 50, 50])
Accuracy of  Next : 98 %
Accuracy of Others : 87 %
Accuracy of Pause : 88 %
Accuracy of  Prev : 97 %
0.927645220646553
