## [0] Import

---------------------

In [1]:
#!pip install transformers datasets beartype --upgrade accelerate numba nvidia-ml-py3

In [2]:
import pandas as pd
import numpy
import pickle
import datetime
import time

# Model Save & Load
import os
# GPU Reset
from numba import cuda

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
# from torchsummary import summary

# 모델, Tokenizer Load
from transformers import AutoModel, AutoConfig, AutoTokenizer, LEDForConditionalGeneration, LEDModel, LEDTokenizer, BartModel, LEDConfig

# 데이터셋 Load from Summarize_from_feedvback, Huggingface
from datasets import load_dataset

# RL Training
from beartype.typing import Deque, Tuple, List
from collections import deque, namedtuple


print("This code is written at " + str(datetime.datetime.now()))

This code is written at 2023-06-22 21:00:24.852697


#### GPU Reset & Setting device

In [3]:
"""
def GPU_reset():
    device = cuda.get_current_device()
    device.reset
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(device)

    return device
"""

"\ndef GPU_reset():\n    device = cuda.get_current_device()\n    device.reset\n    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n    print(device)\n\n    return device\n"

In [4]:
#device = GPU_reset()

device0 = "cuda:0"
device1 = "cuda:1"
device2 = "cuda:2"

In [5]:
!nvidia-smi

Thu Jun 22 21:00:25 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.56.06    Driver Version: 520.56.06    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   37C    P0    45W / 300W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  On   | 00000000:A1:00.0 Off |                    0 |
| N/A   38C    P0    46W / 300W |      0MiB / 81920MiB |      0%      Default |
|       

## [*] Hyper Parameter Setting

-------------------

In [6]:
BATCH_SIZE = 1
EPOCH = 1
LEARNING_RATE = 1e-4

## [1] Tokenizer
- Longformer의 Tokenizer 정의
- Global Attention Mask 함수 정의

-----------------------------------

In [7]:
from transformers import LEDTokenizer
tokenizer = LEDTokenizer.from_pretrained("allenai/led-large-16384-arxiv")

In [8]:
def generate_global_attention_mask(tokenizer, input_ids):
    mask = torch.torch.zeros_like(input_ids)
    mask[((input_ids == tokenizer.bos_token_id) | (input_ids == tokenizer.eos_token_id)).nonzero(as_tuple=True)] = 1
    return mask

## [2] DataLoad
- 데이터 Load
- 데이터 전처리
- 데이터 Loader

---------------------------------

### [2.1] Data Preprocessing

- Human Feedback 데이터 Dictionary를 Dataframe으로 변환
- 알맞은 문장 추출
- DataLoader 생성

#### Text DataFrame 생성

In [9]:
'''
class Data_Preprocessing
    - HuggingFace의 Summarize from feedback 전용 데이터 전처리 Class
    - Train 데이터와 Validation 데이터 출력
'''

class Data_Preprocessing:
    def __init__(self):
        # DownLoad Data from huggingFace
        ## Text Summary 데이터
        ## CNN, TL;DR, Daily Mail
        self.data_feedback = load_dataset("ccdv/arxiv-summarization")

        # Split into Train and Validation dataset
        # Convert to DataFrame
        self.df_train = pd.DataFrame(self.data_feedback['train'])
        self.df_valid = pd.DataFrame(self.data_feedback['validation'])


    # Original Text + Summarized Text 데이터 Columm 추출
    def Data_cleaning(self, df):
        df['original_text'] = df['article']
        # df_valid['original_text'] = self.df_valid['article']

        # df_train['sum_text'] = [row['text'] for row in df_train['summary']]
        # df_valid['sum_text'] = [row['text'] for row in df_valid['summary']]

        # df_all = pd.concat([df_train[['original_text', 'sum_text']], df_valid[['original_text', 'sum_text']]], ignore_index=True)

        return df

    # 최종 DataFrame 출력
    def data_complete_form(self):
        df_train = self.Data_cleaning(self.df_train)
        df_valid = self.Data_cleaning(self.df_valid)

        return df_train, df_valid

# 실행 코드
df_train, df_valid = Data_Preprocessing().data_complete_form()

