# Import requirements

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install simpletransformers

In [2]:
import os
import pdb
import argparse
from dataclasses import dataclass, field
from typing import Optional
from collections import defaultdict

import torch
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm import tqdm, trange

from transformers import (
    XLNetForSequenceClassification,
    XLNetTokenizer,
    XLNetModel,
    AutoConfig,
    AdamW
)

# 1. Preprocess

In [None]:
def make_id_file(task, tokenizer):
    def make_data_strings(file_name):
        data_strings = []
        with open(os.path.join(file_name), 'r', encoding='utf-8') as f:
            id_file_data = [tokenizer.encode(line.lower()) for line in f.readlines()]
        for item in id_file_data:
            data_strings.append(' '.join([str(k) for k in item]))
        return data_strings
    
    print('it will take some times...')
    train_pos = make_data_strings('sentiment.train.1')
    train_neg = make_data_strings('sentiment.train.0')
    dev_pos = make_data_strings('sentiment.dev.1')
    dev_neg = make_data_strings('sentiment.dev.0')

    print('make id file finished!')
    return train_pos, train_neg, dev_pos, dev_neg

In [None]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
!ls

In [None]:
train_pos, train_neg, dev_pos, dev_neg = make_id_file('yelp', tokenizer)

In [None]:
train_pos[:10]

In [None]:
class SentimentDataset(object):
    def __init__(self, tokenizer, pos, neg):
        self.tokenizer = tokenizer
        self.data = []
        self.label = []

        for pos_sent in pos:
            self.data += [self._cast_to_int(pos_sent.strip().split())]
            self.label += [[1]]
        for neg_sent in neg:
            self.data += [self._cast_to_int(neg_sent.strip().split())]
            self.label += [[0]]

    def _cast_to_int(self, sample):
        return [int(word_id) for word_id in sample]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return np.array(sample), np.array(self.label[index])

In [None]:
train_dataset = SentimentDataset(tokenizer, train_pos, train_neg)
dev_dataset = SentimentDataset(tokenizer, dev_pos, dev_neg)

In [None]:
for i, item in enumerate(train_dataset):
    print(item)
    if i == 10:
        break

In [None]:
def collate_fn_style(samples):
    input_ids, labels = zip(*samples)
    max_len = max(len(input_id) for input_id in input_ids)
    sorted_indices = np.argsort([len(input_id) for input_id in input_ids])[::-1]

    input_ids = pad_sequence([torch.tensor(input_ids[index]) for index in sorted_indices],
                             batch_first=True)
    attention_mask = torch.tensor(
        [[1] * len(input_ids[index]) + [0] * (max_len - len(input_ids[index])) for index in
         sorted_indices])
    token_type_ids = torch.tensor([[0] * len(input_ids[index]) for index in sorted_indices])
    position_ids = torch.tensor([list(range(len(input_ids[index]))) for index in sorted_indices])
    labels = torch.tensor(np.stack(labels, axis=0)[sorted_indices])

    return input_ids, attention_mask, token_type_ids, position_ids, labels

In [None]:
# random seed
random_seed=42
np.random.seed(random_seed)
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def compute_acc(predictions, target_labels):
    return (np.array(predictions) == np.array(target_labels)).mean()

