# 📜 문서 타입 분류 대회

> - kimkihong / helpotcreator@gmail.com / Upstage AI Lab 3기
> - 2024.07.30.화 10:00 ~ 2024.08.11.일 19:00

In [1]:
import os
import time
import random
import copy

import timm
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
import albumentations as A
from albumentations import ImageOnlyTransform
from augraphy import *
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

import matplotlib.pyplot as plt

pre_path = '/kkh/'
train_kr_aug_image_path = pre_path + 'data/train_kr_aug'
meta_kr_csv_path = pre_path + 'data/meta_kr.csv'
train_kr_csv_path = pre_path + 'data/train_kr.csv'
train_kr_aug_csv_path = pre_path + 'data/train_kr_aug.csv'
meta_kr_df = pd.read_csv(meta_kr_csv_path)
train_kr_df = pd.read_csv(train_kr_csv_path)
train_kr_aug_df = pd.read_csv(train_kr_aug_csv_path)

# PRE_PATH = '/kkh/'
# TRAIN_KR_IMAGE_PATH = PRE_PATH + 'data/train_kr'
# TRAIN_KR_AUG_IMAGE_PATH = PRE_PATH + 'data/train_kr_aug'
# TEST_IMAGE_PATH = PRE_PATH + 'data/test'

# META_KR_CSV_PATH = PRE_PATH + 'data/meta_kr.csv'
# META_KR_DF = pd.read_csv(META_KR_CSV_PATH)
# TRAIN_KR_CSV_PATH = PRE_PATH + 'data/train_kr.csv'
# TRAIN_KR_DF = pd.read_csv(TRAIN_KR_CSV_PATH)
# TRAIN_KR_AUG_CSV_PATH = PRE_PATH + 'data/train_kr_aug.csv'
# TRAIN_KR_AUG_DF = pd.read_csv(TRAIN_KR_AUG_CSV_PATH)
# TEST_CSV_PATH = PRE_PATH + 'data/sample_submission.csv'
# TEST_DF = pd.read_csv(TEST_CSV_PATH)

