In [None]:
import common

import os
import time
import random

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

from dotenv import load_dotenv
from datetime import datetime
from zoneinfo import ZoneInfo
import wandb

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
is_soft_voting = True
augment_ratio = 1

In [None]:
def get_all_targets_count():
    sample_submission_df = pd.read_csv("datasets_fin/sample_submission.csv")
    return len(sample_submission_df)
    
def hard_voting(predictions):
    predictions = np.asarray(predictions)
    return np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=predictions)

def soft_voting(predictions):
    predictions = np.asarray(predictions)
    mean_axis0 = np.mean(predictions, axis=0)
    
    # 증강된 데이터에 대한 예측값도 고려하기.
    all_targets_count = get_all_targets_count()
    
    aug_size = len(mean_axis0) / all_targets_count
    assert len(mean_axis0) % all_targets_count == 0
    aug_size = int(aug_size)
    
    if aug_size > 1:
        bulk_list = []
        step = 0
        
        for i in range(0, aug_size):
            bulk_list.append(mean_axis0[step:step + all_targets_count])
            step += all_targets_count
        
        bulk_list = np.asarray(bulk_list)
        mean_axis0 = np.mean(bulk_list, axis=0)
        
    return mean_axis0.argmax(axis=1)

In [None]:
fold_1_file_list = [
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_1_aug_200_vl_0.3414_va_0.9158_vf1_0.9147_fold_1_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_2_aug_200_vl_0.3114_va_0.9193_vf1_0.9192_fold_1_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_3_aug_200_vl_0.4536_va_0.9146_vf1_0.9121_fold_1_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_4_aug_200_vl_0.4589_va_0.9204_vf1_0.9195_fold_1_folds_2_PICK.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_5_aug_200_vl_0.6567_va_0.8950_vf1_0.8920_fold_1_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_6_aug_200_vl_0.6449_va_0.8916_vf1_0.8892_fold_1_folds_2.pt',
]

fold_2_file_list = [
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_1_aug_200_vl_0.3597_va_0.8939_vf1_0.8987_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_2_aug_200_vl_0.4646_va_0.8754_vf1_0.8716_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_3_aug_200_vl_0.5112_va_0.8939_vf1_0.8987_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_4_aug_200_vl_0.4933_va_0.9262_vf1_0.9213_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_5_aug_200_vl_0.6814_va_0.9077_vf1_0.9041_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_6_aug_200_vl_0.8049_va_0.9031_vf1_0.9027_fold_2_folds_2_PICK.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_7_aug_200_vl_0.4791_va_0.9135_vf1_0.9131_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_8_aug_200_vl_0.6528_va_0.9008_vf1_0.8995_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_9_aug_200_vl_0.8085_va_0.9031_vf1_0.9043_fold_2_folds_2.pt',
    'model_bak3/cp-densenet121.ra_in1k_sd_42_epc_10_aug_200_vl_0.5794_va_0.9239_vf1_0.9237_fold_2_folds_2.pt',
]

In [None]:
cnt_1 = 0
cnt_2 = 0

for fold_1_filename in fold_1_file_list:
    for fold_2_filename in fold_2_file_list:
        print(f'fold 1 filename: {fold_1_filename}')
        print(f'fold 2 filename: {fold_2_filename}')
        
        model_checkpoint_filename_list = [
            fold_1_filename,
            fold_2_filename,
        ]

        model_list = []
        tst_loader_list = []

        for cp_filename in model_checkpoint_filename_list:
            checkpoint = torch.load(cp_filename, map_location = device)
            
            print(checkpoint['model'], checkpoint['tst_img_size'], checkpoint['batch_size'])
            
            model = timm.create_model(
                checkpoint['model'],
                pretrained = True,
                num_classes = 17,
            ).to(device)
            
            model.load_state_dict(checkpoint['model_state_dict'])    
            model_list.append(model)
            
            ##
            tst_transform = common.create_tst_transform(checkpoint['tst_img_size'])
            tst_aug_transform = common.create_trn_aug_transform(checkpoint['tst_img_size'])

            # tst_dataset = common.ImageDataset(
            #     "datasets_fin/sample_submission.csv",
            #     "datasets_fin/test/",
            #     transform = tst_transform
            # )
            
            tst_dataset = common.ImageDataset(
                "datasets_fin/sample_submission.csv",
                "datasets_fin/test/",
                transform=tst_transform, 
                aug_transform=tst_aug_transform, 
                augment_ratio=augment_ratio)

            tst_loader = DataLoader(
                tst_dataset,
                batch_size = checkpoint['batch_size'],
                shuffle = False,
                num_workers = 12,
                pin_memory = True
            )
            
            tst_loader_list.append(tst_loader)

        print(len(model_list), len(tst_loader_list))
        
        preds_list = []

        for model, tst_loader in zip(model_list, tst_loader_list):
            preds = common.get_preds_list_by_tst_loader(model, tst_loader, device, is_soft_voting)
            preds_list.append(preds)
        
        # 최종 예측
        if is_soft_voting:
            final_pred = soft_voting(preds_list)
        else:
            final_pred = hard_voting(preds_list)

        # csv 로 저장
        csv_filename = f'pred_0808_hoho_{cnt_1}_{cnt_2}.csv'
        common.preds_list_to_save_to_csv(final_pred, tst_loader, csv_filename)
        print(f'prediction save to {csv_filename}.')
        print()
        
        cnt_2 += 1
    
    cnt_1 += 1
    cnt_2 = 0