# import tool

In [1]:
import numpy as np
import datetime
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms

# import self-define tool

In [2]:
from network_structure import ResNet_18

# REPRODUCIBILITY

In [3]:
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# read data

In [4]:
BATCH_SIZE = 4
dataset = dset.ImageFolder(root = "traindata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)  

In [5]:
dataset = dset.ImageFolder(root = "testdata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
test_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = False)  

# predict function

In [6]:
def get_prediction(data_loader_):
    data_loader = data_loader_
    
    total = 0
    correct = 0
    
    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            outputs = net(X)
            _, pred = torch.max(outputs, 1)

            correct = correct + torch.sum(pred == y).item()
            total = total + len(y)
    
    return( correct/total )

# network structure

In [7]:
DEVICE = "cuda:0"

net = ResNet_18()
net = net.to(DEVICE)

# loss

In [8]:
criterion = nn.CrossEntropyLoss()

# optimizer

In [9]:
LR = 0.0001
optimizer = optim.Adam(net.parameters(), lr = LR)

# start train

In [10]:
EPOCH = 50
train_acc = []
test_acc = []
loss_record = []
print( datetime.date.today().strftime('%Y-%m-%d %H:%M:%S') )
for epoch in range(EPOCH):
    running_loss = 0.0
    
    for X, y in tqdm(train_loader):
        X, y = X.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(X)

        loss = criterion(outputs, y)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
    
    print("==== EPOCH %d/%d ====" % (epoch+1, EPOCH))
    print("loss: %.4f" % running_loss)
    buf1, buf2 = get_prediction(train_loader), get_prediction(test_loader)
    train_acc = train_acc + [buf1]
    test_acc = test_acc + [buf2]
    loss_record = loss_record + [running_loss]
    print("train acc: %.4f ||| test acc: %.4f" % (buf1, buf2))
    
    Path = "resnet_18_epoch_%s.pt" % epoch
    torch.save(net, Path)

2020-08-04 00:00:00


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 1/50 ====
loss: 6160.3105
train acc: 0.7351 ||| test acc: 0.7335


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 2/50 ====
loss: 6030.4169
train acc: 0.7351 ||| test acc: 0.7335


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 3/50 ====
loss: 5778.7020
train acc: 0.7405 ||| test acc: 0.7381


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 4/50 ====
loss: 5607.1452
train acc: 0.7425 ||| test acc: 0.7405


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 5/50 ====
loss: 5491.7367
train acc: 0.7485 ||| test acc: 0.7431


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 6/50 ====
loss: 5406.7231
train acc: 0.7521 ||| test acc: 0.7476


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 7/50 ====
loss: 5343.2354
train acc: 0.7539 ||| test acc: 0.7500


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 8/50 ====
loss: 5255.2323
train acc: 0.7565 ||| test acc: 0.7459


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 9/50 ====
loss: 5175.7705
train acc: 0.7595 ||| test acc: 0.7472


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 10/50 ====
loss: 5111.0778
train acc: 0.7679 ||| test acc: 0.7502


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 11/50 ====
loss: 5008.1442
train acc: 0.7703 ||| test acc: 0.7526


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 12/50 ====
loss: 4898.3500
train acc: 0.7705 ||| test acc: 0.7540


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 13/50 ====
loss: 4785.5025
train acc: 0.7749 ||| test acc: 0.7489


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 14/50 ====
loss: 4741.3190
train acc: 0.7958 ||| test acc: 0.7522


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 15/50 ====
loss: 4500.2491
train acc: 0.7878 ||| test acc: 0.7519


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 16/50 ====
loss: 4368.6704
train acc: 0.8054 ||| test acc: 0.7452


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))

train acc: 0.8138 ||| test acc: 0.7490


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 18/50 ====
loss: 3916.4179
train acc: 0.8256 ||| test acc: 0.7421


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 19/50 ====
loss: 3672.5048
train acc: 0.8392 ||| test acc: 0.7310


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 20/50 ====
loss: 3393.1156
train acc: 0.8468 ||| test acc: 0.7367


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 21/50 ====
loss: 3086.6010
train acc: 0.8679 ||| test acc: 0.7243


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 22/50 ====
loss: 2805.4026
train acc: 0.8938 ||| test acc: 0.7264


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 23/50 ====
loss: 2495.8973
train acc: 0.8955 ||| test acc: 0.7011


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 24/50 ====
loss: 2209.0172
train acc: 0.9175 ||| test acc: 0.6957


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 25/50 ====
loss: 1982.4393
train acc: 0.9276 ||| test acc: 0.6900


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 26/50 ====
loss: 1739.4738
train acc: 0.9145 ||| test acc: 0.7139


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 27/50 ====
loss: 172282.4068
train acc: 0.7968 ||| test acc: 0.7107


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 28/50 ====
loss: 2200.6894
train acc: 0.9650 ||| test acc: 0.6921


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 29/50 ====
loss: 1302.1396
train acc: 0.9611 ||| test acc: 0.6865


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 30/50 ====
loss: 1105.7706
train acc: 0.9578 ||| test acc: 0.7012


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 31/50 ====
loss: 1165.9462
train acc: 0.9623 ||| test acc: 0.6948


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 32/50 ====
loss: 1103.8261
train acc: 0.9337 ||| test acc: 0.6679


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 33/50 ====
loss: 1027.7736
train acc: 0.9786 ||| test acc: 0.6874


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 34/50 ====
loss: 942.4779
train acc: 0.9510 ||| test acc: 0.6396


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 35/50 ====
loss: 888.1288
train acc: 0.9694 ||| test acc: 0.6895


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 36/50 ====
loss: 842.0036
train acc: 0.9649 ||| test acc: 0.6878


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 37/50 ====
loss: 837.7734
train acc: 0.9712 ||| test acc: 0.6870


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 38/50 ====
loss: 723.7255
train acc: 0.9785 ||| test acc: 0.6994


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 39/50 ====
loss: 787.9840
train acc: 0.9781 ||| test acc: 0.6864


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 40/50 ====
loss: 678.6150
train acc: 0.9709 ||| test acc: 0.6689


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 41/50 ====
loss: 732.5013
train acc: 0.9703 ||| test acc: 0.6816


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 42/50 ====
loss: 655.1049
train acc: 0.9828 ||| test acc: 0.6974


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 43/50 ====
loss: 3172843.6027
train acc: 0.9293 ||| test acc: 0.7001


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 44/50 ====
loss: 1124.5209
train acc: 0.9830 ||| test acc: 0.7056


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 45/50 ====
loss: 470.5164
train acc: 0.9856 ||| test acc: 0.7089


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 46/50 ====
loss: 522.5099
train acc: 0.9868 ||| test acc: 0.6780


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 47/50 ====
loss: 631.1640
train acc: 0.9812 ||| test acc: 0.6880


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 48/50 ====
loss: 606.3966
train acc: 0.9815 ||| test acc: 0.6770


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 49/50 ====
loss: 568.1118
train acc: 0.9787 ||| test acc: 0.7068


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 50/50 ====
loss: 596.3461
train acc: 0.9781 ||| test acc: 0.6833