In [2]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [3]:
# 데이터셋 클래스를 정의합니다.
class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None, oversample=False):
        self.df = pd.read_csv(csv)
        self.path = path
        self.transform = transform
        self.oversample = oversample

        # 클래스간 불균형 해소를 위한 샘플 증식
        if self.oversample:
            # 각 클래스별로 데이터 수 계산
            class_counts = np.bincount(self.df.values[:, 1].astype(int))

            # 각 클래스별로 증식할 횟수 설정 (이 예제에서는 최대 데이터 수에 맞춤)
            max_class_count = max(class_counts)
            oversample_factors = [max_class_count // count for count in class_counts]
            # Class 3, 7 가중치 2로 변경
            oversample_factors[3] = 2
            oversample_factors[7] = 2 
            # oversample_factors[14] = 3 

            # 각 클래스별로 데이터를 증식한 새로운 데이터 프레임 생성
            oversampled_data = [self.df.values[self.df.values[:, 1] == cls].repeat(factor, axis=0) for cls, factor in enumerate(oversample_factors)]
            oversampled_data = np.vstack(oversampled_data)

            self.df = pd.DataFrame(oversampled_data, columns=self.df.columns)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df.iloc[idx]
        img = np.array(Image.open(os.path.join(self.path, name)).convert("RGB"))
        
        if self.transform:
            img = self.transform(image=img)['image']
        return img, target
    
label_to_class_name = dict(zip(meta_kr_df['target'], meta_kr_df['class_name']))

In [4]:
# one epoch 학습을 위한 함수
def training(model, dataloader, criterion, optimizer, device, epoch, num_epochs):
    model.train()
    train_loss = 0
    preds_list = []
    targets_list = []

    pbar = tqdm(dataloader)
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        model.zero_grad(set_to_none=True)

        preds = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(labels.detach().cpu().numpy())

        pbar.set_description(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {loss.item()}")
        
    train_loss /= len(dataloader)
    train_acc = accuracy_score(targets_list, preds_list)    
    train_f1 = f1_score(targets_list, preds_list, average='macro')

    return model, train_loss, train_acc, train_f1

def evaluation(model, dataloader, criterion, device, epoch, num_epochs):
    model.eval()  # 모델을 평가 모드로 설정
    valid_loss = 0.0
    preds_list = []
    targets_list = []

    with torch.no_grad():
        tbar = tqdm(dataloader)
        for images, labels in tbar:
            images = images.to(device)
            labels = labels.to(device)

            preds = model(images)
            loss = criterion(preds, labels)

            valid_loss += loss.item()
            preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
            targets_list.extend(labels.detach().cpu().numpy())

            tbar.set_description(f"Epoch [{epoch+1}/{num_epochs}] - Valid Loss: {loss.item()}")

    valid_loss = valid_loss / len(dataloader)
    valid_acc = accuracy_score(targets_list, preds_list)  
    valid_f1 = f1_score(targets_list, preds_list, average='macro')

    return valid_loss, valid_acc, valid_f1

def training_loop(model, train_dataloader, valid_dataloader, criterion, optimizer, device, num_epochs, patience, model_name):
    best_valid_loss = float('inf')  # 가장 좋은 validation loss를 저장
    early_stop_counter = 0  # 카운터
    valid_max_accuracy = -1
    best_model = None

    for epoch in range(num_epochs):
        model, train_loss, train_acc, train_f1 = training(model, train_dataloader, criterion, optimizer, device, epoch, num_epochs)
        valid_loss, valid_acc, valid_f1 = evaluation(model, valid_dataloader, criterion, device, epoch, num_epochs)

        monitoring_value = {'train_loss': train_loss, 'train_accuracy': train_acc, 'train_f1': train_f1, 
                            'valid_loss': valid_loss, 'valid_accuracy': valid_acc, 'valid_f1': valid_f1}
        
        print(f'''Epoch [{epoch + 1}/{num_epochs}] Finished
        Train Loss: {train_loss}, Train Accuracy: {train_acc}, Train F1: {train_f1}
        Valid Loss: {valid_loss}, Valid Accuracy: {valid_acc}, Valid F1: {valid_f1}''')

        if valid_acc > valid_max_accuracy:
          valid_max_accuracy = valid_acc

        # validation loss가 감소하면 모델 저장 및 카운터 리셋
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_model = model
            torch.save(model.state_dict(), PRE_PATH + f"model_{model_name}.pt")
            early_stop_counter = 0
            print('Model Saved')

        # validation loss가 증가하거나 같으면 카운터 증가
        else:
            early_stop_counter += 1

        # 조기 종료 카운터가 설정한 patience를 초과하면 학습 종료
        if early_stop_counter >= patience:
            print("Early stopping")
            break

    return best_model, valid_max_accuracy

In [5]:
# 이미지 plotting을 위한 함수
def normalize_image(image):
    # 이미지를 [0, 1] 범위로 정규화
    image_min = image.min()
    image_max = image.max()
    normalized_image = (image - image_min) / (image_max - image_min)
    return normalized_image

def plot_images(images, labels, classes, normalize = True):

    n_images = len(images)
    
    num_rows = n_images // 3  # 행의 개수 계산
    if n_images % 3 != 0:
        num_rows += 1    
    fig, axes = plt.subplots(num_rows, 3, figsize=(30, 10 * num_rows))

    for i in range(n_images):
        image = images[i]
        label = classes[labels[i]]
        row_idx = i // 3
        col_idx = i % 3
        
        if normalize:
            image = normalize_image(image)
        
        axes[row_idx, col_idx].imshow(image.permute(1, 2, 0))
        axes[row_idx, col_idx].set_title(label)
        axes[row_idx, col_idx].axis('off')

    plt.show()

In [6]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# validation config
VALID_RATIO = 0.8

# model config
model_name = 'efficientnet_b4'
pretrained_size = 380
pretrained_means = [0.485, 0.456, 0.406]
pretrained_stds= [0.229, 0.224, 0.225]

# training config
LR = 5e-4
BATCH_SIZE = 32
dropout_ratio = 0.2
patience = 5
num_workers = 0
num_classes = 17

In [7]:
# train image 변환을 위한 transform 코드
train_transform = A.Compose([
    # PatternGeneratorTransform(pattern, p=0.3), # 패턴 노이즈
    # A.Resize(height=pretrained_size, width=pretrained_size), # 이미지 크기 조정
    # 이미지 긴 측면 크기 조절 후 패딩 적용
    A.LongestMaxSize(max_size=pretrained_size, always_apply=True), 
    A.PadIfNeeded(min_height=pretrained_size, min_width=pretrained_size, border_mode=0, value=(255, 255, 255)),
    
    A.Normalize(mean=pretrained_means, std=pretrained_stds), # images normalization
    ToTensorV2() # numpy 이미지나 PIL 이미지를 PyTorch 텐서로 변환
])

# test image 변환을 위한 transform 코드
test_transform = A.Compose([    
    # A.Resize(height=pretrained_size, width=pretrained_size),
    A.LongestMaxSize(max_size=pretrained_size, always_apply=True),
    A.PadIfNeeded(min_height=pretrained_size, min_width=pretrained_size, border_mode=0, value=(255, 255, 255)),
    
    A.Normalize(mean=pretrained_means, std=pretrained_stds),
    ToTensorV2()
])

# aug_test_transform = A.Compose([    
#     A.RandomRotate90(),
#     A.Flip(p=0.5),              
                        
#     # A.Resize(height=pretrained_size, width=pretrained_size),
#     A.LongestMaxSize(max_size=pretrained_size, always_apply=True),
#     A.PadIfNeeded(min_height=pretrained_size, min_width=pretrained_size, border_mode=0, value=(255, 255, 255)),
    
#     A.Normalize(mean=pretrained_means, std=pretrained_stds),
#     ToTensorV2()
# ])

# 시각화를 위한 transform 코드
base_transform = A.Compose([
    ToTensorV2()
])

In [8]:
# Training Dataset 정의
train_dataset = ImageDataset(
    TRAIN_KR_AUG_CSV_PATH,
    TRAIN_KR_AUG_IMAGE_PATH,
    transform=train_transform,
    oversample=True
)

# Test Dataset 정의
test_dataset = ImageDataset(
    TEST_CSV_PATH,
    TEST_IMAGE_PATH,
    transform=test_transform
)

# aug_test_dataset = ImageDataset(
#     data_path + 'sample_submission.csv',
#     data_path + 'test/',
#     transform=aug_test_transform
# )

# 시각화용 Dataset 정의
train_dataset_v = ImageDataset(
    TRAIN_KR_AUG_CSV_PATH,
    TRAIN_KR_AUG_IMAGE_PATH,
    transform=base_transform
)

test_dataset_v = ImageDataset(
    TEST_CSV_PATH,
    TEST_IMAGE_PATH,
    transform=base_transform
)

print(len(train_dataset), len(test_dataset))

29872 3140


In [9]:
# 데이터 셋을 학습 데이터 셋과 검증 데이터 셋으로 분리
total_size = len(train_dataset)
train_num, valid_num = int(total_size * VALID_RATIO), total_size - int(total_size * VALID_RATIO)

# train - valid set 나누기
generator = torch.Generator().manual_seed(SEED)
train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_num, valid_num], generator = generator)

