# 패키지 불러오기

In [None]:
import os
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from kogpt2_transformers import get_kogpt2_tokenizer

# 데이터 로더

In [None]:
class ReviewAutoRegressiveDataset(Dataset):
    
    def __init__(self,
               file_path = "QAdataset_0820.txt",
               n_ctx = 512
               ):
        self.file_path = file_path
        self.data =[]
        self.tokenizer = get_kogpt2_tokenizer()
        
        bos_token_id = [self.tokenizer.bos_token_id]
        eos_token_id = [self.tokenizer.eos_token_id]
        pad_token_id = [self.tokenizer.pad_token_id]
        
        file = open(self.file_path, 'r', encoding='utf-8')
        
        while True:
            line = file.readline()
            if not line:
                break
            datas = line.split("    ")
            if len(datas) > 2:
                new_list, tmp = ['', ''], ''
                for i in range(1, len(datas)):
                    tmp += datas[i]
                new_list[0], new_list[1] = datas[0], tmp
                datas = new_list

            try:     
                index_of_words = bos_token_id +self.tokenizer.encode(datas[0]) + eos_token_id + bos_token_id + self.tokenizer.encode(datas[1][:-1])+ eos_token_id
                pad_token_len = n_ctx - len(index_of_words)

                index_of_words += pad_token_id * pad_token_len

                self.data.append(index_of_words)
            except IndexError:
                continue

        file.close()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        item = self.data[index]
        return item

if __name__ == "__main__":
  dataset = ReviewAutoRegressiveDataset()

# KoGPT-2 로드

In [None]:
class DialogKoGPT2(nn.Module):
    def __init__(self):
        super(DialogKoGPT2, self).__init__()
        self.kogpt2 = get_kogpt2_model()

    def generate(self,
               input_ids,
               do_sample=True,
               max_length= 60,
               top_p=0.92,
               top_k=50,
               temperature= 0.6,
               no_repeat_ngram_size =None,
               num_return_sequences=3,
               early_stopping=False,
               ):
        return self.kogpt2.generate(input_ids,
                   do_sample=do_sample,
                   max_length=max_length,
                   top_p = top_p,
                   top_k=top_k,
                   temperature=temperature,
                   no_repeat_ngram_size= no_repeat_ngram_size,
                   num_return_sequences=num_return_sequences,
                   early_stopping = early_stopping,
                  )
    
    def forward(self, input, labels = None):
        if labels is not None:
            outputs = self.kogpt2(input, labels=labels)
        else:
            outputs = self.kogpt2(input)
        
        return outputs

# 데이터 학습

In [None]:
root_path=''
data_path = f"{root_path}QAdataset_0820.txt"
save_ckpt_path = "./0820_kogpt2-review-auto-regressive_v1.pth"

n_epoch = 5  
batch_size = 1 
ctx = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(ctx)
save_step = 100 
learning_rate = 5e-5

dataset= ReviewAutoRegressiveDataset(data_path)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = DialogKoGPT2()
model.to(device)

loss_fct = torch.nn.CrossEntropyLoss(ignore_index=3)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

losses =[]
for epoch in range(n_epoch):
    count = 0
    with tqdm(total=len(train_loader), desc=f"Train({epoch})") as pbar:
        for i, data in enumerate(train_loader):
            optimizer.zero_grad()
            data = torch.stack(data) 
            data = data.transpose(1, 0)
            data= data.to(ctx)

            outputs = model(data, labels=data)
            _, logits = outputs[:2]

            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = data[..., 1:].contiguous()

            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            if (count > 0 and count % save_step == 0) or (len(data) < batch_size):
                torch.save({
                    'epoch': epoch,
                    'train_no': count,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss
                }, save_ckpt_path)
            count += 1
            pbar.update(1)
            pbar.set_postfix_str(f"Loss: {loss.item():.3f} ({np.mean(losses):.3f})")