# Import requirements

In [None]:
!pip install wandb



In [None]:
!pip install transformers



In [None]:
import os
import pdb
import wandb
import argparse
from dataclasses import dataclass, field
from typing import Optional
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm import tqdm, trange

from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    AutoConfig,
    AdamW,
    get_linear_schedule_with_warmup
)

In [None]:
class EarlyStopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''validation loss가 감소하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# 1. Preprocess

In [None]:
def make_id_file(task, tokenizer):
    def make_data_strings(file_name):
        data_strings = []
        with open(os.path.join('sample_data/', file_name), 'r', encoding='utf-8') as f:
            id_file_data = [tokenizer.encode(line.lower()) for line in f.readlines()]
        for item in id_file_data:
            data_strings.append(' '.join([str(k) for k in item]))
        return data_strings
    
    print('it will take some times...')
    train_pos = make_data_strings('sentiment.train.1')
    train_neg = make_data_strings('sentiment.train.0')
    dev_pos = make_data_strings('sentiment.dev.1')
    dev_neg = make_data_strings('sentiment.dev.0')

    print('make id file finished!')
    return train_pos, train_neg, dev_pos, dev_neg

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
train_pos, train_neg, dev_pos, dev_neg = make_id_file('yelp', tokenizer)

it will take some times...
make id file finished!


In [None]:
train_pos[:10]

['101 6581 2833 1012 102',
 '101 21688 8013 2326 1012 102',
 '101 2027 2036 2031 3679 19247 1998 3256 6949 2029 2003 2428 2204 1012 102',
 '101 2009 1005 1055 1037 2204 15174 2098 7570 22974 2063 1012 102',
 '101 1996 3095 2003 5379 1012 102',
 '101 2204 3347 2833 1012 102',
 '101 2204 2326 1012 102',
 '101 11350 1997 2154 2003 25628 1998 7167 1997 19247 1012 102',
 '101 2307 2173 2005 6265 2030 3347 27962 1998 5404 1012 102',
 '101 1996 2047 2846 3504 6429 1012 102']

In [None]:
class SentimentDataset(object):
    def __init__(self, tokenizer, pos, neg):
        self.tokenizer = tokenizer
        self.data = []
        self.label = []

        for pos_sent, neg_sent in zip(pos,neg):
            self.data += [self._cast_to_int(pos_sent.strip().split())]
            self.label += [[1]]
            self.data += [self._cast_to_int(neg_sent.strip().split())]
            self.label += [[0]]

    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return np.array(sample), np.array(self.label[index])

In [None]:
train_dataset = SentimentDataset(tokenizer, train_pos, train_neg)
dev_dataset = SentimentDataset(tokenizer, dev_pos, dev_neg)

In [None]:
for i, item in enumerate(train_dataset):
    print(item)
    if i == 10:
        break

(array([ 101, 6581, 2833, 1012,  102]), array([1]))
(array([  101,  1045,  2001, 13718, 13534,  1012,   102]), array([0]))
(array([  101, 21688,  8013,  2326,  1012,   102]), array([1]))
(array([  101,  2061,  2006,  2000,  1996,  7570, 22974,  2229,  1010,
        1996,  3059,  2003,  2236,  2448,  1997,  1996,  4971,  1012,
         102]), array([0]))
(array([  101,  2027,  2036,  2031,  3679, 19247,  1998,  3256,  6949,
        2029,  2003,  2428,  2204,  1012,   102]), array([1]))
(array([  101, 10124,  6240,  1998,  1037, 10228,  1997, 29022,  2292,
        8525,  3401,  1012,   102]), array([0]))
(array([  101,  2009,  1005,  1055,  1037,  2204, 15174,  2098,  7570,
       22974,  2063,  1012,   102]), array([1]))
(array([  101,  2498,  2428,  2569,  1004,  2025, 11007,  1997,  1996,
        1002,  1035, 16371,  2213,  1035,  3976,  6415,  1012,   102]), array([0]))
(array([ 101, 1996, 3095, 2003, 5379, 1012,  102]), array([1]))
(array([  101,  2117,  1010,  1996, 21475,  7570, 2

In [None]:
def collate_fn_sentiment(samples):
    input_ids, labels = zip(*samples)
    max_len = max(len(input_id) for input_id in input_ids)
    sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]

    input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
                             batch_first=True)
    attention_mask = torch.tensor(
        [[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
         sorted_indices])
    token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
    position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])
    labels = torch.tensor(np.stack(labels, axis=0)[sorted_indices])

    return input_ids, attention_mask, token_type_ids, position_ids, labels