valid_data = copy.deepcopy(valid_dataset)
valid_data.dataset.transform = test_transform

print(f'Train dataset 개수: {len(train_dataset)}')
print(f'Validation dataset 개수: {len(valid_dataset)}')
print(f'Test dataset 개수: {len(test_dataset)}')

Train dataset 개수: 23897
Validation dataset 개수: 5975
Test dataset 개수: 3140


In [10]:
# DataLoader 정의
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False
    )

valid_dataloader = DataLoader(
    valid_dataset, 
    batch_size = BATCH_SIZE, 
    shuffle = False,
    num_workers=0,
    pin_memory=True
    )

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
    )

# aug_test_dataloader = DataLoader(
#     aug_test_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=False,
#     num_workers=0,
#     pin_memory=True
#     )

In [11]:
# 변환된 학습 이미지 확인
N_IMAGES = 24

# 무작위로 선택된 인덱스 리스트를 생성
selected_indices = random.sample(range(len(train_dataset_v)), N_IMAGES)

images, labels = zip(*[(image, label) for image, label in [train_dataset_v[i] for i in selected_indices]])
# images, labels = zip(*[(image, label) for image, label in [train_dataset_v[i] for i in range(N_IMAGES)]])

# plot_images(images, labels, label_to_class_name)

In [12]:
class AttentionModule(nn.Module):
    def __init__(self, in_features, out_features):
        super(AttentionModule, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.Sigmoid()
        )

    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights

class CustomEfficientNetB4(nn.Module):
    def __init__(self, num_classes, attention_size=1792):
        super(CustomEfficientNetB4, self).__init__()
        self.base_model = timm.create_model('efficientnet_b4', pretrained=True)
        
        # Remove the existing classifier
        self.base_model.reset_classifier(0, '')

        # Add attention mechanism
        self.attention = AttentionModule(attention_size, attention_size)

        # New classifier with attention
        self.classifier = nn.Linear(attention_size, num_classes)
        
    def forward(self, x):
        x = self.base_model(x)
        
        # Global average pooling
        x = x.mean([2, 3])

        # Apply attention mechanism
        x = self.attention(x)

        # Final classification
        x = self.classifier(x)

        return x

