In [None]:
!pip install transformers

In [None]:
!pip install sentencepiece

In [1]:
# import
import re
import json
import torch
import random
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, random_split
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPT2LMHeadModel


  from .autonotebook import tqdm as notebook_tqdm


## Dataset

In [2]:
# Dataset class
class BeautyDataset(Dataset):
    def __init__(self, txt_list, label_list, tokenizer, max_length):
        # define variables    
        self.input_ids = []
        self.attn_masks = []
        #self.labels = []
        
        # iterate through the dataset
        # truncate long content > 256
        count = 1
        for txt, label in zip(txt_list, label_list):
            txt= txt.replace("\n", "")
            txt_ = (txt[:256]) if len(txt) > 256 else txt
            if count ==1:
                print("label : ", label)
                count = 0
            # prepare the text
            prep_txt = f'<s>Content: {txt_}[SEP] Title: {label}</s>'
            # tokenize
            encodings_dict = tokenizer(prep_txt, truncation=True,
                                       max_length=max_length, padding="max_length")
            # append to list
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
            #self.labels.append(label)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]


In [3]:
def read_content_title():
  import glob

  path = r'/home/khinkhant2hlaing/text-generation-research/all' # folder
  all_files = glob.glob(path + "/*.csv")

  list_ = []

  for filename in all_files:
    try:
      #print(filename)
      df_t = pd.read_csv(filename, index_col=None, header=None, skiprows=1, encoding='utf-8')
      list_.append(df_t)
    except:
      print(f"reading error in{filename}")



  frame = pd.concat(list_, axis=0, ignore_index=True)
  print(f"length of data frame: {len(frame)}")
  return frame

In [4]:
# Data load function
def load_beauty_dataset(tokenizer, random_seed = 1):
    # load dataset and sample.
    #df = pd.read_csv(file_path, encoding='ISO-8859-1', header=None)
    df = read_content_title()
    df = df[[0, 1]]
    df.columns = ['content', 'title']
    #df = df.sample(20000, random_state=1)
    
    max_length = max([len(tokenizer.encode(description)) for description in df['content']])
    print("Max length: {}".format(max_length))

    dataset = BeautyDataset(df['content'].tolist(), df['title'].tolist(), tokenizer, max_length=512)
    dataset.__getitem__(5)

    
    train_size = int(0.9 * len(dataset))
    train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
    print(len(dataset))

    # return
    return train_dataset, val_dataset

In [5]:
from transformers import T5Tokenizer, AutoModelForCausalLM, GPT2LMHeadModel
  
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium", bos_token='<s>', eos_token='</s>', pad_token='<pad>')


model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium").cuda()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model.resize_token_embeddings(len(tokenizer))

Embedding(32001, 1024)

In [7]:
for trial_no in range(1):
  print("Loading dataset...")
  train_dataset, val_dataset = load_beauty_dataset(tokenizer, trial_no)
  

Loading dataset...
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_796718.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_427210.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_959723.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_410930.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_899707.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_927244.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_667202.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_687498.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_761047.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_464814.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_308484.csv
reading error in/home/khinkhant2hla

reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_833496.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_580047.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_962220.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_484913.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_759824.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_657314.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_776394.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_912263.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_857602.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_541467.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_450957.csv
reading error in/home/khinkhant2hlaing/text-generation

reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_997207.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_426099.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_885282.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_742658.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_654533.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_489494.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_866750.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_472777.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_516918.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_866299.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_897646.csv
reading error in/home/khinkhant2hlaing/text-generation

reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_698902.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_558784.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_820667.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_828239.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_572339.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_947894.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_523099.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_550979.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_891784.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_978668.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_477400.csv
reading error in/home/khinkhant2hlaing/text-generation

reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_814048.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_740942.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_432373.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_871044.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_917361.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_432091.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_809322.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_687553.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_914840.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_327371.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_726665.csv
reading error in/home/khinkhant2hlaing/text-generation

reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_466609.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_986625.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_460053.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_933405.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_621981.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_283249.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_675178.csv
reading error in/home/khinkhant2hlaing/text-generation-research/all/beauty_860874.csv
length of data frame: 37366
Max length: 2404
label :  目指すは60年代の海外女優！レトロショートでおしゃれ度UP


  f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"


37366


In [None]:
print("Start training...")
training_args = TrainingArguments(output_dir=r'/home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title', 
                                num_train_epochs=6, 
                                logging_steps=5000, load_best_model_at_end=True,
                                save_strategy='steps',
                                evaluation_strategy="steps",
                                save_steps=10000,
                                per_device_train_batch_size=12, per_device_eval_batch_size=12,
                                learning_rate=0.001,
                                warmup_steps=1, weight_decay=0.0001, logging_dir='logs')


trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset,
          eval_dataset=val_dataset, data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                                                'attention_mask': torch.stack([f[1] for f in data]),
                                                                'labels': torch.stack([f[0] for f in data])})
trainer.train()                                                  


***** Running training *****
  Num examples = 33629
  Num Epochs = 6
  Instantaneous batch size per device = 12
  Total train batch size (w. parallel, distributed & accumulation) = 12
  Gradient Accumulation steps = 1
  Total optimization steps = 16818
  Number of trainable parameters = 336129024


Start training...


Step,Training Loss,Validation Loss
5000,0.4485,0.41802
10000,0.3004,0.448923


***** Running Evaluation *****
  Num examples = 3737
  Batch size = 12
***** Running Evaluation *****
  Num examples = 3737
  Batch size = 12
Saving model checkpoint to /home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title/checkpoint-10000
Configuration saved in /home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title/checkpoint-10000/config.json
Model weights saved in /home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title/checkpoint-10000/pytorch_model.bin


In [None]:
import os

output_dir = '/home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title/23-11-content-title-jpt2'

# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
import shutil
model_file = r'/home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title/23-11-content-title-jpt2'

shutil.make_archive(model_file, 'zip', 'result-content-tilte')

In [None]:
model = AutoModelForCausalLM.from_pretrained("/home/khinkhant2hlaing/text-generation-research/gpt-2_finetune/result-content-title/23-11-content-title-jpt2").cuda()


In [None]:
text = '【ブローチェ　アヴェダ】また、ピンクブラウンにはメリットが多いのも魅力の一つ。自分に似合う髪色が見つからないという方にとって魅力的なメリットがたくさんあります◎【ピンクブラウンのメリット】①光に当たった時の柔らかさと透明感 ②日本人の肌に馴染みのいい色味 ③パーソナルカラーのイエベさんもブルベさんも取り入れやすい色合い ここからはそんなピンクブラウンについて、さまざまな角度からおすすめヘアスタイルを紹介していきます。光に当たった時の柔らかさと透明感日本人の肌に馴染みのいい色味パーソナルカラーのイエベさんもブルベさんも取り入れやすい色合い'
prompt = f'Content: {text}[SEP] Title:'
generated = tokenizer(f"<s> {prompt}", return_tensors="pt").input_ids.cuda()
sample_outputs = model.generate(generated, do_sample=False, top_k=50, max_length=256, top_p=0.90, 
            temperature=0, num_return_sequences=0).cuda()
pred_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
           

In [None]:
print(pred_text)

ref: https://qiita.com/m__k/items/36875fedf8ad1842b729