No config specified, defaulting to: arxiv-summarization/section
Found cached dataset arxiv-summarization (/root/.cache/huggingface/datasets/ccdv___arxiv-summarization/section/1.0.0/fa2c9abf4312afb8660ef8e041d576b8e3943ea96ae771bd3cd091b5798e7cc3)


  0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
df_train.head(2)

Unnamed: 0,article,abstract,original_text
0,additive models @xcite provide an important fa...,additive models play an important role in semi...,additive models @xcite provide an important fa...
1,the leptonic decays of a charged pseudoscalar ...,"we have studied the leptonic decay @xmath0 , v...",the leptonic decays of a charged pseudoscalar ...


### [2.2] DataLoader
- 원본 Text와 Summarize가 합쳐진 데이터 형식의 DataFrame을 DataLoader로 처리
- 입력: DataFrame <br>
- Feature : original_with_good_sum,   original_with_bad_sum  </br>
- 내용: 원본 텍스트 + 긍정 Summary, 원본 텍스트 + 부정 Summary

In [11]:
class RL_Dataset(torch.utils.data.Dataset):

    def __init__(self, df_textsum): #, transforms_=None, random_masking = False,  unaligned=True ):

        self.original_text = df_textsum['original_text']
        # self.sum_text = df_textsum['sum_text']
        # self.old_action = df_textsum['old_action_prob']

        print(f"My_dataset __init__ received : {self.original_text.shape}") # , {self.old_action.shape}")
        print(f"Data Type : {type(self.original_text[0])}") #, type(self.old_action[0])}")
        # print(f"Data example : {self.original_text[0]}") # , {self.sum_text[0]}}")

    def __getitem__(self, index):
        original_text = self.original_text[index]
        # sum_text = self.sum_text[index]
        # old_action_prob = self.old_action[index]

        return original_text # , sum_text# old_action_prob


    def __len__(self):
        return len(self.original_text)


In [12]:
df_train.head(2)

Unnamed: 0,article,abstract,original_text
0,additive models @xcite provide an important fa...,additive models play an important role in semi...,additive models @xcite provide an important fa...
1,the leptonic decays of a charged pseudoscalar ...,"we have studied the leptonic decay @xmath0 , v...",the leptonic decays of a charged pseudoscalar ...


In [13]:
print('====================================================================')
print('')
print("TRAIN LOADER")
train_loader = torch.utils.data.DataLoader(RL_Dataset(df_train), batch_size=BATCH_SIZE, shuffle=False, drop_last = False)
print('')
print("====================================================================")
print('')
print("VALID LOADER")
valid_loader = torch.utils.data.DataLoader(RL_Dataset(df_valid), batch_size=BATCH_SIZE, shuffle=False, drop_last = False)
print('')
print("====================================================================")

# test_loader = torch.utils.data.DataLoader(RL_Dataset(df_test), batch_size=BATCH_SIZE, shuffle=False, drop_last = False)
# print('')
# print("====================================================================")


TRAIN LOADER
My_dataset __init__ received : (203037,)
Data Type : <class 'str'>


VALID LOADER
My_dataset __init__ received : (6436,)
Data Type : <class 'str'>



## [3] Model
- Policy model : LED For Conditional Generation
- Input : (token, global_attention_mask) type: tensor
-------------------------

### [3.1] Policy Model

In [14]:
# policy_model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv").to(device)

### [3.2] Reward Model