In [13]:
# 모델 생성
model = CustomEfficientNetB4(num_classes).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b4.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b4.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


In [14]:
EPOCHS = 50

# 모델 학습 루프 실행 및 검증 최대 정확도 출력
model, valid_max_accuracy = training_loop(
    model,             # 학습할 모델
    train_dataloader,  # 훈련 데이터로더
    valid_dataloader,  # 검증 데이터로더
    loss_fn,           # 손실 함수
    optimizer,         # 옵티마이저
    device,            # 디바이스 (CPU 또는 GPU)
    EPOCHS,            # 총 에폭 수
    patience,          # 조기 중단을 위한 인내 수
    model_name         # 모델 이름
)

# 검증 데이터에서의 최대 정확도 출력
print(f'Valid Max Accuracy: {valid_max_accuracy}')


Epoch [1/50] - Train Loss: 0.0467059500515461: 100%|██████████| 747/747 [04:51<00:00,  2.56it/s]   
Epoch [1/50] - Valid Loss: 0.010883726179599762: 100%|██████████| 187/187 [00:30<00:00,  6.19it/s] 


Epoch [1/50] Finished
        Train Loss: 0.2966415800954155, Train Accuracy: 0.8985228271331129, Train F1: 0.9185054477159537
        Valid Loss: 0.04546667784166069, Valid Accuracy: 0.9834309623430962, Valid F1: 0.987270102305591
Model Saved


Epoch [2/50] - Train Loss: 0.026621723547577858: 100%|██████████| 747/747 [04:47<00:00,  2.59it/s]  
Epoch [2/50] - Valid Loss: 0.0009441324509680271: 100%|██████████| 187/187 [00:29<00:00,  6.33it/s] 


Epoch [2/50] Finished
        Train Loss: 0.0315025129474399, Train Accuracy: 0.9898732058417374, Train F1: 0.9920497353504713
        Valid Loss: 0.028349079499376022, Valid Accuracy: 0.9902928870292887, Valid F1: 0.9928895085247229
Model Saved


Epoch [3/50] - Train Loss: 0.005246365442872047: 100%|██████████| 747/747 [04:47<00:00,  2.60it/s]  
Epoch [3/50] - Valid Loss: 0.0019036014564335346: 100%|██████████| 187/187 [00:29<00:00,  6.28it/s] 


Epoch [3/50] Finished
        Train Loss: 0.01594429017339565, Train Accuracy: 0.9950621416914257, Train F1: 0.9961059643737592
        Valid Loss: 0.03200678187136085, Valid Accuracy: 0.9924686192468619, Valid F1: 0.9940218872297342


Epoch [4/50] - Train Loss: 0.0009074960835278034: 100%|██████████| 747/747 [04:47<00:00,  2.59it/s] 
Epoch [4/50] - Valid Loss: 0.007510175462812185: 100%|██████████| 187/187 [00:29<00:00,  6.42it/s]  


Epoch [4/50] Finished
        Train Loss: 0.01608634549155572, Train Accuracy: 0.9955224505168013, Train F1: 0.9963955926352405
        Valid Loss: 0.012771253544918273, Valid Accuracy: 0.9966527196652719, Valid F1: 0.9965334212470909
Model Saved


Epoch [5/50] - Train Loss: 0.0012349841417744756: 100%|██████████| 747/747 [04:47<00:00,  2.59it/s] 
Epoch [5/50] - Valid Loss: 0.001193458680063486: 100%|██████████| 187/187 [00:29<00:00,  6.34it/s]  


Epoch [5/50] Finished
        Train Loss: 0.011707027341491977, Train Accuracy: 0.9965686069381093, Train F1: 0.9970789434048617
        Valid Loss: 0.01188056285329542, Valid Accuracy: 0.9979916317991632, Valid F1: 0.9984656477168423
Model Saved


Epoch [6/50] - Train Loss: 0.008044359274208546: 100%|██████████| 747/747 [04:47<00:00,  2.59it/s]  
Epoch [6/50] - Valid Loss: 0.015514264814555645: 100%|██████████| 187/187 [00:30<00:00,  6.20it/s]  