In [None]:
class Trainer():
    def __init__(self, device, output_path, lr, resume_path):
        self.output_path = output_path
        self.device = device
        self.model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased')
        self.optimizer = AdamW(self.model.parameters(), lr=lr)
        
        #For learning curve
        self.training_stats = []
        if resume_path :
            checkpoint = torch.load(resume_path, map_location=device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.start_train_epoch = checkpoint['epoch'] + 1
            self.lowest_valid_loss = checkpoint['lowest_valid_loss']
        else:
            self.start_train_epoch = 0
            self.lowest_valid_loss = 9999.
        self.model.to(self.device)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    def training(self, train_loader, dev_loader, last_epoch):
        self.model.train()
        for epoch in range(self.start_train_epoch, last_epoch):
            with tqdm(train_loader, unit="batch") as tepoch:
                for iteration, (input_ids, attention_mask, token_type_ids, position_ids, labels) in enumerate(tepoch):
                    
                    tepoch.set_description(f"Epoch {epoch}")
                    input_ids = input_ids.to(self.device)
                    attention_mask = attention_mask.to(self.device)
                    token_type_ids = token_type_ids.to(self.device)
                    position_ids = position_ids.to(self.device)
                    labels = labels.to(self.device, dtype=torch.long)

                    self.optimizer.zero_grad()

                    output = self.model(input_ids=input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               labels=labels)

                    loss = output.loss
                    loss.backward()

                    self.optimizer.step()
                    
                
                    tepoch.set_postfix(loss=loss.item())
                    
                    if iteration != 0 and iteration % int(len(train_loader) / 10) == 0:
                        # Evaluate the model five times per epoch
                        
                        with torch.no_grad():

                            self.model.eval()
                            valid_losses = []
                            predictions = []
                            target_labels = []
                            for input_ids, attention_mask, token_type_ids, position_ids, labels in tqdm(dev_loader,
                                                                                                    desc='Eval',
                                                                                                    position=1,
                                                                                                    leave=None):
                                
                                
                                input_ids = input_ids.to(self.device)
                                attention_mask = attention_mask.to(self.device)
                                token_type_ids = token_type_ids.to(self.device)
                                position_ids = position_ids.to(self.device)
                                labels = labels.to(self.device, dtype=torch.long)

                                output = self.model(input_ids=input_ids,
                                           attention_mask=attention_mask,
                                           token_type_ids=token_type_ids,
                                           position_ids=position_ids,
                                           labels=labels)

                                logits = output.logits
                                loss = output.loss
                                valid_losses.append(loss.item())

                                batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
                                batch_labels = [int(example) for example in labels]

                                predictions += batch_predictions
                                target_labels += batch_labels

                        print(epoch)
                        acc = compute_acc(predictions, target_labels)
                        valid_loss = sum(valid_losses) / len(valid_losses)
                        
                        # For Learning curve  
                        #각 epoch의 iteration 마다 Train,Validation의 Loss 와 accuracy 를 저장
                        self.training_stats.append( {
                                'epoch': epoch + 1,
                                'iteration' : iteration,
                                'Training Loss': loss.item(),
                                'Valid Loss': valid_loss,
                                'Valid Accuracy': acc})
                        if self.lowest_valid_loss > valid_loss:
                            print('')
                            print('Acc for model which have lower valid loss: ', acc)
                            #torch.save(self.model.state_dict(), "/content/drive/MyDrive/Colab Notebooks/NLP/project/pytorch_model.bin")
                            torch.save({
                                'epoch': epoch,
                                'lowest_valid_loss': self.lowest_valid_loss,
                                'optimizer_state_dict': self.optimizer.state_dict(),
                                'model_state_dict': self.model.state_dict(),
                                }, f'{self.output_path}/checkpoint_epoch_{epoch}.{iteration}.pth')
                            self.lowest_valid_loss = valid_loss
                            print('--------------save checkpoint at epoch : {}--------------'.format(epoch))
                            print('--------------lowest_valid_loss : {}--------------'.format(self.lowest_valid_loss))

In [None]:
train_batch_size=32
eval_batch_size=32

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=train_batch_size,
                                           shuffle=True, collate_fn=collate_fn_style,
                                           pin_memory=True, num_workers=2)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=eval_batch_size,
                                         shuffle=False, collate_fn=collate_fn_style,
                                         num_workers=2)

# 경로설정
output_path = '/content/drive/MyDrive/Colab Notebooks/NLP/project/checkpoints'
resume_path = None 

lr = 5e-5 
last_epoch = 5

trainer = Trainer(device,output_path,lr,resume_path)
trainer.training(train_loader, dev_loader, last_epoch)