# import tool

In [1]:
import numpy as np
import datetime
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as models

# read data

In [2]:
BATCH_SIZE = 4
dataset = dset.ImageFolder(root = "traindata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)  

In [3]:
dataset = dset.ImageFolder(root = "testdata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
test_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = False)  

# predict function

In [4]:
def get_prediction(data_loader_):
    data_loader = data_loader_
    
    total = 0
    correct = 0
    
    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            outputs = net(X)
            _, pred = torch.max(outputs, 1)

            correct = correct + torch.sum(pred == y).item()
            total = total + len(y)
    
    return( correct/total )

# network structure

In [5]:
DEVICE = "cuda:0"
net = models.resnet50(pretrained=True)

# def set_parameter_requires_grad(model, feature_extracting):
#     if feature_extracting:
#         for param in model.parameters():
#             param.requires_grad = False

# set_parameter_requires_grad(net, feature_extracting = True)
num_ftrs = net.fc.in_features

net.fc = nn.Linear(num_ftrs, 5)

net = net.to(DEVICE)

# loss

In [6]:
criterion = nn.CrossEntropyLoss()

# optimizer

In [7]:
LR = 0.0001
optimizer = optim.SGD(net.parameters(), lr = LR, weight_decay=0.0005, momentum=0.9)

# start train

In [8]:
EPOCH = 20
train_acc = []
test_acc = []
loss_record = []
print( datetime.date.today().strftime('%Y-%m-%d %H:%M:%S') )
for epoch in range(EPOCH):
    running_loss = 0.0
    
    for X, y in tqdm(train_loader):
        X, y = X.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(X)

        loss = criterion(outputs, y)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
    
    print("==== EPOCH %d/%d ====" % (epoch+1, EPOCH))
    print("loss: %.4f" % running_loss)
    buf1, buf2 = get_prediction(train_loader), get_prediction(test_loader)
    train_acc = train_acc + [buf1]
    test_acc = test_acc + [buf2]
    loss_record = loss_record + [running_loss]
    print("train acc: %.4f ||| test acc: %.4f" % (buf1, buf2))
    
    Path = "resnet_50_pretrain_epoch_%s.pt" % epoch
    torch.save(net, Path)

2020-08-05 00:00:00


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 1/20 ====
loss: 5498.6668
train acc: 0.7576 ||| test acc: 0.7311


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 2/20 ====
loss: 4953.8710
train acc: 0.7695 ||| test acc: 0.7332


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 3/20 ====
loss: 4616.7150
train acc: 0.7933 ||| test acc: 0.7258


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 4/20 ====
loss: 4321.4646
train acc: 0.8115 ||| test acc: 0.7209


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 5/20 ====
loss: 4019.5215
train acc: 0.8242 ||| test acc: 0.6883


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 6/20 ====
loss: 3710.2095
train acc: 0.8418 ||| test acc: 0.7085


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 7/20 ====
loss: 3360.4398
train acc: 0.8588 ||| test acc: 0.7270


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 8/20 ====
loss: 2949.6839
train acc: 0.8799 ||| test acc: 0.6996


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 9/20 ====
loss: 2595.1365
train acc: 0.9003 ||| test acc: 0.6897


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 10/20 ====
loss: 2207.3363
train acc: 0.9187 ||| test acc: 0.7401


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 11/20 ====
loss: 1869.9799
train acc: 0.9330 ||| test acc: 0.7157


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 12/20 ====
loss: 1595.0532
train acc: 0.9538 ||| test acc: 0.7378


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 13/20 ====
loss: 1244.4241
train acc: 0.9518 ||| test acc: 0.6841


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 14/20 ====
loss: 1051.4407
train acc: 0.9543 ||| test acc: 0.6941


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 15/20 ====
loss: 937.8134
train acc: 0.9708 ||| test acc: 0.7173


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 16/20 ====
loss: 801.9001
train acc: 0.9689 ||| test acc: 0.7063


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 17/20 ====
loss: 741.5371
train acc: 0.9777 ||| test acc: 0.7277


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 18/20 ====
loss: 598.5383
train acc: 0.9852 ||| test acc: 0.7377


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 19/20 ====
loss: 491.8426
train acc: 0.9872 ||| test acc: 0.7375


HBox(children=(FloatProgress(value=0.0, max=7025.0), HTML(value='')))


==== EPOCH 20/20 ====
loss: 505.9279
train acc: 0.9794 ||| test acc: 0.7068