In [15]:
class RewardModel(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()

        self.device = device

        self.led = AutoModel.from_config(AutoConfig.from_pretrained("allenai/led-large-16384-arxiv")).get_encoder()
        self.bart = AutoModel.from_config(AutoConfig.from_pretrained("facebook/bart-large")).get_encoder()

        self.flatten = nn.Linear(1024, 1)

        self.head = nn.Sequential(
            nn.Linear(17408, 4096),
            nn.ReLU(),
            nn.Linear(4096, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, input_ids, summary_input_ids, global_attention_mask=None):
        hidden_state = self.led(input_ids, global_attention_mask=global_attention_mask).last_hidden_state
        with torch.no_grad():
            output = torch.zeros((hidden_state.shape[0], 16384, 1024)).to(self.device) # head가 fixed size만 받을 수 있으므로 0으로 padding
            output[:, :hidden_state.shape[1], :] = hidden_state

        bart_hidden_state = self.bart(summary_input_ids).last_hidden_state
        bart_output = torch.zeros((bart_hidden_state.shape[0], 1024, 1024)).to(self.device) # head가 fixed size만 받을 수 있으므로 0으로 padding
        bart_output[:, :bart_hidden_state.shape[1], :] = bart_hidden_state

        concat = torch.cat([output.repeat((summary_input_ids.shape[0], 1, 1)), bart_output], dim=1)
        concat = self.flatten(concat)
        result = self.head(concat.transpose(1, 2))

        return result.squeeze()

In [16]:
# reward_model = RewardModel().to(device)

In [17]:
def save_model_info(_model, _version="ver_1"):
    if not os.path.isdir("./RL_policy_model"):
        os.makedirs("./RL_policy_model")
    # 모델 정보 저장
    _model = _model.cpu()
    torch.save({'model_state_dict': _model.state_dict(),
                # 'optimizer_state_dict': _optimizer.state_dict(),
                # 'record_list' : {'train_loss': _train_loss, 'valid_loss': _valid_loss},
                }, f"./policy_model/RL_policy_model_{_version}.pth")  #policy_model_ver_1

    print(f"******************* Model Saved : RL_policy_model_{_version} *******************")

In [18]:
'''
def load_model_info(_file_path):
    """
    if not os.path.exists(_file_path):
        print("FATAL ERROR : model path not exist")
    model_info = torch.load(_file_path)
    print(f"******************* model_loaded FROM {_file_path} *******************")
    """
    config = LEDConfig.from_pretrained("allenai/led-large-16384-arxiv")
    model = LEDForConditionalGeneration(config)

    #model.load_state_dict(model_info['model_state_dict'])
    model.to(device)
    model.eval()

    return model
'''

'\ndef load_model_info(_file_path):\n    """\n    if not os.path.exists(_file_path):\n        print("FATAL ERROR : model path not exist")\n    model_info = torch.load(_file_path)\n    print(f"******************* model_loaded FROM {_file_path} *******************")\n    """\n    config = LEDConfig.from_pretrained("allenai/led-large-16384-arxiv")\n    model = LEDForConditionalGeneration(config)\n\n    #model.load_state_dict(model_info[\'model_state_dict\'])\n    model.to(device)\n    model.eval()\n\n    return model\n'

In [19]:
config = LEDConfig.from_pretrained("allenai/led-large-16384-arxiv")

old_policy_model = LEDForConditionalGeneration(config)
new_policy_model = LEDForConditionalGeneration(config)#LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv")

# old_policy_model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv").to(device)
# new_policy_model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv").to(device)

reward_model = RewardModel(device2)
#reward_model.load_state_dict(torch.load("reward_model_openai_2190.pth"))

In [20]:
#old_policy_model = nn.DataParallel(old_policy_model,  device_ids = [0,1])
#new_policy_model = nn.DataParallel(new_policy_model,  device_ids = [0,1])
#reward_model = nn.DataParallel(reward_model,  device_ids = [0,1])

In [21]:
old_policy_model = old_policy_model.to(device0)
new_policy_model = new_policy_model.to(device1)
reward_model = reward_model.to(device2)

# old_policy_model = old_policy_model.to(device)
# new_policy_model = new_policy_model.to(device)
# reward_model = reward_model.to(device)

In [22]:
old_policy_model.device, new_policy_model.device, reward_model.device

(device(type='cuda', index=0), device(type='cuda', index=1), 'cuda:0')

In [23]:
!nvidia-smi

Thu Jun 22 21:01:05 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.56.06    Driver Version: 520.56.06    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   37C    P0    63W / 300W |   4365MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  On   | 00000000:A1:00.0 Off |                    0 |
| N/A   37C    P0    65W / 300W |   2323MiB / 81920MiB |      0%      Default |
|       

## [4] Reinforcement Learning Training
---

In [24]:
import nvidia_smi

torch.cuda.empty_cache()
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)

In [25]:
optimizer = optim.AdamW(new_policy_model.parameters(), lr=1e-08)

In [26]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "true"

In [27]:
dataset = load_dataset("ccdv/arxiv-summarization")

No config specified, defaulting to: arxiv-summarization/section
Found cached dataset arxiv-summarization (/root/.cache/huggingface/datasets/ccdv___arxiv-summarization/section/1.0.0/fa2c9abf4312afb8660ef8e041d576b8e3943ea96ae771bd3cd091b5798e7cc3)


  0%|          | 0/3 [00:00<?, ?it/s]

In [28]:
train_loss = []
start_time = time.time()

# device0 - old
# device1 - new
# device2 - reward

new_policy_model.train()
old_policy_model.eval()
reward_model.eval()
for index, origin in enumerate(dataset["train"]["article"]):    
  # Preparation
  origin_token = tokenizer.batch_encode_plus([origin], padding=True, return_tensors='pt').input_ids
  print(origin_token[:,:5])

  if origin_token.shape[1] > 16000:
      continue
    
  print(origin_token.device, origin_token.shape)
  origin_mask = generate_global_attention_mask(tokenizer, origin_token)
  # print("ORIGIN_TOKEN ", origin_token.size())
  # print("ORIGIN MASK ", origin_mask.size())
  # info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
  # print(f" GPU: {100 * (1 - info.free / info.total)}% used")
    
  # Old Actor output
  print("1. START", end="/")
  !nvidia-smi
  action_logits = new_policy_model(input_ids = origin_token.to(device1), decoder_input_ids = origin_token.to(device1)[:,:1000], global_attention_mask = origin_mask.to(device1)).logits
  print("2. ON", end="/")
  action_prob = torch.softmax(action_logits, dim=-1).to(device0)
    
  print("3 ON", end="/")
  # Actor output
  with torch.no_grad():
      old_action_logits = old_policy_model(input_ids = origin_token.to(device0), decoder_input_ids = origin_token.to(device0)[:,:1000], global_attention_mask = origin_mask.to(device0)).logits
  print("4 ON", end="/")
  old_action_prob = torch.softmax(old_action_logits, dim=-1)

    
  print("5 ON", end="/")
  sum_token = torch.argmax(action_prob, dim=-1 ).cpu()

    
  print("6 ON")  
  print(origin_token.shape, sum_token.shape, origin_mask.shape)
  print("6 ON")  
  print(torch.max(origin_token), torch.max(sum_token), torch.max(origin_mask))
  reward = reward_model(origin_token.to(device2), sum_token.to(device2), origin_mask.to(device2))
    
  # Dimension Padding
  if old_action_prob.size(1) != action_prob.size(1):
    print("7 ON", end="/")
    max_dim = max(old_action_prob.size(1), action_prob.size(1))
    padding_size = abs(action_prob.size(1) - old_action_prob.size(1))

    padding_token = tokenizer.pad_token_id

    padding_token_matrix = torch.full((action_prob.size(0), padding_size, 50265), padding_token)
    action_prob_matrix = torch.zeros((action_prob.size(0), padding_size, 50265))
    
    action_prob_matrix[ : , :, padding_token] = 1
    padding_token_matrix[..., :] = action_prob_matrix

    if old_action_prob.size(1) < max_dim:
      old_action_prob = torch.cat([old_action_prob, padding_token_matrix], dim=1)
    else:
      action_prob = torch.cat([action_prob, padding_token_matrix], dim=1)

    
  # Loss Calculation
  action_log_prob = torch.log(action_prob + 1e-8)
    
  old_action_log_prob = torch.log(old_action_prob + 1e-8)
  print("8 ON", end="/")
  KL_divergence = F.kl_div(input = old_action_log_prob,
                           target = action_log_prob,
                           reduction='mean',  # 'none' | 'batchmean' | 'sum' | 'mean'
                           log_target=True)

  print("9 ON", end="/")
  R = reward + KL_divergence

  ratio = (action_log_prob - old_action_log_prob).exp()

  surr1 = ratio * R

  surr2 = (torch.clamp(ratio, 0.8, 1.2)* R)

  loss = -torch.min(surr1, surr2).mean()

  optimizer.zero_grad()
  print("10 ON", end="/")
  loss.backward()
  print("11 ON", end="/")
  optimizer.step()
  
  train_loss.append(loss.item())
  print("LOSS : ", loss.item())
  del loss
  del action_logits
  del action_prob
  del old_action_logits
  del old_action_prob
  del sum_token
  del reward
  del action_log_prob
  del old_action_log_prob
  del KL_divergence
  del R
  del ratio
  del surr1
  del surr2
  """
  if (index+1)%1 == 0:
    mean = sum(train_loss)/len(train_loss)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    end_time = time.time()
    print("=============================================================================================================================")
    print(f" Index: {index+1} / {len(train_loader)} \t LOSS: {mean :.4f} \t \n GPU: {100 * (1 - info.free / info.total) :.4f}% used \t Elapsed Time : {end_time-start_time :.2f}secs \n")
    train_loss = []
  
  if (index+1)%10 == 0: 
    save_model_info(new_policy_model, f"ver_{(index+1)//300}")
    new_policy_model.to(device)
  """
  

tensor([[    0,  4917, 15589,  3092,   787]])
cpu torch.Size([1, 6791])
1. START/Thu Jun 22 21:01:12 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.56.06    Driver Version: 520.56.06    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   37C    P0    66W / 300W |   4365MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  On   | 00000000:A1:00.0 Off |                    0 |
| N/A  



11 ON/LOSS :  0.03374732285737991
tensor([[    0,   627,  2084,  3320, 10003]])
cpu torch.Size([1, 4830])
1. START/Thu Jun 22 21:01:17 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.56.06    Driver Version: 520.56.06    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   37C    P0    64W / 300W |  28351MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  On   | 00000000:A1:00.0 Of

Token indices sequence length is longer than the specified maximum sequence length for this model (21321 > 16384). Running this sequence through the model will result in indexing errors


11 ON/LOSS :  0.03502907603979111
tensor([[    0,  5605,  9779, 26683,  1635]])
tensor([[    0, 26302,   918,     9, 16426]])
cpu torch.Size([1, 2919])
1. START/Thu Jun 22 21:01:53 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.56.06    Driver Version: 520.56.06    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:81:00.0 Off |                    0 |
| N/A   39C    P0    68W / 300W |  65025MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  

UnboundLocalError: local variable 'child' referenced before assignment

In [None]:
print(next(reward_model.parameters()).device)

In [None]:
print(

##### 모델 불러오기

In [None]:
# def save_model_info(self, _model, _version="ver_1"):
#     if not os.path.isdir("./policy_model"):
#         os.makedirs("./policy_model")
#     # 모델 정보 저장
#     _model = _model.cpu()
#     torch.save({'model_state_dict': _model.state_dict(),
#                 }, f"./policy_model/policy_model_{_version}.pth")  #policy_model_ver_1

#     print(f"******************* Model Saved : policy_model_{_version} *******************")

## [5] Memory 채우기

In [None]:
# old_action_prob = None

# optimizer = optim.AdamW(policy_model.parameters(), lr=LEARNING_RATE)

# action_prob_list = []

# for index, (origin, sum) in enumerate(train_loader):

#   # Preparation
#   origin_token = tokenizer.batch_encode_plus(origin, padding=True, return_tensors='pt').input_ids.to(device)

#   origin_mask = generate_global_attention_mask(tokenizer, origin_token).to(device)

#   # Action
#   action_logits = policy_model(input_ids = origin_token, global_attention_mask = origin_mask).logits # , labels = sum_token)

#   action_prob = torch.softmax(action_logits, dim=-1)

#   action_prob = action_prob.squeeze(0).detach().cpu().numpy()
#   action_prob_list.append(action_prob)

#   if (index+1)%200 ==0:
#     print("200")

In [None]:
# import pickle

# # 출력 결과를 pickle로 저장
# with open('./old_action_200.pickle', 'wb') as f:
#     pickle.dump(action_prob_list, f)

# # # 불러온 출력 결과 확인
# # print(loaded_output)

-----------------------------

In [None]:
class RL_Trainer:
    def __init__(self, _epoch= EPOCH, _policy_model = policy_model, _reward_model = reward_model, _tokenizer = tokenizer, _lr = LEARNING_RATE, _train_loader = train_loader, _valid_loader = valid_loader):

        # Training 관련
        self.policy_model = _policy_model
        self.reward_model = _reward_model

        self.tokenizer = _tokenizer

        self.epoch = _epoch

        self.optimizer = optim.AdamW(self.policy_model.parameters(), lr=_lr)
        ## CHECK: Reward Optimizer 있어야 할 듯 ?

        self.train_loader = _train_loader
        self.valid_loader = _valid_loader

    # 모델 저장
    def save_model_info(self, _model, _optimizer, _train_loss = [], _valid_loss= [], _version="ver_1"):
        if not os.path.isdir("./policy_model"):
            os.makedirs("./policy_model")
        # 모델 정보 저장
        torch.save({'model_state_dict': _model.state_dict(),
                    'optimizer_state_dict': _optimizer.state_dict(),
                    'record_list' : {'train_loss': _train_loss, 'valid_loss': _valid_loss},
                    }, f"./policy_model/policy_model_{_version}.pth")  #reward_model_ver_1

        print(f"model_saved : policy_model_{_version}")



    ''' ======================================= 매   우   중   요 ============================================================='''
    # 모델 Training
    def train(self):

        ## 초기화
        train_loss_list = []
        valid_loss_list = []

        record_train_loss = []
        record_valid_loss = []

        # Optimizer & Loss function
        optimizer = self.optimizer

        # Data Loader
        train_loader = self.train_loader
        valid_loader= self.valid_loader

        # 모델 정의
        model = self.policy_model
        tokenizer = self.tokenizer

        # Hyper Parameter
        epoch = self.epoch

        for i in range(epoch):
            start_time = time.time()

            model.train()

            for index, (original_text, sum_text) in enumerate(train_loader):

                original_token = tokenizer.batch_encode_plus(original_text, padding=True, return_tensors='pt').input_ids
                sum_token = tokenizer.batch_encode_plus(sum_text, padding=True, return_tensors='pt').input_ids

                original_attention_mask = generate_global_attention_mask(tokenizer, original_token)

                original_token = original_token.to(device)
                original_attention_mask = original_attention_mask.to(device)
                sum_token = sum_token.to(device)

                output = model(input_ids = original_token,
                               global_attention_mask = original_attention_mask,
                               labels = sum_token)

                # Log Sigmoid
                loss = output[0]

                optimizer.zero_grad()

                loss.backward()

                optimizer.step()

                end_time = time.time()
                train_loss_list.append(loss.item())

#                 print(f"========================{index+1}===========================")
#                 print(f"Loss : {loss.item()}")

                if (index+1)%200 == 0:
                    # Validation loss 계산
                    valid_loss_list = []

                    ##### 이거 validatoin 끝나고 model.train() 있는지 꼭꼭 확인 #####
                    model.eval()

                    with torch.no_grad():
                        for valid_index, (valid_original_text, valid_sum_text) in enumerate(valid_loader):

                            original_token = tokenizer.batch_encode_plus(valid_original_text, padding=True, return_tensors='pt').input_ids
                            sum_token = tokenizer.batch_encode_plus(valid_sum_text, padding=True, return_tensors='pt').input_ids

                            original_attention_mask = generate_global_attention_mask(tokenizer, original_token)

                            original_token = original_token.to(device)
                            sum_token = sum_token.to(device)
                            original_attention_mask= original_attention_mask.to(device)

                            valid_output = model(input_ids = original_token,
                                                 global_attention_mask = original_attention_mask,
                                                 labels = sum_token)

                            valid_loss = valid_output[0]
                            valid_loss_list.append(valid_loss.item())

                            if (valid_index+1)%20 == 0:
                                break;

                    model.train()

                    train_loss_mean = sum(train_loss_list) / len(train_loss_list)
                    valid_loss_mean = sum(valid_loss_list) / len(valid_loss_list)

                    print("==================================================================================")
                    print(f"Batch {(index+1)}  ({((index+1)/len(train_loader))*100 :.3f} %) \t \
                            Train Loss : {train_loss_mean :.4f} \t \
                            Valid Loss : {valid_loss_mean :.4f} \t \
                            Elapsed Time: {(end_time - start_time) :.2f} sec")

                    train_loss_list = []
                    record_train_loss.append(train_loss_mean)
                    record_valid_loss.append(valid_loss_mean)

                if (index+1)%1000 == 0:
                    self.save_model_info(model, optimizer, record_train_loss, record_valid_loss, f"ver_{(index+1)//1000}")

        return model, record_train_loss, record_valid_loss





In [None]:
policy_trainer = RL_Trainer(EPOCH,
                                policy_model,
                                tokenizer,
                                LEARNING_RATE,
                                train_loader,
                                valid_loader)


In [None]:
model, record_train_loss, record_valid_loss = policy_trainer.train()

----------------------