In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import logging
from tqdm import tqdm
import math
import argparse
import os

In [None]:
!git clone https://github.com/huggingface/transformers
!pip install transformers/
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

Cloning into 'transformers'...
remote: Enumerating objects: 141205, done.[K
remote: Counting objects: 100% (1178/1178), done.[K
remote: Compressing objects: 100% (521/521), done.[K
remote: Total 141205 (delta 685), reused 939 (delta 545), pack-reused 140027[K
Receiving objects: 100% (141205/141205), 139.29 MiB | 21.95 MiB/s, done.
Resolving deltas: 100% (105657/105657), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./transformers
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers==4.30.0.dev0)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers=

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=88888)
parser.add_argument("--model_name", default="gpt2", type=str)
parser.add_argument("--max_seq_length", default=512, type=int)
parser.add_argument("--train_batch_size", default=2, type=int)
parser.add_argument("--valid_batch_size", default=2, type=int)
parser.add_argument("--num_train_epochs", default=1, type=int)
parser.add_argument("--warmup", default=0.1, type=float)
parser.add_argument("--learning_rate", default=5e-5, type=float)
parser.add_argument("--input_text_path", default='../content/', type=str)
args, _ = parser.parse_known_args()

In [None]:
DATAPATH=args.input_text_path
def combinetext(prompt, story):
    fp=open(os.path.join(DATAPATH,prompt),encoding='utf8')
    fs=open(os.path.join(DATAPATH,story),encoding='utf8')
    prompts=fp.readlines()
    stories=fs.readlines()
    assert len(prompts)==len(stories)
    combine=[]
    for i in range(len(prompts)):
        combine.append(prompts[i].rstrip()+' <sep> '+" ".join(stories[i].split()[:300]))
    return combine

#do a littel text clean with punctuations
def cleanpunctuation(s):
    for p in '!,.:;?':
        s=s.replace(' '+p,p)
    s=s.replace(' '+'n\'t','n\'t')
    s=s.replace(' '+'\'s','\'s')
    s=s.replace(' '+'\'re','\'re')
    s=s.replace(' '+'\'ve','\'ve')
    s=s.replace(' '+'\'ll','\'ll')
    s=s.replace(' '+'\'am','\'am')
    s=s.replace(' '+'\'m','\'m')
    s=s.replace(' '+'\' m','\'m')
    s=s.replace(' '+'\'m','\'m')
    s=s.replace(' '+'\' ve','\'ve')
    s=s.replace(' '+'\' s','\'s')
    s=s.replace('<newline>','\n')
    return s   

train_text=combinetext('valid.wp_source', 'valid.wp_target')
train_text=list(map(cleanpunctuation,train_text))
valid_text=combinetext('test.wp_source', 'test.wp_target')
valid_text=list(map(cleanpunctuation,valid_text))

In [None]:
train_text[6]

"[ WP ] Everyone in the world has magic with various levels of mastery over it. You are extremely powerful with almost no control so you find a demon that's very weak but extremely good at controlling his powers. <sep> `` Imagine you're in a field. '' Green extends in all directions. `` You're alone, the earth is flat, and the blue sky touches the horizon. '' Blue shoots from the ground, arcing overhead. `` The sun appears, tiny in the sky. '' There's a bright light, rays casting shadow behind me. `` What color is it? '' \n \n `` Yellow. '' It burns so brightly, winking playfully. \n \n `` Good. '' She licks her chapped lips, the sound distorting my tiny sun's light. `` Look ahead of you. There's a sheep. '' Something soft and downy wanders across the green, its shadow stretching far beyond the horizon. `` What color is it? '' \n \n My brows crease. `` Uh- '' \n \n `` What color is it? '' \n \n The green wavers. Baa baa black sheep, have you any wool? `` Uh. '' Mary had a little lamb, 

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token=tokenizer.eos_token

inputs_train = tokenizer(train_text, padding=True,truncation=True,max_length=args.max_seq_length)
inputs_valid=tokenizer(valid_text, padding=True,truncation=True,max_length=args.max_seq_length)

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
def create_labels(inputs):
    labels=[]
    for ids,attention_mask in zip(inputs['input_ids'],inputs['attention_mask']):
        label=ids.copy()
        real_len=sum(attention_mask)
        padding_len=len(attention_mask)-sum(attention_mask)
        label[:]=label[:real_len]+[-100]*padding_len
        labels.append(label)
    inputs['labels']=labels
    
create_labels(inputs_train)
create_labels(inputs_valid)

In [None]:
print(inputs_train['input_ids'][6])
print(inputs_train['attention_mask'][6])
print(inputs_train['labels'][6])

[58, 28993, 2361, 11075, 287, 262, 995, 468, 5536, 351, 2972, 2974, 286, 30677, 625, 340, 13, 921, 389, 4457, 3665, 351, 2048, 645, 1630, 523, 345, 1064, 257, 3222, 326, 338, 845, 4939, 475, 4457, 922, 379, 12755, 465, 5635, 13, 1279, 325, 79, 29, 7559, 18450, 345, 821, 287, 257, 2214, 13, 10148, 3469, 14582, 287, 477, 11678, 13, 7559, 921, 821, 3436, 11, 262, 4534, 318, 6228, 11, 290, 262, 4171, 6766, 18105, 262, 17810, 13, 10148, 4518, 20611, 422, 262, 2323, 11, 610, 2259, 16965, 13, 7559, 383, 4252, 3568, 11, 7009, 287, 262, 6766, 13, 10148, 1318, 338, 257, 6016, 1657, 11, 24823, 13092, 9082, 2157, 502, 13, 7559, 1867, 3124, 318, 340, 30, 10148, 220, 198, 220, 198, 7559, 12550, 13, 10148, 632, 20246, 523, 35254, 11, 266, 8040, 711, 2759, 13, 220, 198, 220, 198, 7559, 4599, 13, 10148, 1375, 300, 3378, 607, 442, 6320, 11914, 11, 262, 2128, 1233, 24707, 616, 7009, 4252, 338, 1657, 13, 7559, 6803, 4058, 286, 345, 13, 1318, 338, 257, 15900, 13, 10148, 13742, 2705, 290, 866, 88, 11569, 36

In [None]:
class StoryDataset:
    def __init__(self, inputs):
        self.ids = inputs['input_ids']
        self.attention_mask = inputs['attention_mask']
        self.labels=inputs['labels']

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, item):

        return [torch.tensor(self.ids[item], dtype=torch.long),
                torch.tensor(self.attention_mask[item], dtype=torch.long),
                torch.tensor(self.labels[item], dtype=torch.long)]

In [None]:
train_batch_size=args.train_batch_size
valid_batch_size=args.valid_batch_size
traindata=StoryDataset(inputs_train)
train_dataloader = torch.utils.data.DataLoader(
    traindata,
    shuffle=False,
    batch_size=train_batch_size)

validdata=StoryDataset(inputs_valid)
valid_dataloader = torch.utils.data.DataLoader(
    validdata,
    shuffle=False,
    batch_size=valid_batch_size)

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [None]:
print(f"GPT2 base size: {np.round(model.num_parameters()/1e6, 1)} M parameters")

GPT2 base size: 124.4 M parameters


In [None]:
num_train_epochs = args.num_train_epochs
training_steps_per_epoch=len(train_dataloader)
total_num_training_steps = int(training_steps_per_epoch*num_train_epochs)
weight_decay=0
learning_rate=args.learning_rate
adam_epsilon=1e-8
warmup_steps=int(total_num_training_steps*args.warmup)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_num_training_steps
)



In [None]:
print("***** Running training *****")
print("  Total_num_training_step = {}".format(total_num_training_steps))
print("  Num Epochs = {}".format(num_train_epochs))
print(f"  Train_batch_size per device = {train_batch_size}")
print(f"  Valid_batch_size per device = {valid_batch_size}")
model.to('cuda')
for epoch in range(num_train_epochs):
    print(f"Start epoch{epoch+1} of {num_train_epochs}")
    train_loss=0
    epoch_iterator = tqdm(train_dataloader,desc='Iteration')
    model.train()
    model.zero_grad()    
    for _, inputs in enumerate(epoch_iterator):        
        d1,d2,d3=inputs
        d1=d1.to('cuda')
        d2=d2.to('cuda')
        d3=d3.to('cuda')
        output = model(input_ids=d1, attention_mask=d2,labels=d3)
        batch_loss=output[0]
        batch_loss.backward()
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        train_loss+=batch_loss.item()
        epoch_iterator.set_description('(batch loss=%g)' % batch_loss.item())
        del batch_loss
    print(f'Average train loss per example={train_loss/training_steps_per_epoch} in epoch{epoch+1}')    
    print(f'Starting evaluate after epoch {epoch+1}')
    eval_loss=[]    
    model.eval()    
    for inputs in tqdm(valid_dataloader, desc="eval"):
        d1,d2,d3=inputs
        d1=d1.to('cuda')        
        d2=d2.to('cuda')
        d3=d3.to('cuda')
        with torch.no_grad():
            output = model(input_ids=d1, attention_mask=d2,labels=d3)
            batch_loss=output[0]
        eval_loss+=[batch_loss.cpu().item()]
        del batch_loss
    eval_loss=np.mean(eval_loss)
    perplexity=math.exp(eval_loss)
    print(f'Average valid loss per example={eval_loss} in epoch{epoch+1}')    
    print(f'Perplextiy for valid dataset in epoch{epoch+1} is {perplexity}')

***** Running training *****
  Total_num_training_step = 7810
  Num Epochs = 1
  Train_batch_size per device = 2
  Valid_batch_size per device = 2
Start epoch1 of 1


(batch loss=2.72667): 100%|██████████| 7810/7810 [42:54<00:00,  3.03it/s]


Average train loss per example=3.2757100843742166 in epoch1
Starting evaluate after epoch 1


eval: 100%|██████████| 7569/7569 [12:46<00:00,  9.88it/s]

Average valid loss per example=3.186393145384696 in epoch1
Perplextiy for valid dataset in epoch1 is 24.200980415633015





In [None]:
prompt=valid_text[300][:valid_text[300].find('<sep>')]
target=valid_text[300][valid_text[300].find('<sep>')+5:]

def generate_story(prompt,target,k=0,p=0.9,output_length=300,temperature=1,num_return_sequences=3,repetition_penalty=1.0):
    print("====prompt====\n")
    print(prompt+"\n")
    print('====target story is as below===\n')
    print(target+"\n")
    encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    model.to('cpu')
    model.eval()
    output_sequences = model.generate(
        input_ids=encoded_prompt,
        max_length=output_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences
    )
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
        generated_sequence = generated_sequence.tolist()
        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        # Remove all text after eos token
        text = text[: text.find(tokenizer.eos_token)]
        print(text)

generate_story(prompt,target)

====prompt====

Children's logic dictates the way the world works. [ WP ] 

====target story is as below===

 “ That ’ s not an option I ’ m currently willing to exercise. ” 
 
 I pinch the bridge of my nose to stave off the headache building behind my eyes. If this goes on much longer, I ’ m gon na have to start to start cutting back on the vegetables. 
 
 “ She ’ s dangerous, Jimmy. You know that. You ’ ve seen it. Dealt with it first hand. She just doesn ’ t play by anyone ’ s rules. ” 
 
 Ali finished off her sucker and unwrapped a fresh one, offering it to me. I declined. I ’ d sworn off the things after my third cavity scare. That one saw me at the dentist for the third time in as many months. I don ’ t care what my dad says, I know that guy is evil. Who owns a drill like that? A murderer, that ’ s who. I still hear the damn thing in my nightmares. 
 
 While she savored the smooth flavor of blue-raspberry, I pondered her words. We both knew she was right. The situation was spiral

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


=== GENERATED SEQUENCE 1 ===
Children's logic dictates the way the world works. [ WP ] 
 
 In modern times, only earth's purest water exists. That's not a problem, as earth extends only through the ocean and another ocean seep through the ground below. All the water in the world has joined the cycle of the earth and the oceans, clearing land into the ocean. However, we humans don't live in the simple world of the Earth; we live in nature. 
 
 On the first day of our creation, Earth is rounded to the center of the earth. There, we see two transparent spheres, each at the same point in the center. We go through the earth, each sphere is shining brightly, until finally there's only light. This light is just dimmer than the rest of the light, and the light fades to nothing. Earth didn't see this light, but it began to fade to nothing. This was the beginning of all life, a universe like no other, since the words `` We are all in this '' were taken from within Earth and placed into outer spa