# import tool

In [1]:
import numpy as np
import datetime
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms

# import self-define tool

In [2]:
from network_structure import ResNet_50

# REPRODUCIBILITY

In [3]:
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
IM[PO]

# read data

In [4]:
BATCH_SIZE = 128
dataset = dset.ImageFolder(root = "traindata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)  

In [5]:
dataset = dset.ImageFolder(root = "testdata",
                           transform=transforms.Compose([transforms.Resize([224, 224]), 
                                                         transforms.ToTensor(),
                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                        ]))
test_loader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = False)  

# predict function

In [6]:
def get_prediction(data_loader_):
    data_loader = data_loader_
    
    total = 0
    correct = 0
    
    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            outputs = net(X)
            _, pred = torch.max(outputs, 1)

            correct = correct + torch.sum(pred == y).item()
            total = total + len(y)
    
    return( correct/total )

# network structure

In [7]:
DEVICE = "cuda:1"

net = ResNet_50()
net = net.to(DEVICE)

# loss

In [8]:
criterion = nn.CrossEntropyLoss()

# optimizer

In [9]:
LR = 0.0001
optimizer = optim.Adam(net.parameters(), lr = LR)

# start train

In [11]:
EPOCH = 10
train_acc = []
test_acc = []
loss_record = []
print( datetime.date.today().strftime('%Y-%m-%d %H:%M:%S') )
for epoch in range(EPOCH):
    running_loss = 0.0
    
    for X, y in tqdm(train_loader):
        X, y = X.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(X)

        loss = criterion(outputs, y)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
    
    print("==== EPOCH %d/%d ====" % (epoch+1, EPOCH))
    print("loss: %.4f" % running_loss)
    buf1, buf2 = get_prediction(train_loader), get_prediction(test_loader)
    train_acc = train_acc + [buf1]
    test_acc = test_acc + [buf2]
    loss_record = loss_record + [running_loss]
    print("train acc: %.4f ||| test acc: %.4f" % (buf1, buf2))
    
    Path = "resnet_50_epoch_%s.pt" % epoch
    torch.save(net, Path)

2020-08-04 00:00:00


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 1/10 ====
loss: 185.5225
train acc: 0.7351 ||| test acc: 0.7335


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 2/10 ====
loss: 183.3139
train acc: 0.7351 ||| test acc: 0.7335


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 3/10 ====
loss: 181.3128
train acc: 0.7377 ||| test acc: 0.7354


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 4/10 ====
loss: 177.8002
train acc: 0.7391 ||| test acc: 0.7372


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 5/10 ====
loss: 173.7288
train acc: 0.7448 ||| test acc: 0.7422


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 6/10 ====
loss: 169.4968
train acc: 0.7484 ||| test acc: 0.7422


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 7/10 ====
loss: 165.7687
train acc: 0.7534 ||| test acc: 0.7443


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 8/10 ====
loss: 163.5760
train acc: 0.7535 ||| test acc: 0.7458


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 9/10 ====
loss: 160.3512
train acc: 0.7595 ||| test acc: 0.7496


HBox(children=(FloatProgress(value=0.0, max=220.0), HTML(value='')))


==== EPOCH 10/10 ====
loss: 157.5790
train acc: 0.7665 ||| test acc: 0.7495
