In [1]:
import common

import os
import time
import random

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

from dotenv import load_dotenv
from datetime import datetime
from zoneinfo import ZoneInfo
import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
model_checkpoint_filename_list = ['checkpoint-resnet34_seed_42_epoch_0_isFull_False.pt', 
                                  'checkpoint-resnet34_seed_42_epoch_2_isFull_False.pt']

model_list = []
tst_loader_list = []

for cp_filename in model_checkpoint_filename_list:
    checkpoint = torch.load(cp_filename, map_location = device)
    
    print(checkpoint['model'], checkpoint['tst_img_size'], checkpoint['batch_size'])
    
    model = timm.create_model(
        checkpoint['model'],
        pretrained = True,
        num_classes = 17,
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])    
    model_list.append(model)
    
    ##
    tst_transform = common.create_tst_transform(checkpoint['tst_img_size'])

    tst_dataset = common.ImageDataset(
        "datasets_fin/sample_submission.csv",
        "datasets_fin/test/",
        transform = tst_transform
    )

    tst_loader = DataLoader(
        tst_dataset,
        batch_size = checkpoint['batch_size'],
        shuffle = False,
        num_workers = 0,
        pin_memory = True
    )
    
    tst_loader_list.append(tst_loader)

print(len(model_list), len(tst_loader_list))

resnet34 32 32
resnet34 32 32
2 2


In [4]:
preds_list = []

for model, tst_loader in zip(model_list, tst_loader_list):
    preds = common.get_preds_list_by_tst_loader(model, tst_loader, device)
    preds_list.append(preds)

100%|██████████| 99/99 [00:11<00:00,  8.77it/s]
100%|██████████| 99/99 [00:10<00:00,  9.02it/s]


In [6]:
# hard voting 구현
def hard_voting(predictions):
    predictions = np.asarray(predictions)
    return np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=predictions)

# 최종 예측
final_pred = hard_voting(preds_list)

# csv 로 저장
common.preds_list_to_save_to_csv(final_pred, tst_loader, 'pred_ensemble.csv')