In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import glob
import cv2
from torchvision.transforms import ToTensor, Resize
import random
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MyDataset(Dataset):
    def __init__(self, train):
        data_path = "dataset"
        self.data_list = glob.glob(data_path + "/*/*.png")
        random.shuffle(self.data_list)
        if train:
            self.data_list = self.data_list[:int(len(self.data_list) * 0.8)]
        else:
            self.data_list = self.data_list[int(len(self.data_list) * 0.8):]

    def __getitem__(self, index):
        img = cv2.imread(self.data_list[index])
        img = cv2.resize(img, (64, 64))
        img = ToTensor()(img)
        label = self.data_list[index].split("/")[-1].split("\\")[-2]
        return img, int(label)

    def __len__(self):
        return len(self.data_list)

ds_train = MyDataset(train=True)
ds_val = MyDataset(train=False)

In [3]:
dl_train = DataLoader(ds_train, batch_size = 32, shuffle = True)
dl_val = DataLoader(ds_val, batch_size = 32, shuffle = True)

In [4]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.dropout = nn.Dropout(0.25)

        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 2)
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.pool(out)

        out = self.conv2(out)
        out = F.relu(out)
        out = self.pool2(out)

        out = self.conv3(out)
        out = F.relu(out)
        out = self.pool3(out)

        out = self.conv4(out)
        out = F.relu(out)
        out = self.pool4(out)

        out = out.view(out.size(0), -1)

        out = self.dropout(out)
        out = self.fc1(out)

        out = F.relu(out)

        out = self.dropout(out)
        out = self.fc2(out)
        
        return out
    
model = MyModel()

In [5]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [6]:
def correction(pred, label):
    correct = 0
    with torch.no_grad():
        for data, label in dl_val:
            pred = model(data)
            pred = torch.argmax(pred, dim = 1)
            correct += torch.sum(pred == label).item()
    return correct / len(dl_val.dataset)

In [7]:
epochs = 30
for epoch in range(epochs):
    for data, label in dl_train:
        pred = model(data)
        loss = loss_fn(pred, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print("Epoch: {}".format(epoch))
    print("Train Loss: {}".format(loss.item()))
    print("Validation Accuracy: {}".format(correction(pred, label)))

Epoch: 0
Train Loss: 0.7015120387077332
Validation Accuracy: 0.5
Epoch: 1
Train Loss: 0.6604321002960205
Validation Accuracy: 0.75
Epoch: 2
Train Loss: 0.7293635606765747
Validation Accuracy: 0.7386363636363636
Epoch: 3
Train Loss: 0.35158881545066833
Validation Accuracy: 0.8863636363636364
Epoch: 4
Train Loss: 0.4426625072956085
Validation Accuracy: 0.9318181818181818
Epoch: 5
Train Loss: 0.5439110994338989
Validation Accuracy: 0.9204545454545454
Epoch: 6
Train Loss: 0.26401054859161377
Validation Accuracy: 0.9431818181818182
Epoch: 7
Train Loss: 0.16540445387363434
Validation Accuracy: 0.9431818181818182
Epoch: 8
Train Loss: 0.11339705437421799
Validation Accuracy: 0.9659090909090909
Epoch: 9
Train Loss: 0.1407967209815979
Validation Accuracy: 0.9545454545454546
Epoch: 10
Train Loss: 0.10916218906641006
Validation Accuracy: 0.9545454545454546
Epoch: 11
Train Loss: 0.051885705441236496
Validation Accuracy: 0.9886363636363636
Epoch: 12
Train Loss: 0.18683941662311554
Validation Accurac

In [16]:
torch.save(model.state_dict(), 'model_weights3.pth')