In [None]:
train_batch_size=128
eval_batch_size=128

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=train_batch_size,
                                           shuffle=True, collate_fn=collate_fn_sentiment,
                                           pin_memory=True, num_workers=4)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=eval_batch_size,
                                         shuffle=False, collate_fn=collate_fn_sentiment,
                                         num_workers=2)

  cpuset_checked))


In [None]:
# random seed
random_seed=42
np.random.seed(random_seed)
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [None]:
model.train()
train_epoch = 20
learning_rate = 5e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)
t_total = len(train_loader)*train_epoch
scheduler = get_linear_schedule_with_warmup(optimizer, t_total/10, t_total)

In [None]:
def compute_acc(predictions, target_labels):
    return (np.array(predictions) == np.array(target_labels)).mean()

In [None]:
def train_model(model, batch_size, patience, n_epochs):
    train_losses = []
    valid_losses = []
    avg_train_losses = []
    avg_valid_losses = []
    
    lowest_valid_loss = 9999.
    early_stopping = EarlyStopping(patience = patience, verbose = True)
    for epoch in range(n_epochs):
        with tqdm(train_loader, unit="batch") as tepoch:
            for iteration, (input_ids, attention_mask, token_type_ids, position_ids, labels) in enumerate(tepoch):
                tepoch.set_description(f"Epoch {epoch}")
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                token_type_ids = token_type_ids.to(device)
                position_ids = position_ids.to(device)
                labels = labels.to(device, dtype=torch.long)
    
                optimizer.zero_grad()
    
                output = model(input_ids=input_ids,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids,
                           position_ids=position_ids,
                           labels=labels)

                loss = output.loss
                loss.backward()
    
                optimizer.step()
                scheduler.step()
                
                train_losses.append(loss.item())
    
                tepoch.set_postfix(loss=loss.item())
                if iteration != 0 and iteration % int(len(train_loader) / 5) == 0:
                    # Evaluate the model five times per epoch
                    with torch.no_grad():
                        model.eval()
                        valid_lossesd_losses = []
                        predictions = []
                        target_labels = []
                        for input_ids, attention_mask, token_type_ids, position_ids, labels in tqdm(dev_loader,
                                                                                                    desc='Eval',
                                                                                                    position=1,
                                                                                                    leave=None):
                            input_ids = input_ids.to(device)
                            attention_mask = attention_mask.to(device)
                            token_type_ids = token_type_ids.to(device)
                            position_ids = position_ids.to(device)
                            labels = labels.to(device, dtype=torch.long)
    
                            output = model(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           token_type_ids=token_type_ids,
                                           position_ids=position_ids,
                                           labels=labels)
    
                            logits = output.logits
                            loss = output.loss
                            valid_losses.append(loss.item())
    
                            batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
                            batch_labels = [int(example) for example in labels]

                            predictions += batch_predictions
                            target_labels += batch_labels
                
            

                    acc = compute_acc(predictions, target_labels)
                    valid_loss = sum(valid_losses) / len(valid_losses)
                    if lowest_valid_loss > valid_loss:
                        print('Acc for model which have lower valid loss: ', acc)
            train_loss = np.average(train_losses)
            valid_loss = np.average(valid_losses)
            avg_train_losses.append(train_loss)
            avg_valid_losses.append(valid_loss)
                    
            train_losses = []
            valid_losses = []
                
            early_stopping(valid_loss, model)
                
            if early_stopping.early_stop:
                print("Early Stopping")
                break       
    model.load_state_dict(torch.load('checkpoint.pt'))
    return model, avg_train_losses, avg_valid_losses

In [None]:
patience = 10

model, train_loss, valid_loss = train_model(model, train_batch_size, patience, train_epoch)

  cpuset_checked))