Epoch [6/50] Finished
        Train Loss: 0.010264995422132696, Train Accuracy: 0.9967359919655187, Train F1: 0.9972985376289843
        Valid Loss: 0.010037427830851773, Valid Accuracy: 0.9969874476987448, Valid F1: 0.9978780032622426
Model Saved


Epoch [7/50] - Train Loss: 0.22850458323955536: 100%|██████████| 747/747 [04:48<00:00,  2.59it/s]   
Epoch [7/50] - Valid Loss: 0.008665334433317184: 100%|██████████| 187/187 [00:29<00:00,  6.41it/s]  


Epoch [7/50] Finished
        Train Loss: 0.009764235804257787, Train Accuracy: 0.997028915763485, Train F1: 0.9976916670368835
        Valid Loss: 0.013191404560967642, Valid Accuracy: 0.9964853556485356, Valid F1: 0.9966500457910473


Epoch [8/50] - Train Loss: 0.013936085626482964: 100%|██████████| 747/747 [04:47<00:00,  2.60it/s]  
Epoch [8/50] - Valid Loss: 0.10222700238227844: 100%|██████████| 187/187 [00:30<00:00,  6.22it/s]   


Epoch [8/50] Finished
        Train Loss: 0.01231026813545872, Train Accuracy: 0.9964849144244048, Train F1: 0.9970336432692853
        Valid Loss: 0.012510527402406433, Valid Accuracy: 0.9959832635983263, Valid F1: 0.9962905157274129


Epoch [9/50] - Train Loss: 0.00011603563325479627: 100%|██████████| 747/747 [04:47<00:00,  2.60it/s]
Epoch [9/50] - Valid Loss: 0.010353174060583115: 100%|██████████| 187/187 [00:29<00:00,  6.36it/s]  


Epoch [9/50] Finished
        Train Loss: 0.009721592706033465, Train Accuracy: 0.996903376992928, Train F1: 0.9973116739935586
        Valid Loss: 0.005669481178454998, Valid Accuracy: 0.999163179916318, Valid F1: 0.9992688064556239
Model Saved


Epoch [10/50] - Train Loss: 0.04793402552604675: 100%|██████████| 747/747 [04:47<00:00,  2.60it/s]   
Epoch [10/50] - Valid Loss: 0.00023983050778042525: 100%|██████████| 187/187 [00:30<00:00,  6.17it/s]


Epoch [10/50] Finished
        Train Loss: 0.001322715862557157, Train Accuracy: 0.9995815374314767, Train F1: 0.999596397103065
        Valid Loss: 0.03471588094725301, Valid Accuracy: 0.9984937238493724, Valid F1: 0.9986548938077152


Epoch [11/50] - Train Loss: 0.000147076731082052: 100%|██████████| 747/747 [04:48<00:00,  2.59it/s]  
Epoch [11/50] - Valid Loss: 0.001345082768239081: 100%|██████████| 187/187 [00:30<00:00,  6.19it/s]  


Epoch [11/50] Finished
        Train Loss: 0.014629798533522729, Train Accuracy: 0.9955642967736535, Train F1: 0.9961787841086768
        Valid Loss: 0.004518984323995563, Valid Accuracy: 0.9984937238493724, Valid F1: 0.9983671538082947
Model Saved


Epoch [12/50] - Train Loss: 0.007552257739007473: 100%|██████████| 747/747 [04:48<00:00,  2.59it/s]  
Epoch [12/50] - Valid Loss: 0.00042556560947559774: 100%|██████████| 187/187 [00:30<00:00,  6.21it/s]


Epoch [12/50] Finished
        Train Loss: 0.006121287266517027, Train Accuracy: 0.9983679959827594, Train F1: 0.9987469133214544
        Valid Loss: 0.006232105401450056, Valid Accuracy: 0.998326359832636, Valid F1: 0.998908632201776


Epoch [13/50] - Train Loss: 2.1314410787454108e-06: 100%|██████████| 747/747 [04:48<00:00,  2.59it/s]
Epoch [13/50] - Valid Loss: 2.7502615921548568e-05: 100%|██████████| 187/187 [00:29<00:00,  6.38it/s]


Epoch [13/50] Finished
        Train Loss: 0.0022227390322857423, Train Accuracy: 0.9992886136335105, Train F1: 0.9994468353161323
        Valid Loss: 0.011147532648887352, Valid Accuracy: 0.9971548117154811, Valid F1: 0.9976887318774937


