In [1]:
import torch
import random
import numpy as np


config = {
    'train_file_path': 'dataset/train.csv',
    'test_file_path': 'dataset/test.csv',
    'train_val_ratio': 0.1,
    'model_path': 'dataset/BERT_model',
    'batch_size': 16,
    'num_epochs': 2,
    'learning_rate': 2e-5,
    'logging_step': 500,
    'seed': 2021
}
config['device']='cuda' if torch.cuda.is_available() else 'cpu'

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

seed_everything(config['seed'])

2021

In [2]:
import pandas as pd
from tqdm import tqdm
from collections import defaultdict


def read_data(config, tokenizer, mode='train'):
    # read train/test data
    data_df = pd.read_csv(config[f'{mode}_file_path'], sep=',')
    if mode == 'train':
        # if is train, split dataset: train/val
        X_train, y_train = defaultdict(list), []
        X_val, y_val = defaultdict(list), []
        num_val = int(len(data_df) * config['train_val_ratio'])
    else:
        X_test, y_test = defaultdict(list), []
        
    for i, row in tqdm(data_df.iterrows(), desc=f'preprocess {mode} data', total=len(data_df)):
        # get label
        label = row[1] if mode == 'train' else 0
        # get sentence
        sentence = row[-1]
        # add_special_tokens: CLS SEP
        # return_token_type_ids
        # return_attention_mask
        inputs = tokenizer.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True, return_attention_mask=True)
        
        if mode == 'train':
            if i < num_val:
                X_val['inputs_ids'].append(inputs['input_ids'])
                y_val.append(label)
                X_val['token_type_ids'].append(inputs['token_type_ids'])
                X_val['attention_mask'].append(inputs['attention_mask'])
            else:
                X_train['inputs_ids'].append(inputs['input_ids'])
                y_train.append(label)
                X_train['token_type_ids'].append(inputs['token_type_ids'])
                X_train['attention_mask'].append(inputs['attention_mask'])
        else:
            X_test['inputs_ids'].append(inputs['input_ids'])
            y_test.append(label)
            X_test['token_type_ids'].append(inputs['token_type_ids'])
            X_test['attention_mask'].append(inputs['attention_mask'])
    
    if mode == 'train':
        label2id = {label: i for i, label in enumerate(np.unique(y_train))} 
        id2label = {i: label for label, i in label2id.items()} 
        y_train = torch.tensor([label2id[i] for i in y_train], dtype=torch.long)  
        y_val = torch.tensor([label2id[i] for i in y_val], dtype=torch.long)
        
        return X_train, y_train, X_val, y_val, label2id, id2label
    else:
        y_test = torch.tensor(y_test, dtype=torch.long)
        
        return X_test, y_test

In [3]:
from torch.utils.data import Dataset

class TNEWSData(Dataset):
    def __init__(self, X, y):
        self.x = X
        self.y = y
        
    def __getitem__(self, idx):
        return {
            'inputs_ids' : self.x['inputs_ids'][idx],
            'label' : self.y[idx],
            'token_type_ids': self.x['token_type_ids'][idx],
            'attention_mask': self.x['attention_mask'][idx]
        }
    
    def __len__(self):
        return self.y.size(0)

In [4]:
def collate_fn(examples):
    input_ids_list, labels = [], []
    token_type_ids_list, attention_mask_list = [], []
    
    for example in examples:
        input_ids_list.append(example['inputs_ids'])
        labels.append(example['label'])
        token_type_ids_list.append(example['token_type_ids'])
        attention_mask_list.append(example['attention_mask'])
        
    # to tensor
    max_length = max(len(input_ids) for input_ids in input_ids_list)
    # shape: (len(labels), max_length)
    input_ids_tensor = torch.zeros((len(labels), max_length), dtype=torch.long)
    token_type_ids_tensor = torch.zeros_like(input_ids_tensor)
    attention_mask_tensor = torch.zeros_like(input_ids_tensor)
    
    for i, input_ids in enumerate(input_ids_list):
        input_ids_tensor[i, :len(input_ids)] = torch.tensor(input_ids, dtype=torch.long)
        token_type_ids_tensor[i, :len(input_ids)] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
        attention_mask_tensor[i, :len(input_ids)] = torch.tensor(attention_mask_list[i], dtype=torch.long)
        
    return {
        'input_ids' : input_ids_tensor,
        'labels' : torch.tensor(labels, dtype=torch.long),
        'token_type_ids': token_type_ids_tensor,
        'attention_mask': attention_mask_tensor
    }