Epoch 0:  20%|██        | 554/2770 [08:35<34:29,  1.07batch/s, loss=0.0863]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.08it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:12,  2.49it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.73it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.83it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.86it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.86it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:08,  2.78it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.80it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.80it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.85it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.85it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.83it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.88it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.84it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.82it/s][A
Eval:  50%|

Acc for model which have lower valid loss:  0.96075


Epoch 0:  40%|████      | 1108/2770 [17:20<25:08,  1.10batch/s, loss=0.0797]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.10it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.53it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.76it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.82it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.84it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.84it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.77it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.80it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.84it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.88it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:06,  2.86it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.93it/s][A
Eval:  44%|████▍     | 14/32 [00:04<00:06,  2.89it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:05,  2.86it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.9655


Epoch 0:  60%|██████    | 1662/2770 [26:06<17:07,  1.08batch/s, loss=0.0265]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.10it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.53it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.75it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.83it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.86it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.86it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:08,  2.80it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.77it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.83it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.85it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.87it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:06,  2.86it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.91it/s][A
Eval:  44%|████▍     | 14/32 [00:04<00:06,  2.87it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:05,  2.86it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.972


Epoch 0:  80%|████████  | 2216/2770 [34:54<08:22,  1.10batch/s, loss=0.0362]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.12it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.50it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.75it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.84it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.87it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.85it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.78it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.81it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.84it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.88it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.86it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.90it/s][A
Eval:  44%|████▍     | 14/32 [00:04<00:06,  2.85it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:05,  2.84it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.97125


Epoch 0: 100%|██████████| 2770/2770 [43:37<00:00,  1.06batch/s, loss=0.0244]


Validation loss decreased (inf --> 0.085259).  Saving model ...


Epoch 1:  20%|██        | 554/2770 [08:35<32:50,  1.12batch/s, loss=0.0813]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.12it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.52it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.78it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.87it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.87it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.88it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.78it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.78it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.81it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.84it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.87it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.86it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.90it/s][A
Eval:  44%|████▍     | 14/32 [00:04<00:06,  2.89it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.83it/s][A
Eval:  50%|█████     | 16/32 [

Acc for model which have lower valid loss:  0.97325


Epoch 1:  40%|████      | 1108/2770 [17:21<25:44,  1.08batch/s, loss=0.1]   
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.16it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.55it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.75it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.81it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.84it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.84it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.74it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.78it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.77it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.84it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.82it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.82it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.87it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.82it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.80it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.97475


Epoch 1:  60%|██████    | 1662/2770 [26:10<16:52,  1.09batch/s, loss=0.0451]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.09it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:12,  2.50it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.71it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.79it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.81it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.81it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.74it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.77it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.77it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.82it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.83it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.83it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.88it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.83it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.81it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.979


Epoch 1:  80%|████████  | 2216/2770 [34:59<09:12,  1.00batch/s, loss=0.0244]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.15it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:12,  2.49it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.69it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.78it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.81it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.80it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.74it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.78it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.80it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.83it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.82it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.87it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.85it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.82it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.97775


Epoch 1: 100%|██████████| 2770/2770 [43:48<00:00,  1.05batch/s, loss=0.00466]


Validation loss decreased (0.085259 --> 0.063852).  Saving model ...


Epoch 2:  20%|██        | 554/2770 [08:36<34:38,  1.07batch/s, loss=0.00853]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.07it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.51it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.72it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.79it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.84it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.83it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.76it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.78it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.83it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.86it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.85it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.90it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.82it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.82it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.97875


Epoch 2:  40%|████      | 1108/2770 [17:26<26:13,  1.06batch/s, loss=0.0145]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.10it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:12,  2.50it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.73it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.80it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.82it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.83it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.79it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.79it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.82it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.85it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.84it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.87it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.85it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.80it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.97825


Epoch 2:  60%|██████    | 1662/2770 [26:14<16:54,  1.09batch/s, loss=0.007] 
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.10it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:12,  2.50it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.71it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.78it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.79it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.81it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.73it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.73it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.76it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.79it/s][A
Eval:  34%|███▍      | 11/32 [00:04<00:07,  2.82it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.82it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.86it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.83it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.81it/s][A
Eval:  50%|█████     | 16/32 

Acc for model which have lower valid loss:  0.97625


Epoch 2:  80%|████████  | 2216/2770 [35:01<08:39,  1.07batch/s, loss=0.0231] 
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.12it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.50it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.73it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.81it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.83it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.83it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.78it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.78it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.84it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.83it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.83it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.85it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.83it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.80it/s][A
Eval:  50%|█████     | 16/32

Acc for model which have lower valid loss:  0.97975


Epoch 2: 100%|██████████| 2770/2770 [43:50<00:00,  1.05batch/s, loss=0.00588]


EarlyStopping counter: 1 out of 10


Epoch 3:  20%|██        | 554/2770 [08:38<34:59,  1.06batch/s, loss=0.0433]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:15,  1.97it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:12,  2.46it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.66it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.77it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.82it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.80it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.75it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.75it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.79it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.81it/s][A
Eval:  34%|███▍      | 11/32 [00:04<00:07,  2.84it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.85it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.85it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.84it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.78it/s][A
Eval:  50%|█████     | 16/32 [

Acc for model which have lower valid loss:  0.97775


Epoch 3:  40%|████      | 1108/2770 [17:26<25:23,  1.09batch/s, loss=0.0213] 
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.19it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.54it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.70it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:10,  2.78it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.80it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.81it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:09,  2.76it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.76it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.79it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.81it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.84it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.82it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.86it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.83it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.80it/s][A
Eval:  50%|█████     | 16/32

Acc for model which have lower valid loss:  0.9785


Epoch 3:  60%|██████    | 1662/2770 [26:15<17:14,  1.07batch/s, loss=0.00419]
Eval:   0%|          | 0/32 [00:00<?, ?it/s][A
Eval:   3%|▎         | 1/32 [00:00<00:14,  2.15it/s][A
Eval:   6%|▋         | 2/32 [00:00<00:11,  2.55it/s][A
Eval:   9%|▉         | 3/32 [00:01<00:10,  2.76it/s][A
Eval:  12%|█▎        | 4/32 [00:01<00:09,  2.83it/s][A
Eval:  16%|█▌        | 5/32 [00:01<00:09,  2.86it/s][A
Eval:  19%|█▉        | 6/32 [00:02<00:09,  2.84it/s][A
Eval:  22%|██▏       | 7/32 [00:02<00:08,  2.78it/s][A
Eval:  25%|██▌       | 8/32 [00:02<00:08,  2.76it/s][A
Eval:  28%|██▊       | 9/32 [00:03<00:08,  2.79it/s][A
Eval:  31%|███▏      | 10/32 [00:03<00:07,  2.81it/s][A
Eval:  34%|███▍      | 11/32 [00:03<00:07,  2.83it/s][A
Eval:  38%|███▊      | 12/32 [00:04<00:07,  2.82it/s][A
Eval:  41%|████      | 13/32 [00:04<00:06,  2.85it/s][A
Eval:  44%|████▍     | 14/32 [00:05<00:06,  2.83it/s][A
Eval:  47%|████▋     | 15/32 [00:05<00:06,  2.81it/s][A
Eval:  50%|█████     | 16/32

Acc for model which have lower valid loss:  0.9785


Epoch 3:  74%|███████▍  | 2059/2770 [32:36<11:05,  1.07batch/s, loss=0.006]

In [None]:
import pandas as pd
test_df = pd.read_csv('sample_data/test_no_label.csv')

In [None]:
test_dataset = test_df['Id']

In [None]:
def make_id_file_test(tokenizer, test_dataset):
    data_strings = []
    id_file_data = [tokenizer.encode(sent.lower()) for sent in test_dataset]
    for item in id_file_data:
        data_strings.append(' '.join([str(k) for k in item]))
    return data_strings

In [None]:
test = make_id_file_test(tokenizer, test_dataset)

In [None]:
test[:10]

In [None]:
class SentimentTestDataset(object):
    def __init__(self, tokenizer, test):
        self.tokenizer = tokenizer
        self.data = []

        for sent in test:
            self.data += [self._cast_to_int(sent.strip().split())]

    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return np.array(sample)

In [None]:
test_dataset = SentimentTestDataset(tokenizer, test)

In [None]:
def collate_fn_sentiment_test(samples):
    input_ids = samples
    max_len = max(len(input_id) for input_id in input_ids)
    sorted_indices = np.array([len(input_id) for input_id in input_ids])

    input_ids = pad_sequence([torch.tensor(input_id) for input_id in input_ids],
                             batch_first=True)
    attention_mask = torch.tensor(
        [[1] * len(input_id) + [0] * (max_len - len(input_id)) for input_id in
         input_ids])
    token_type_ids = torch.tensor([[0] * len(input_id) for input_id in input_ids])
    position_ids = torch.tensor([list(range(len(input_id))) for input_id in input_ids])

    return input_ids, attention_mask, token_type_ids, position_ids

In [None]:
test_batch_size = 32
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size,
                                          shuffle=False, collate_fn=collate_fn_sentiment_test,
                                          num_workers=2)

In [None]:
with torch.no_grad():
    model.eval()
    predictions = []
    for input_ids, attention_mask, token_type_ids, position_ids in tqdm(test_loader,
                                                                        desc='Test',
                                                                        position=1,
                                                                        leave=None):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        position_ids = position_ids.to(device)

        output = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids,
                       position_ids=position_ids)

        logits = output.logits
        batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
        predictions += batch_predictions

In [None]:
test_df['Category'] = predictions

In [None]:
test_df.to_csv('submission.csv', index=False)