Epoch [14/50] - Train Loss: 1.1714381798810791e-05: 100%|██████████| 747/747 [04:48<00:00,  2.59it/s]
Epoch [14/50] - Valid Loss: 0.00034648104337975383: 100%|██████████| 187/187 [00:30<00:00,  6.23it/s]


Epoch [14/50] Finished
        Train Loss: 0.006827067790108328, Train Accuracy: 0.9978658409005314, Train F1: 0.998444620827829
        Valid Loss: 0.0008154965906174963, Valid Accuracy: 0.9994979079497908, Valid F1: 0.9995064683551099
Model Saved


Epoch [15/50] - Train Loss: 0.0003958964953199029:   3%|▎         | 22/747 [00:08<04:54,  2.46it/s] 


KeyboardInterrupt: 

In [15]:
torch.cuda.empty_cache()

In [16]:
model.load_state_dict(torch.load(f'./model_{model_name}.pt'))
# model.load_state_dict(torch.load(PRE_PATH + f'model_{model_name}.pt'))
model.to(device)

N_TTA = 20
preds_list = []
with torch.no_grad():
    # loaders = [test_dataloader] + [aug_test_dataloader] * N_TTA
    loaders = [test_dataloader]

    for batches in tqdm(zip(*loaders), total=len(test_dataloader)):
        images, *aug_images = [images.to(device) for images, _ in batches]

        outputs_original = model(images)
        outputs_augmented = [model(aug_image) for aug_image in aug_images]

        final_outputs = (outputs_original + sum(outputs_augmented)) / N_TTA + 1
        preds_list.extend(final_outputs.argmax(dim=1).cpu().numpy())

# 예측 결과 확인
print("Ensemble Predictions:", preds_list)

100%|██████████| 99/99 [00:17<00:00,  5.65it/s]

Ensemble Predictions: [2, 12, 5, 4, 2, 15, 0, 8, 15, 11, 5, 3, 16, 9, 15, 4, 4, 5, 13, 11, 12, 7, 1, 6, 3, 0, 14, 16, 0, 6, 3, 0, 13, 2, 5, 16, 13, 14, 15, 0, 5, 9, 12, 9, 0, 8, 5, 0, 11, 7, 10, 10, 10, 6, 3, 12, 9, 5, 13, 13, 12, 4, 5, 5, 6, 1, 5, 7, 10, 6, 10, 10, 8, 15, 7, 15, 6, 12, 12, 13, 8, 9, 9, 11, 10, 10, 5, 13, 10, 0, 10, 8, 5, 15, 4, 16, 11, 11, 7, 11, 14, 3, 13, 1, 15, 11, 2, 12, 16, 8, 6, 2, 13, 4, 12, 16, 2, 7, 11, 4, 2, 13, 5, 8, 10, 6, 4, 4, 3, 6, 5, 7, 15, 10, 16, 16, 7, 6, 6, 8, 4, 11, 14, 2, 12, 8, 3, 5, 9, 8, 6, 8, 16, 12, 11, 16, 9, 15, 6, 8, 5, 5, 10, 10, 16, 15, 9, 12, 16, 5, 2, 8, 8, 16, 9, 8, 16, 16, 3, 4, 11, 15, 9, 9, 2, 3, 11, 10, 9, 0, 4, 0, 16, 5, 14, 15, 5, 12, 0, 4, 13, 2, 6, 16, 16, 10, 8, 9, 0, 10, 5, 1, 13, 7, 11, 2, 0, 4, 0, 13, 12, 0, 16, 4, 12, 5, 3, 0, 14, 6, 0, 3, 12, 12, 9, 13, 9, 10, 9, 15, 10, 14, 9, 11, 12, 0, 1, 11, 12, 6, 7, 13, 4, 3, 14, 15, 7, 12, 3, 4, 0, 15, 13, 11, 6, 12, 8, 7, 9, 0, 8, 3, 4, 5, 0, 0, 7, 0, 9, 12, 1, 4, 7, 8, 12, 1, 1




In [17]:
pred_df = pd.DataFrame(test_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list

sample_submission_df = pd.read_csv('/kkh/data/sample_submission.csv')
assert (sample_submission_df['ID'] == pred_df['ID']).all()
# assert (TEST_DF['ID'] == pred_df['ID']).all()

pred_df.to_csv("/kkh/submission/aug.csv", index=False)