In [5]:
from transformers import BertTokenizer
from torch.utils.data import DataLoader

def build_dataloader(config):
    tokenizer = BertTokenizer.from_pretrained(config['model_path'])
    X_train, y_train, X_val, y_val, label2id, id2label = read_data(config, tokenizer, mode='train')
    X_test, y_test = read_data(config, tokenizer, mode='test')
    
    train_dataset = TNEWSData(X_train, y_train)
    val_dataset = TNEWSData(X_val, y_val)
    test_dataset = TNEWSData(X_test, y_test)
    
    train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], num_workers=4, shuffle=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], num_workers=4, shuffle=False, collate_fn=collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], num_workers=4, shuffle=False, collate_fn=collate_fn)
    
    return train_dataloader, val_dataloader, test_dataloader, id2label

In [6]:
train_dataloader, val_dataloader, test_dataloader, id2label = build_dataloader(config)

preprocess train data: 100%|███████████████████████████| 53360/53360 [00:21<00:00, 2441.42it/s]
preprocess test data: 100%|████████████████████████████| 10000/10000 [00:04<00:00, 2362.04it/s]


In [7]:
for batch in train_dataloader:
    print(batch)
    break

{'input_ids': tensor([[  101,  2682,  1762,  6948,  2336,  1453,  6804,   743,  5018,   753,
          1947,  2791,  2094,  8024,   126,  1283,  2340,  1381,  4638,  8024,
          3300,   784,   720,  1962,  4638,  2972,  5773,  8043,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  8108,  2207,  3198,   704,  1744,  7674,  2168,   131,  2199,
          8208,   674,  2207,   868,  1773,   976,  1168,  2399,  1057, 10194,
          8157,   783,   117,  4706,  4518,  7790,  1213,  1772,  6631,  3330,
          1649,  6411,   102,     0,     0,     0,     0,     0],
        [  101,  3173,  3528,  1059,  1744,  6121,  1266,   776,  4991,   100,
          2207,  4923,  2415,  3173,  6629,  4157,  2199,   715,  1215,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  5303,   754,  5023,  1168,   872,  8013,  1266,   677,  2408

In [8]:
from sklearn.metrics import f1_score

def evaluation(config, model, val_dataloader):
    model.eval()
    preds = []
    labels = []
    val_loss = 0.
    val_iterator = tqdm(val_dataloader, desc='Evaluation', total=len(val_dataloader))

    with torch.no_grad():
        for batch in val_iterator:
            labels.append(batch['labels'])
            batch = {item: value.to(config['device']) for item, value in batch.items()}
            loss, logits = model(**batch)[:2]

            val_loss += loss.item()
            preds.append(logits.argmax(dim=-1).detach().cpu())

    avg_val_loss = val_loss / len(val_dataloader)
    labels = torch.cat(labels, dim=0).numpy()
    preds = torch.cat(preds, dim=0).numpy()
    f1 = f1_score(labels, preds, average='macro')
    return avg_val_loss, f1

In [9]:
from transformers import BertConfig, BertForSequenceClassification
from tqdm import trange

def train(config, id2label, train_dataloader, val_dataloader):
    # config
    bert_config = BertConfig.from_pretrained(config['model_path'])
    bert_config.num_labels = len(id2label)
    model = BertForSequenceClassification.from_pretrained(config['model_path'], config=bert_config)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])
    
    model.to(config['device'])
    epoch_iterator = trange(config['num_epochs'])
    global_steps = 0
    train_loss = 0.
    logging_loss = 0.
    
    for epoch in epoch_iterator:
        train_iterator = tqdm(train_dataloader, desc='Training', total=len(train_dataloader))
        model.train()
        for batch in train_dataloader:
            batch = {item: value.to(config['device']) for item, value in batch.items()}
            loss = model(**batch)[0]
            
            model.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            global_steps += 1
            
            if global_steps % config['logging_step'] == 0:
                print_train_loss = (train_loss - logging_loss) / config['logging_step']
                logging_loss = train_loss
                avg_val_loss, f1 = evaluation(config, model, val_dataloader)
                print_log = f'>>> training loss: {print_train_loss:.4f}, valid_loss: {avg_val_loss:.4f}, ' \
                            f'valid f1 score: {f1:.4f}'
                print(print_log)
                model.train()
    
    return model

In [10]:
model = train(config, id2label, train_dataloader, val_dataloader)

Some weights of the model checkpoint at dataset/BERT_model were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint a

Evaluation:  82%|████████████████████████████████████▏       | 275/334 [00:07<00:01, 37.68it/s][A[A

Evaluation:  84%|████████████████████████████████████▉       | 280/334 [00:07<00:01, 38.84it/s][A[A

Evaluation:  85%|█████████████████████████████████████▍      | 284/334 [00:07<00:01, 37.80it/s][A[A

Evaluation:  87%|██████████████████████████████████████      | 289/334 [00:08<00:01, 39.52it/s][A[A

Evaluation:  88%|██████████████████████████████████████▋     | 294/334 [00:08<00:00, 40.47it/s][A[A

Evaluation:  90%|███████████████████████████████████████▍    | 299/334 [00:08<00:00, 40.98it/s][A[A

Evaluation:  91%|████████████████████████████████████████    | 304/334 [00:08<00:00, 40.50it/s][A[A

Evaluation:  93%|████████████████████████████████████████▋   | 309/334 [00:08<00:00, 40.53it/s][A[A

Evaluation:  94%|█████████████████████████████████████████▎  | 314/334 [00:08<00:00, 39.67it/s][A[A

Evaluation:  96%|██████████████████████████████████████████  | 319/334 [0

>>> training loss: 1.6371, valid_loss: 1.3800, valid f1 score: 0.5081




Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A[A

Evaluation:   0%|▏                                             | 1/334 [00:00<01:23,  3.98it/s][A[A

Evaluation:   2%|▊                                             | 6/334 [00:00<00:16, 19.90it/s][A[A

Evaluation:   3%|█▎                                           | 10/334 [00:00<00:12, 26.38it/s][A[A

Evaluation:   4%|██                                           | 15/334 [00:00<00:10, 31.83it/s][A[A

Evaluation:   6%|██▋                                          | 20/334 [00:00<00:08, 35.31it/s][A[A

Evaluation:   7%|███▏                                         | 24/334 [00:00<00:08, 35.60it/s][A[A

Evaluation:   8%|███▊                                         | 28/334 [00:00<00:08, 36.27it/s][A[A

Evaluation:  10%|████▎                                        | 32/334 [00:01<00:08, 35.98it/s][A[A

Evaluation:  11%|████▊                                        | 36/334 

>>> training loss: 1.3593, valid_loss: 1.3190, valid f1 score: 0.5140




Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A[A

Evaluation:   0%|▏                                             | 1/334 [00:00<01:32,  3.58it/s][A[A

Evaluation:   1%|▋                                             | 5/334 [00:00<00:21, 14.98it/s][A[A

Evaluation:   3%|█▏                                            | 9/334 [00:00<00:14, 21.98it/s][A[A

Evaluation:   4%|█▊                                           | 13/334 [00:00<00:12, 26.00it/s][A[A

Evaluation:   5%|██▎                                          | 17/334 [00:00<00:10, 28.88it/s][A[A

Evaluation:   6%|██▊                                          | 21/334 [00:00<00:10, 30.62it/s][A[A

Evaluation:   7%|███▎                                         | 25/334 [00:00<00:09, 32.02it/s][A[A

Evaluation:   9%|███▉                                         | 29/334 [00:01<00:09, 32.14it/s][A[A

Evaluation:  10%|████▍                                        | 33/334 

Evaluation:  96%|██████████████████████████████████████████  | 319/334 [00:09<00:00, 32.02it/s][A[A

Evaluation:  97%|██████████████████████████████████████████▌ | 323/334 [00:09<00:00, 32.70it/s][A[A

Evaluation:  98%|███████████████████████████████████████████ | 327/334 [00:09<00:00, 33.50it/s][A[A

Evaluation: 100%|████████████████████████████████████████████| 334/334 [00:09<00:00, 34.37it/s][A[A


>>> training loss: 1.3054, valid_loss: 1.2878, valid f1 score: 0.5250




Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A[A

Evaluation:   0%|▏                                             | 1/334 [00:00<01:28,  3.76it/s][A[A

Evaluation:   1%|▋                                             | 5/334 [00:00<00:20, 15.86it/s][A[A

Evaluation:   3%|█▏                                            | 9/334 [00:00<00:14, 22.97it/s][A[A

Evaluation:   4%|█▊                                           | 13/334 [00:00<00:11, 27.44it/s][A[A

Evaluation:   5%|██▎                                          | 17/334 [00:00<00:10, 31.22it/s][A[A

Evaluation:   7%|██▉                                          | 22/334 [00:00<00:08, 34.77it/s][A[A

Evaluation:   8%|███▋                                         | 27/334 [00:00<00:08, 36.89it/s][A[A

Evaluation:   9%|████▏                                        | 31/334 [00:01<00:08, 37.43it/s][A[A

Evaluation:  10%|████▋                                        | 35/334 

>>> training loss: 1.2782, valid_loss: 1.2546, valid f1 score: 0.5487




Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A[A

Evaluation:   0%|▏                                             | 1/334 [00:00<01:29,  3.71it/s][A[A

Evaluation:   1%|▋                                             | 5/334 [00:00<00:21, 15.41it/s][A[A

Evaluation:   3%|█▏                                            | 9/334 [00:00<00:14, 22.95it/s][A[A

Evaluation:   4%|█▊                                           | 13/334 [00:00<00:11, 27.80it/s][A[A

Evaluation:   5%|██▎                                          | 17/334 [00:00<00:10, 30.58it/s][A[A

Evaluation:   6%|██▊                                          | 21/334 [00:00<00:09, 31.78it/s][A[A

Evaluation:   7%|███▎                                         | 25/334 [00:00<00:09, 32.99it/s][A[A

Evaluation:   9%|███▉                                         | 29/334 [00:01<00:09, 32.99it/s][A[A

Evaluation:  10%|████▍                                        | 33/334 

>>> training loss: 1.2716, valid_loss: 1.2468, valid f1 score: 0.5312




Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A[A

Evaluation:   0%|▏                                             | 1/334 [00:00<01:20,  4.11it/s][A[A

Evaluation:   2%|▊                                             | 6/334 [00:00<00:16, 20.27it/s][A[A

Evaluation:   3%|█▎                                           | 10/334 [00:00<00:12, 26.13it/s][A[A

Evaluation:   4%|█▉                                           | 14/334 [00:00<00:10, 29.66it/s][A[A

Evaluation:   5%|██▍                                          | 18/334 [00:00<00:09, 32.81it/s][A[A

Evaluation:   7%|██▉                                          | 22/334 [00:00<00:09, 34.24it/s][A[A

Evaluation:   8%|███▌                                         | 26/334 [00:00<00:08, 35.62it/s][A[A

Evaluation:   9%|████                                         | 30/334 [00:00<00:08, 36.06it/s][A[A

Evaluation:  10%|████▋                                        | 35/334 

>>> training loss: 1.2501, valid_loss: 1.2249, valid f1 score: 0.5282


 50%|█████████████████████████████▌                             | 1/2 [06:38<06:38, 398.77s/it]

Training:   0%|                                                       | 0/3002 [06:38<?, ?it/s][A[A

Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A
Evaluation:   0%|▏                                             | 1/334 [00:00<01:08,  4.83it/s][A
Evaluation:   1%|▋                                             | 5/334 [00:00<00:18, 18.22it/s][A
Evaluation:   3%|█▍                                           | 11/334 [00:00<00:10, 31.60it/s][A
Evaluation:   4%|██                                           | 15/334 [00:00<00:09, 34.14it/s][A
Evaluation:   6%|██▌                                          | 19/334 [00:00<00:09, 34.11it/s][A
Evaluation:   7%|███                                          | 23/334 [00:00<00:09, 34.16it/s][A
Evaluation:   8%|███▋                                         | 27/334 [00:00<00:08, 35.49it/s][A
Evaluati

Evaluation: 100%|████████████████████████████████████████████| 334/334 [00:09<00:00, 35.05it/s][A


>>> training loss: 1.0153, valid_loss: 1.2458, valid f1 score: 0.5257



Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A
Evaluation:   0%|▏                                             | 1/334 [00:00<01:09,  4.78it/s][A
Evaluation:   2%|▊                                             | 6/334 [00:00<00:14, 22.20it/s][A
Evaluation:   4%|█▌                                           | 12/334 [00:00<00:09, 33.87it/s][A
Evaluation:   5%|██▍                                          | 18/334 [00:00<00:07, 40.54it/s][A
Evaluation:   7%|███                                          | 23/334 [00:00<00:07, 43.17it/s][A
Evaluation:   8%|███▊                                         | 28/334 [00:00<00:06, 44.94it/s][A
Evaluation:  10%|████▍                                        | 33/334 [00:00<00:06, 45.10it/s][A
Evaluation:  11%|█████                                        | 38/334 [00:00<00:06, 45.95it/s][A
Evaluation:  13%|█████▊                                       | 43/334 [00:01<00:06, 46.38it/s][A
Evaluatio

>>> training loss: 1.0388, valid_loss: 1.2395, valid f1 score: 0.5412



Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A
Evaluation:   0%|▏                                             | 1/334 [00:00<01:21,  4.08it/s][A
Evaluation:   1%|▋                                             | 5/334 [00:00<00:19, 16.72it/s][A
Evaluation:   3%|█▏                                            | 9/334 [00:00<00:13, 23.32it/s][A
Evaluation:   4%|█▊                                           | 13/334 [00:00<00:11, 27.10it/s][A
Evaluation:   5%|██▎                                          | 17/334 [00:00<00:10, 29.28it/s][A
Evaluation:   6%|██▊                                          | 21/334 [00:00<00:10, 30.28it/s][A
Evaluation:   7%|███▎                                         | 25/334 [00:00<00:09, 31.70it/s][A
Evaluation:   9%|███▉                                         | 29/334 [00:01<00:09, 33.13it/s][A
Evaluation:  10%|████▍                                        | 33/334 [00:01<00:08, 33.73it/s][A
Evaluatio

>>> training loss: 1.0356, valid_loss: 1.2401, valid f1 score: 0.5250



Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A
Evaluation:   0%|▏                                             | 1/334 [00:00<01:20,  4.14it/s][A
Evaluation:   2%|▊                                             | 6/334 [00:00<00:16, 19.88it/s][A
Evaluation:   3%|█▎                                           | 10/334 [00:00<00:12, 25.02it/s][A
Evaluation:   4%|█▉                                           | 14/334 [00:00<00:10, 29.39it/s][A
Evaluation:   5%|██▍                                          | 18/334 [00:00<00:09, 31.82it/s][A
Evaluation:   7%|██▉                                          | 22/334 [00:00<00:09, 33.85it/s][A
Evaluation:   8%|███▋                                         | 27/334 [00:00<00:08, 35.80it/s][A
Evaluation:   9%|████▏                                        | 31/334 [00:01<00:08, 35.14it/s][A
Evaluation:  10%|████▋                                        | 35/334 [00:01<00:08, 35.47it/s][A
Evaluatio

>>> training loss: 1.0190, valid_loss: 1.2509, valid f1 score: 0.5510



Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A
Evaluation:   0%|▏                                             | 1/334 [00:00<01:08,  4.89it/s][A
Evaluation:   1%|▋                                             | 5/334 [00:00<00:17, 18.38it/s][A
Evaluation:   3%|█▏                                            | 9/334 [00:00<00:12, 25.85it/s][A
Evaluation:   4%|█▊                                           | 13/334 [00:00<00:10, 29.51it/s][A
Evaluation:   5%|██▎                                          | 17/334 [00:00<00:10, 30.44it/s][A
Evaluation:   7%|██▉                                          | 22/334 [00:00<00:09, 34.02it/s][A
Evaluation:   8%|███▌                                         | 26/334 [00:00<00:09, 33.12it/s][A
Evaluation:   9%|████                                         | 30/334 [00:01<00:09, 33.29it/s][A
Evaluation:  10%|████▌                                        | 34/334 [00:01<00:08, 33.91it/s][A
Evaluatio

Evaluation: 100%|████████████████████████████████████████████| 334/334 [00:09<00:00, 33.83it/s][A


>>> training loss: 1.0558, valid_loss: 1.2284, valid f1 score: 0.5487



Evaluation:   0%|                                                      | 0/334 [00:00<?, ?it/s][A
Evaluation:   0%|▏                                             | 1/334 [00:00<01:22,  4.04it/s][A
Evaluation:   1%|▋                                             | 5/334 [00:00<00:19, 16.95it/s][A
Evaluation:   3%|█▎                                           | 10/334 [00:00<00:11, 27.03it/s][A
Evaluation:   4%|█▉                                           | 14/334 [00:00<00:10, 30.90it/s][A
Evaluation:   5%|██▍                                          | 18/334 [00:00<00:09, 32.76it/s][A
Evaluation:   7%|██▉                                          | 22/334 [00:00<00:09, 33.68it/s][A
Evaluation:   8%|███▌                                         | 26/334 [00:00<00:08, 34.88it/s][A
Evaluation:   9%|████                                         | 30/334 [00:01<00:08, 34.02it/s][A
Evaluation:  10%|████▌                                        | 34/334 [00:01<00:08, 34.26it/s][A
Evaluatio

>>> training loss: 1.0513, valid_loss: 1.1985, valid f1 score: 0.5474


100%|███████████████████████████████████████████████████████████| 2/2 [12:58<00:00, 389.39s/it]
Training:   0%|                                                       | 0/3002 [06:20<?, ?it/s]


In [11]:
def predict(config, id2label, model, test_dataloader):
    test_iterator = tqdm(test_dataloader, desc='Predicting', total=len(test_dataloader))
    model.eval()
    test_preds = []
    with torch.no_grad():
        for batch in test_iterator:
            batch = {item: value.to(config['device']) for item, value in batch.items()}
            logits = model(**batch)[1]

            test_preds.append(logits.argmax(dim=-1).detach().cpu())
    test_preds = torch.cat(test_preds, dim=0).numpy()
    test_preds = [id2label[id_] for id_ in test_preds]
    test_df = pd.read_csv(config['test_file_path'], sep=',')
    test_df.insert(1, column='label', value=test_preds)
    test_df.drop(columns=['sentence'], inplace=True)
    test_df.to_csv('submission.csv', index=False, encoding='utf8')

In [12]:
predict(config, id2label, model, test_dataloader)

Predicting: 100%|████████████████████████████████████████████| 625/625 [00:17<00:00, 36.52it/s]
