# import tool

In [1]:
import numpy as np
import datetime
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as models

# read data

In [2]:
BATCH_SIZE = 4
dataset = dset.ImageFolder(root = "traindata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)  

In [3]:
dataset = dset.ImageFolder(root = "testdata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
test_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = False)  

# predict function

In [4]:
def get_prediction(data_loader_):
    data_loader = data_loader_
    
    total = 0
    correct = 0
    
    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            outputs = net(X)
            _, pred = torch.max(outputs, 1)

            correct = correct + torch.sum(pred == y).item()
            total = total + len(y)
    
    return( correct/total )

# network structure

In [5]:
DEVICE = "cuda:0"
net = models.resnet18(pretrained=True)

"""def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

set_parameter_requires_grad(net, feature_extracting = True)"""
num_ftrs = net.fc.in_features

net.fc = nn.Linear(num_ftrs, 5)

net = net.to(DEVICE)

# loss

In [6]:
criterion = nn.CrossEntropyLoss()

# optimizer

In [7]:
LR = 0.0001
optimizer = optim.SGD(net.parameters(), lr = LR, weight_decay=0.0005, momentum=0.9)

# start train

In [8]:
EPOCH = 20
train_acc = []
test_acc = []
loss_record = []
print( datetime.date.today().strftime('%Y-%m-%d %H:%M:%S') )
for epoch in range(EPOCH):
    running_loss = 0.0
    
    for X, y in tqdm(train_loader):
        X, y = X.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(X)

        loss = criterion(outputs, y)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
    
    print("==== EPOCH %d/%d ====" % (epoch+1, EPOCH))
    print( datetime.date.today().strftime('%Y-%m-%d %H:%M:%S') )
    print("loss: %.4f" % running_loss)
    buf1, buf2 = get_prediction(train_loader), get_prediction(test_loader)
    train_acc = train_acc + [buf1]
    test_acc = test_acc + [buf2]
    loss_record = loss_record + [running_loss]
    print("train acc: %.4f ||| test acc: %.4f" % (buf1, buf2))
    
    Path = "resnet_18_pretrain_epoch_%s.pt" % epoch
    torch.save(net, Path)

2020-08-04 00:00:00


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 1/20 ====
loss: 5585.0540
train acc: 0.7474 ||| test acc: 0.7302


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 2/20 ====
loss: 5133.5474
train acc: 0.7630 ||| test acc: 0.7320


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 3/20 ====
loss: 4863.6559
train acc: 0.7838 ||| test acc: 0.7149


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 4/20 ====
loss: 4621.3932
train acc: 0.7943 ||| test acc: 0.7209


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 5/20 ====
loss: 4397.3102
train acc: 0.8011 ||| test acc: 0.7274


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 6/20 ====
loss: 4115.9777
train acc: 0.8073 ||| test acc: 0.6791


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 7/20 ====
loss: 3886.9651
train acc: 0.8286 ||| test acc: 0.6974


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 8/20 ====
loss: 3636.3650
train acc: 0.8469 ||| test acc: 0.7180


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 9/20 ====
loss: 3324.8899
train acc: 0.8512 ||| test acc: 0.7112


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 10/20 ====
loss: 2965.6940
train acc: 0.8795 ||| test acc: 0.6743


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 11/20 ====
loss: 2585.0422
train acc: 0.8947 ||| test acc: 0.6908


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 12/20 ====
loss: 2308.6105
train acc: 0.9111 ||| test acc: 0.6962


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 13/20 ====
loss: 1944.6212
train acc: 0.9288 ||| test acc: 0.6660


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 14/20 ====
loss: 1696.8842
train acc: 0.9380 ||| test acc: 0.6985


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 15/20 ====
loss: 1414.6688
train acc: 0.9515 ||| test acc: 0.7045


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 16/20 ====
loss: 1234.0689
train acc: 0.9519 ||| test acc: 0.7140


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 17/20 ====
loss: 1116.7217
train acc: 0.9473 ||| test acc: 0.6592


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 18/20 ====
loss: 983.8261
train acc: 0.9549 ||| test acc: 0.6801


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 19/20 ====
loss: 885.7848
train acc: 0.9701 ||| test acc: 0.7042


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 20/20 ====
loss: 815.0270
train acc: 0.9661 ||| test acc: 0.7035
