In [None]:
!git clone https://github.com/SKKU-STEM/TEMworkshop29_DL

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from utils import *

In [None]:
## 데이터셋 구성 ##

dataset_dir = "train_dataset"
label_list = ["W", "V_W", "Se2", "Vac_Se", "Vac_Se2"]

show_dataset(dataset_dir = dataset_dir, 
             label_list = label_list, 
             show_img_num = 10)

In [None]:
## 데이터로더 구성 ##

train_valid_ratio = (8, 2)
batch_size = 32

train_dataloader, valid_dataloader = get_dataloader(dataset_dir = dataset_dir, 
                                                    label_list = label_list,
                                                    train_valid_ratio = train_valid_ratio, 
                                                    batch_size = batch_size)

In [None]:
## 딥러닝 모델 구성 ##

img_channels = 1
conv_num_features1 = 32
conv_num_features2 = 64
conv_num_features3 = 128
fc_num_features = 64

model = MODEL(img_channels, 
              conv_num_features1, 
              conv_num_features2, 
              conv_num_features3,
              fc_num_features, 
              label_list)

In [None]:
## 학습 hyper-parameter 설정 ##

loss_function = nn.CrossEntropyLoss()

learning_rate = 0.005
optimizer = optim.SGD(params = model.parameters(), 
                      lr = learning_rate)

device = "cuda"
device = torch.device(device if torch.cuda.is_available() else "cpu")

In [None]:
## 학습 진행 ##

EPOCH = 200

training_log = []
for epoch in range(1, EPOCH + 1):
    print(f"EPOCH : {epoch}")
    train_loss, train_acc, valid_loss, valid_acc = train(model = model, 
                                                         train_dataloader = train_dataloader, 
                                                         valid_dataloader = valid_dataloader, 
                                                         loss_function = loss_function, 
                                                         optimizer = optimizer, 
                                                         device = device)
    training_log.append([epoch, train_loss, train_acc, valid_loss, valid_acc])

show_train_graph(training_log)

In [None]:
## 테스트 진행 ##

test_data_dir = "test_dataset/input/simulation0.tif"

test_result, mapping_result = test(test_data_dir = test_data_dir, 
                                   model = model, 
                                   device = device, 
                                   show_result = True,
                                   save_mapping_result = True
                                   )

In [None]:
## Confusion matrix 확인 ##

confusion_mat = evaluation(test_data_dir = test_data_dir, 
                           test_result = test_result)
print(confusion_mat)

In [None]:
## Metric 계산 ##

score_result = calculate_score(confusion_mat)

In [None]:
## 테스트 데이터를 활용한 전체 confusion matrix 확인 ##

test_data_src_dir = "test_dataset/input"

test_data_list = os.listdir(test_data_src_dir)

total_confusion_mat = np.zeros((5, 5))
for test_data_dir in test_data_list:
    test_data_dir = f"{test_data_src_dir}/{test_data_dir}"
    test_result, mapping_result = test(test_data_dir = test_data_dir, 
                                       model = model, 
                                       device = device, 
                                       show_result = False,
                                       save_mapping_result = True)
    confusion_mat = evaluation(test_data_dir = test_data_dir, 
                               test_result = test_result)
    total_confusion_mat += confusion_mat

print(total_confusion_mat.astype(np.int32))

In [None]:
## 테스트 데이터를 활용한 전체 metric 계산 ##

total_score_result = calculate_score(total_confusion_mat)

In [None]:
## 실제 데이터 테스트 ##

test_data_dir = "test_data/20MX 80kV Image1 7792.tif"
test_result, mapping_result = test(test_data_dir = test_data_dir, 
                                   model = model, 
                                   device = device, 
                                   show_result = True,
                                   save_mapping_result = False)

In [None]:
## 실제 데이터 테스트 및 결과 도출 ##

test_data_src_dir = "test_data"

test_data_list = os.listdir(test_data_src_dir)

result_df = get_result_df()
for test_data_dir in test_data_list:
    test_data_dir = f"{test_data_src_dir}/{test_data_dir}"
    test_result, mapping_result = test(test_data_dir = test_data_dir, 
                                       model = model, 
                                       device = device, 
                                       show_result = True,
                                       save_mapping_result = True)
    result_df = update_result_df(result_df = result_df, 
                                 test_data_dir = test_data_dir, 
                                 test_result = test_result)
    
result_df.to_csv("result_df.csv")