In [None]:
import timm
from DataLoader import AOIDataset
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
from statistics import mean
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import pandas as pd
import warnings

warnings.filterwarnings(action='ignore')

### Sweep 초기화 및 하이퍼파라미터 선언

학습 결과 다음의 파라미터에서 최적의 성능을 보임을 확인하였다. 
- batch size : 64
- learning rate : 1e-4
- image size : 224

차후 이를 이용하여 모델의 평가를 진행하고자 한다. 

In [None]:
sweep_config = {
    'name' : 'resnet18-aoi',
    'method' : 'grid',
    'parameters' : {
        'lr' : {
            'value' : 1e-4
        },
        'batch_size' : {
            'value' : 64
        }, 
        'img_size' : {
            'value' : 224
        }
    }
}

In [None]:
epochs = 50
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
sweep_id = wandb.sweep(sweep_config, project='test')

### 데이터 정의

In [None]:
def define_datasets(run) : 
    # data
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(run.config.img_size)
    ])

    train_data = AOIDataset(train=True, transform=transform)
    val_data = AOIDataset(val=True, transform=transform)
    test_data = AOIDataset(test=True, transform=transform)

    train_loader = DataLoader(train_data, batch_size=run.config.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_data, batch_size=run.config.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_data, batch_size=run.config.batch_size, shuffle=True, num_workers=4)

    return train_loader, val_loader, test_loader


### 모델 및 기타 학습용 객체 정의

In [None]:
def define_training_object(run) : 
    model = timm.create_model('resnet18', pretrained=True, num_classes=7).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=run.config.lr)
    criterion = nn.NLLLoss()

    # resnet 18의 특성 추출용 CNN layer의 파라미터는 고정시키고 분류기의 성능만을 업데이트하고자 한다. 
    for param in model.parameters() :
        param.require_grad = False

    for param in model.fc.parameters() :
        param.require_grad = True

    return model, optimizer, criterion

### 학습 함수 정의

In [None]:
def train(model, optimizer, criterion, data_loader) : 
    model.train()
    
    epoch_loss = 0  
    epoch_acc = 0

    for _, (data, label) in enumerate(data_loader) :
        data = data.to(device)
        label = label.to(device)
        print(data.size())
        print(label.size())
        print('==========')

        logit = model(data)
        prob = F.log_softmax(logit) 
        pred = torch.argmax(prob, dim=1)
        acc = torch.sum(pred==label)

        loss = criterion(prob, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    train_loss = epoch_loss / len(data_loader.dataset)
    train_acc = epoch_acc / len(data_loader.dataset)
        
    return train_acc, train_loss


def validation(model, optimizer, criterion, val_loader) :
    model.eval()
    epoch_loss = 0
    epoch_acc = 0

    for data, label in val_loader :
        with torch.no_grad() : 
            data = data.to(device)
            label = label.to(device)
            
            logit = model(data)
            prob = F.log_softmax(logit) 
            pred = torch.argmax(prob, dim=1)
            acc = torch.sum(pred==label)

            loss = criterion(prob, label)
            epoch_loss += loss.item()
            epoch_acc += acc
    
    val_loss = epoch_loss / len(val_loader.dataset)
    val_acc = epoch_acc / len(val_loader.dataset)
        
    return val_acc, val_loss


def test(model, test_loader) :
    model.eval()
    prediction, ground_truth = [], []

    infer_times = []
    for data, label in test_loader :
        with torch.no_grad() : 
            data = data.to(device)
            label = label.to(device)

            t1 = time.time()
            logit = model(data)
            t2 = time.time()
            infer_times.append((t2-t1) * 1000)  

            prob = F.softmax(logit)
            pred = torch.argmax(prob, dim=1)

            prediction += list(pred.detach().cpu().numpy())
            ground_truth += list(label.detach().cpu().numpy())

    infer_time = mean(infer_times)
    wandb.config.infer_time = infer_time
    report = classification_report(prediction, ground_truth)

    return report
    #report = wandb.Table(report)
    #run.log({'mutil class classification report' : report})

### 학습 진행

In [None]:
def main() : 
    run = wandb.init()
    run.name = 'resnet18_' + str(run.config.lr) + '_' + str(run.config.batch_size) + '_' + str(run.config.img_size)
    wandb.config.epochs = epochs
    
    train_loader, val_loader, test_loader = define_datasets(run)
    model, optimizer, criterion = define_training_object(run)

    print(run.config)

    for epoch in range(epochs) :
        
        train_acc, train_loss = train(model, optimizer, criterion, train_loader)
        run.log({'epoch' : epoch, 'train_acc' : train_acc})
        run.log({'epoch' : epoch, 'train_loss' : train_loss})

        val_acc, val_loss = validation(model, optimizer, criterion, val_loader)
        run.log({'epoch' : epoch, 'val_acc' : val_acc})
        run.log({'epoch' : epoch, 'val_loss' : val_loss})

        print('epoch : {} train_acc : {:.4f} train_loss : {:.4f} val_acc : {:.4f} val_loss : {:.4f}'.format(epoch, train_acc, train_loss, val_acc, val_loss))

    print(test(model, test_loader))
    torch.save(model.state_dict(), 'model/resnet18_' + str(run.config.lr) + '_' + str(run.config.batch_size) + '_' + str(run.config.img_size) + '.pt')

### sweep 실행

In [None]:
wandb.agent(sweep_id, function=main)

### test 데이터에 적용하여 분류성능 평가

In [None]:
model = timm.create_model('resnet18', pretrained=True, num_classes=7).to(device)
model.load_state_dict(torch.load('model/resnet18_0.0001_64_224.pt'))

# 테스트 데이터 불러오기 
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224))
])

test_data = AOIDataset(test=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=4)


In [None]:
def test(model, test_loader) :
    model.eval()
    prediction, ground_truth = [], []

    infer_times = []
    for data, label in test_loader :
        with torch.no_grad() : 
            data = data.to(device)
            label = label.to(device)

            t1 = time.time()
            logit = model(data)
            t2 = time.time()
            infer_times.append((t2-t1) * 1000)  

            prob = F.softmax(logit)
            pred = torch.argmax(prob, dim=1)

            prediction += list(pred.detach().cpu().numpy())
            ground_truth += list(label.detach().cpu().numpy())

    infer_time = mean(infer_times)
    report = classification_report(prediction, ground_truth)

    cols = ['normal', 'burr', 'substance', 'metalburr', 'crack', 'overflow', 'unfulfilled']
    cm = confusion_matrix(prediction, ground_truth)
    cm = pd.DataFrame(cm, columns=cols, index=cols)
    print(cm)

    return report

In [None]:
report = test(model, test_loader)
print(report)