In [1]:
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as tr 

from data import dataset
from utils import seed, trainer
from model import basic_classifier
from model.backbone import resnet101

seed.seed_everything()

In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [None]:
### Fine-tuning ### 

image_dir_train = './data'
image_dir_val = './data'
image_dir_test = './data'

transform = tr.Compose(
    [
        tr.Resize(512), 
        tr.RandomHorizontalFlip(), 
        tr.RandomVerticalFlip(), 
        tr.RandomRotation(10), 
        tr.ToTensor()
    ]
)

train_set = dataset.make_dataset(
    image_dir=image_dir_train,
    transform=transform
)

val_set = dataset.make_dataset(
    image_dir=image_dir_val,
    transform=tr.Compose([tr.Resize(512), tr.ToTensor()])
)

test_set = dataset.make_dataset(
    image_dir=image_dir_test,
    transform=tr.Compose([tr.Resize(512), tr.ToTensor()])
)

train_loader = DataLoader(
    train_set, 
    batch_size=4,
    shuffle=True
)

val_loader = DataLoader(
    val_set, 
    batch_size=4,
    shuffle=False 
)

test_loader = DataLoader(
    test_set, 
    batch_size=4,
    shuffle=False 
)

In [None]:
ENABLE_PDA = True

resnet = resnet101.ResNet101(pretrain=True).to(device=device)
model = basic_classifier.BasicClassifier(
    model=resnet, 
    in_features=resnet.in_features,
    freezing=True, 
    enable_PDA=ENABLE_PDA,
    num_classes=1
).to(device=device)

# print(model) 

In [None]:
# Warm-up stage 
# 0 ~ 10 epochs
EPOCHS = 10

criterion = nn.BCEWithLogitsLoss()
# higher lr for warm-up
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
           
max_loss = np.inf

model.enable_PDA = False 

In [None]:
for epoch in range(EPOCHS):
    train_loss, train_acc = trainer.model_train(
        model=model, 
        data_loader=train_loader, 
        criterion=criterion, 
        optimizer=optimizer, 
        device=device, 
        scheduler=None, 
    )
    
    val_loss, val_acc = trainer.model_evaluate(
        model=model, 
        data_loader=val_loader, 
        criterion=criterion, 
        device=device)

    print(f'epoch {epoch+1:02d}, loss: {train_loss:.5f}, accuracy: {train_acc:.5f}, val_loss: {val_loss:.5f}, val_accuracy: {val_acc:.5f} \n')

In [None]:
# freezing True -> False 
for p in model.backbone.parameters():
    p.requires_grad = True

In [None]:
# Fine-tuning stage 
# 10 ~ 50 epochs
EPOCHS = 40

criterion = nn.BCEWithLogitsLoss()
# lower lr for fine-tuning
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-3)
es = trainer.EarlyStopping(patience=EPOCHS//2, delta=0, mode='min', verbose=True)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=EPOCHS//5, gamma=0.9)

history = {'train_loss' : [],
           'val_loss': [],
           'train_accuracy': [],
           'val_accuracy': []}
           
max_loss = np.inf    

model.enable_PDA = True

In [None]:
for epoch in range(EPOCHS):
    if ENABLE_PDA:
        model.update_cutoff_()
        
    train_loss, train_acc = trainer.model_train(
        model=model, 
        data_loader=train_loader, 
        criterion=criterion, 
        optimizer=optimizer, 
        device=device, 
        scheduler=scheduler, 
    )
    val_loss, val_acc = trainer.model_evaluate(
        model=model, 
        data_loader=val_loader, 
        criterion=criterion, 
        device=device
    )
    
    history['train_loss'].append(train_loss)
    history['train_accuracy'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_accuracy'].append(val_acc)
    
    es(val_loss)
    # Early Stop Check
    if es.early_stop:
        break

    if val_loss < max_loss:
        print(f'[INFO] val_loss has been improved from {max_loss:.5f} to {val_loss:.5f}. Save model.')
        max_loss = val_loss
        torch.save(model.state_dict(), 'Best_Model_IMAGENET.pth')

    print(f'epoch {epoch+1:02d}, loss: {train_loss:.5f}, accuracy: {train_acc:.5f}, val_loss: {val_loss:.5f}, val_accuracy: {val_acc:.5f} \n')

In [None]:
trainer.plot_acc(history=history)

In [None]:
trainer.plot_loss(history=history)

In [None]:
backbone = resnet101.ResNet101(pretrain=False).to(device=device)
model = basic_classifier.BasicClassifier(
    model=resnet, 
    in_features=resnet.in_features,
    freezing=True, 
    enable_PDA=ENABLE_PDA,
    num_classes=1
).to(device=device)

model.load_state_dict(torch.load('Best_Model_IMAGENET.pth', map_location=device))
model.eval()

test_loss, test_acc = trainer.model_evaluate(
    model=model,
    data_loader=test_loader,
    criterion=criterion,
    device=device
)

print('Test Loss: %s'%test_loss)
print('Test Accuracy: %s'%test_acc)