In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary

import torchvision
import torchvision.transforms as transforms

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)

In [3]:
trans = transforms.Compose([
    transforms.ToTensor()
])

train_data = torchvision.datasets.ImageFolder(root='./custom_data/train_data', transform=trans)

In [4]:
data_loader = DataLoader(dataset = train_data, batch_size = 8, shuffle = True, num_workers=2)

In [5]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,6,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(6,16,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(16*13*29, 120),
            nn.ReLU(),
            nn.Linear(120,2)
        )
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.shape[0], -1)
        out = self.layer3(out)
        return out

In [6]:
net = CNN().to(device)
summary(net, (3, 64, 128), device=device)

-------------------------------------------------------------------------------------------------------------------
        Layer (type)    Kernel Shape          Input Shape              Output Shape         Param #     Multi Ops #
            Conv2d-1    [3, 6, 5, 5]      [2, 3, 64, 128]           [2, 6, 60, 124]       3,348,006       3,348,000
              ReLU-2               -      [2, 6, 60, 124]           [2, 6, 60, 124]               0               0
         MaxPool2d-3               -      [2, 6, 60, 124]            [2, 6, 30, 62]               0               0
            Conv2d-4   [6, 16, 5, 5]       [2, 6, 30, 62]           [2, 16, 26, 58]       3,619,216       3,619,200
              ReLU-5               -      [2, 16, 26, 58]           [2, 16, 26, 58]               0               0
         MaxPool2d-6               -      [2, 16, 26, 58]           [2, 16, 13, 29]               0               0
            Linear-7     [6032, 120]            [2, 6032]               

(7691424, 7691424)

In [7]:
optimizer = optim.Adam(net.parameters(), lr=0.00005)
loss_func = nn.CrossEntropyLoss().to(device)

In [13]:
total_batch = len(data_loader)

def train():
    epochs = 7
    for epoch in range(epochs):
        avg_cost = 0.0
        for num, data in enumerate(data_loader):
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            out = net(imgs)
            loss = loss_func(out, labels)
            loss.backward()
            optimizer.step()

            avg_cost += loss / total_batch

        print('[Epoch:{}] cost = {}'.format(epoch+1, avg_cost))
    print('Learning Finished!')

In [14]:
%time train()

[Epoch:1] cost = 0.005010341294109821
[Epoch:2] cost = 0.0037535890005528927
[Epoch:3] cost = 0.0028771429788321257
[Epoch:4] cost = 0.002247781725600362
[Epoch:5] cost = 0.0018486643675714731
[Epoch:6] cost = 0.001478617312386632
[Epoch:7] cost = 0.0012478209100663662
Learning Finished!
CPU times: user 30.3 s, sys: 1.16 s, total: 31.5 s
Wall time: 11.6 s


In [9]:
torch.save(net.state_dict(), "./models/model.pth")

In [10]:
new_net = CNN().to(device)
new_net.load_state_dict(torch.load('./models/model.pth'))

<All keys matched successfully>

In [11]:
trans=torchvision.transforms.Compose([
    transforms.Resize((64,128)),
    transforms.ToTensor()
])
test_data = torchvision.datasets.ImageFolder(root='./custom_data/test_data', transform=trans)

In [15]:
test_set = DataLoader(dataset = test_data, batch_size = len(test_data))

def test():
    with torch.no_grad():
        for num, data in enumerate(test_set):
            imgs, label = data
            imgs = imgs.to(device)
            label = label.to(device)

            prediction = net(imgs)

            correct_prediction = torch.argmax(prediction, 1) == label

            accuracy = correct_prediction.float().mean()
            print('Accuracy:', accuracy.item())

In [16]:
%time test()

Accuracy: 1.0
CPU times: user 2.48 s, sys: 116 ms, total: 2.6 s
Wall time: 2.63 s
