In [1]:
import os
import time
import random
import numpy as np
import pandas as pd
from glob import glob
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
from data import preprocess_df, CustomDataset
from model import BertModel

In [4]:
bias_config = {
    'model_name': 'beomi/KcELECTRA-base',       
    'input_col': 'comment_title',               
    'target_col': 'bias',              
    'dropout': 0.1,                
    'max_len': 128,
    'select': False,
}
hate_config = {
    'model_name': 'beomi/KcELECTRA-base',       
    'input_col': 'comment_title',  
    'target_col': 'hate',                  
    'dropout': 0.1,                            
    'max_len': 128,            
    'select': False,
}

In [5]:
def stack_inference(config, target_col, load_path, test_path='./test.csv'):
    n_class = 3 if target_col == "bias" else 2

    df = pd.read_csv(test_path)
    df['bias'] = 0
    df['hate'] = 0
    df['comment_title'] = df['comment'] + ' ' + df['title']
    df['title_comment'] = df['title'] + ' ' + df['comment']
    bias_map = {'none': 0, 'gender': 1, 'others': 2}
    hate_map = {'none': 0, 'hate': 1}
    inv_bias_map = {v: k for k, v in bias_map.items()}
    inv_hate_map = {v: k for k, v in hate_map.items()}

    model_list = []
    for path in load_path:
        load_pt = torch.load(path)
        model_list.append(load_pt)
        print(path)

    stack_logits = torch.zeros(len(df), n_class).cpu()
    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
    model = BertModel(model_name=config['model_name'], n_class=n_class, p=config['dropout'])
    model.to(device)
    
    for state_dict in model_list:
        assert config['target_col'] == target_col

        df = preprocess_df(df, col=config['input_col'])

        x_data = np.array([i for i in df[config['input_col']].values])
        y_data = df[target_col].values

        test_set = CustomDataset(x_data, y_data, tokenizer, config['max_len'])
        test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

        model.load_state_dict(state_dict)
        model.eval()

        temp = []
        with torch.no_grad():
            for x, y in tqdm(test_loader):
                mask = x['attention_mask'].squeeze(1).to(device)
                input_ids = x['input_ids'].squeeze(1).to(device)
                segment_ids = x['token_type_ids'].squeeze(1).to(device)
                y = y.to(device)
  
                logits = model(input_ids, mask, segment_ids) 
 
                logits = logits.detach().cpu()
                temp.append(logits)
        res = torch.stack(temp, dim=0).squeeze()
        stack_logits += res
        
    y_pred = stack_logits.argmax(dim=-1).numpy()

    if target_col == 'bias':
        df[target_col] = [inv_bias_map[i] for i in y_pred]
    elif target_col == 'hate':
        df[target_col] = [inv_hate_map[i] for i in y_pred]
    return df

In [6]:
glob('./saved/KcELECTRA-base/*ct_bias_*')

['./saved/KcELECTRA-base/4_ct_bias_0.770.pt',
 './saved/KcELECTRA-base/0_ct_bias_0.754.pt',
 './saved/KcELECTRA-base/2_ct_bias_0.738.pt',
 './saved/KcELECTRA-base/3_ct_bias_0.762.pt',
 './saved/KcELECTRA-base/1_ct_bias_0.759.pt']

In [None]:
bias_load_path = glob('./saved/KcELECTRA-base/*ct_bias_*')
bias_result = stack_inference(bias_config, 'bias', bias_load_path)

In [None]:
hate_load_path = glob('./saved/KcELECTRA-base/*ct_hate_*')
hate_result = stack_inference(hate_config, 'hate', hate_load_path)

In [9]:
submission = pd.read_csv('sample_submission.csv')
submission['bias'] = bias_result.bias
submission['hate'] = hate_result.hate
submission

Unnamed: 0,ID,bias,hate
0,0,none,none
1,1,none,none
2,2,none,hate
3,3,none,hate
4,4,none,hate
...,...,...,...
506,506,none,hate
507,507,none,none
508,508,others,hate
509,509,others,hate


In [10]:
submission.to_csv(f'./_submission.csv', index=False)