In [1]:
# custom model
from DataLoader import XrayDataset

# library for deep learning
import torch 
from torch import nn 
import torch.nn.functional as F
import timm
from torch.utils.data import DataLoader
from torchvision import transforms

# classification
from sklearn.metrics import classification_report

# ETC
import pandas as pd 
import wandb
import time
import warnings

warnings.filterwarnings(action='ignore')

### defect 유형 탐지 및 전처리 

In [None]:
df = pd.read_csv('DATA PATH')

labels = list(df['label'])
unique_label = list(set(labels))

for label_type in unique_label :
    print('{} defect : {}'.format(label_type, labels.count(label_type)))

### 하이퍼파라미터 및 sweep 정의하기 

In [None]:
sweep_config = {
    'name' : 'sse_xray_resnet18', 
    'method' : 'grid',
    'parameters' : {
        'img_size' : {
            'value' : 224
        },
        'batch_size' : {
            'values' : [16, 32, 64]
        },
        'lr' : {
            'values' : [1e-3, 1e-4, 1e-5]
        },
        'epochs' : {
            'value' : 20
        }
    }
}

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

sweep_id = wandb.sweep(sweep_config, project='resnet18_xray_eval', entity='wandb')

### 데이터 불러오기

In [None]:
def declare_data(run) : 
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((run.config.img_size, run.config.img_size))
    ])

    train_data = XrayDataset(train=True, transform=transform)
    val_data = XrayDataset(val=True, transform=transform)
    test_data = XrayDataset(test=True, transform=transform)

    train_loader = DataLoader(train_data, batch_size=run.config.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_data, batch_size=run.config.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=4)

    return train_loader, val_loader, test_loader

### train, validation 함수 정의하기

In [None]:
def train(model, optimizer, criterion, data_loader) :
    model.train()
    epoch_acc = 0
    epoch_loss = 0

    for data, label in data_loader :
        data = data.to(device)
        label = label.to(device)

        logit = model(data)
        prob = F.log_softmax(logit)
        pred = torch.argmax(prob, dim=1)
        acc = torch.sum(pred==label)
        loss = criterion(prob, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_acc += acc.item() 
        epoch_loss += loss.item()

    train_acc = epoch_acc / len(data_loader.dataset)
    train_loss = epoch_loss / len(data_loader.dataset)

    return train_acc, train_loss



def validation(model, criterion, data_loader) :
    model.eval()
    epoch_acc = 0
    epoch_loss = 0

    with torch.no_grad() : 
        for data, label in data_loader :
            data = data.to(device)
            label = label.to(device)

            logit = model(data)
            prob = F.log_softmax(logit)
            pred = torch.argmax(prob, dim=1)
            acc = torch.sum(pred==label)
            loss = criterion(prob, label)

            epoch_acc += acc 
            epoch_loss += loss

        val_acc = epoch_acc / len(data_loader.dataset)
        val_loss = epoch_loss / len(data_loader.dataset)

    return val_acc, val_loss
        

### 학습용 객체 정의하기

In [None]:
def define_training_object(run) : 
    model = timm.create_model('resnet18', pretrained=True, num_classes=4).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=run.config.lr)
    criterion = nn.NLLLoss()

    # resnet 18의 특성 추출용 CNN layer의 파라미터는 고정시키고 분류기의 성능만을 업데이트하고자 한다. 
    for param in model.parameters() :
        param.require_grad = False

    for param in model.fc.parameters() :
        param.require_grad = True

    return model, optimizer, criterion

### main 함수 정의하기 

In [None]:
def main() :
    # wandb initialize 
    run = wandb.init()
    run.name = 'xray_' + str(run.config.batch_size) + '_' + str(run.config.img_size) + '_' + str(run.config.lr)

    # 학습용 객체(데이터, 모델, 옵티마이저, 로스 함수) 정의하기 
    train_data, val_data, _ = declare_data(run)
    model, optimizer, criterion = define_training_object(run)

    print(run.config)

    # 학습 진행
    for epoch in range(run.config.epochs) : 
        train_acc, train_loss = train(model, optimizer, criterion, train_data)
        run.log({'epoch' : epoch, 'train/acc' : train_acc})
        run.log({'epoch' : epoch, 'train/loss' : train_loss})

        val_acc, val_loss = validation(model, criterion, val_data)
        run.log({'epoch' : epoch, 'val/acc' : val_acc})
        run.log({'epoch' : epoch, 'val/loss' : val_loss})

        print('epoch : {} train_acc : {:.4f} train_loss : {:.4f} val_acc : {:.4f} val_loss : {:.4f}'.format(epoch, train_acc, train_loss, val_acc, val_loss))

    # 모델 저장하기 
    torch.save(model.state_dict(), 'model/xray/resnet18_' + '_' + str(run.config.batch_size) + '_' + str(run.config.img_size) + '_' +  str(run.config.lr) + '.pt')

### 학습 진행

In [None]:
wandb.agent(sweep_id, function=main)

---
### 모델 테스트

In [13]:
def test(run, model, criterion, data_loader) :
    
    model.eval()
    epoch_acc = 0
    epoch_loss = 0
    infer_times = 0 

    labels, predictions = [], []

    with torch.no_grad() : 
        for data, label in data_loader :
            data = data.to(device)
            label = label.to(device)

            t1 = time.time()
            logit = model(data)
            t2 = time.time()
            infer_times += (t2-t1) * 1000

            prob = F.log_softmax(logit)
            pred = torch.argmax(prob, dim=1)
            acc = torch.sum(pred==label)
            loss = criterion(prob, label)

            epoch_acc += acc 
            epoch_loss += loss

            labels += list(label.detach().cpu().numpy())
            predictions += list(pred.detach().cpu().numpy())

        test_acc = epoch_acc / len(data_loader.dataset)
        test_loss = epoch_loss / len(data_loader.dataset)
        infer_times = infer_times / len(data_loader.dataset)

        run.config.acc = test_acc
        run.config.loss = test_loss
        run.config.infer_time = infer_times

    # precision, recall, confusion matrix 등 세부 지표 확인하기 
    report = classification_report(labels, predictions)
    print(report)

    #run.log({'roc curve/roc_curve' : wandb.plot.roc_curve(labels, predictions, labels=['good', 'bad', 'empty', 'in-spec'])})
    run.log({'confusion_matrix/confusion_matrix' : wandb.sklearn.plot_confusion_matrix(labels, predictions, ['good', 'bad', 'empty', 'in-spec'])})


In [14]:
run = wandb.init(entity='wandb', project='resnet18_xray_eval', name='evaluation')

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

model = timm.create_model('resnet18', pretrained=True, num_classes=4).to(device)
model.load_state_dict(torch.load('model/xray/resnet18__64_224_1e-05.pt'))
criterion = nn.NLLLoss()

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224))
    ])

test_data = XrayDataset(test=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=4)

test(run, model, criterion, test_loader)

              precision    recall  f1-score   support

           0       0.99      0.99      0.99      3858
           1       0.73      0.92      0.81       724
           2       1.00      1.00      1.00      4030
           3       0.69      0.32      0.44       370

    accuracy                           0.96      8982
   macro avg       0.85      0.81      0.81      8982
weighted avg       0.96      0.96      0.96      8982



In [15]:
run.finish()




VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…