In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Import Libraries

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

# Read Text

In [3]:
poem_file = open('/content/drive/MyDrive/poem/datasets/poem.txt','r')
poem = poem_file.read()

In [4]:
poem_corpus = poem.split("\n")
print(poem_corpus[:5])

['नछाडी जानोस् हे मेरा प्राण ! अकेली मलाई,', 'मनको वनमा ननिभ्ने गरी विरह जलाई !', 'ननिभ्ने गरी विरह जलाई,', 'लोचनका तारा ! हे मेर प्यारा ! यो जोति  बिलाए !', 'के भनूँ? भन्ने म केही थिइन  विष नै पिलाए !']


In [5]:
def remove_noise(sentences):
    punctuations = ['\n','\ufeff','0','1','2','3','4','5','6','7','8','9','०','१','२','३','४','५','६','७','८','९','१०','\u200d']
    processed_sentences = []
    for sentence in sentences:
        for punct in punctuations:
            sentence = sentence.replace(punct,'')
        processed_sentences.append(sentence)

    return processed_sentences

In [6]:
processed_poem_corpus = remove_noise(poem_corpus)
print(processed_poem_corpus[:5])

['नछाडी जानोस् हे मेरा प्राण ! अकेली मलाई,', 'मनको वनमा ननिभ्ने गरी विरह जलाई !', 'ननिभ्ने गरी विरह जलाई,', 'लोचनका तारा ! हे मेर प्यारा ! यो जोति  बिलाए !', 'के भनूँ? भन्ने म केही थिइन  विष नै पिलाए !']


In [7]:
with open('processed_poem.txt','w') as f:
  for line in processed_poem_corpus:
    f.write(line + '\n')

# Fine Tune With GPT-2

# Load Dataset Function

In [8]:
def load_dataset(file_path, tokenizer, block_size = 128):
    dataset = TextDataset(
        tokenizer = tokenizer,
        file_path = file_path,
        block_size = block_size,
    )
    return dataset


def load_data_collator(tokenizer, mlm = False):
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=mlm,
    )
    return data_collator


# Train Function

In [9]:
def train(train_file_path,model_name,
          output_dir,
          overwrite_output_dir,
          per_device_train_batch_size,
          num_train_epochs,
          save_steps):
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
  train_dataset = load_dataset(train_file_path, tokenizer)
  data_collator = load_data_collator(tokenizer)

  tokenizer.save_pretrained(output_dir)

  model = GPT2LMHeadModel.from_pretrained(model_name)

  model.save_pretrained(output_dir)

  training_args = TrainingArguments(
          output_dir=output_dir,
          overwrite_output_dir=overwrite_output_dir,
          per_device_train_batch_size=per_device_train_batch_size,
          num_train_epochs=num_train_epochs,
      )

  trainer = Trainer(
          model=model,
          args=training_args,
          data_collator=data_collator,
          train_dataset=train_dataset,
  )

  trainer.train()
  trainer.save_model()

In [18]:
# you need to set parameters
train_file_path = "/content/processed_poem.txt"
model_name = 'gpt2'
output_dir = '/content/'
overwrite_output_dir = False
per_device_train_batch_size = 8
num_train_epochs = 10
save_steps = 500

In [19]:
# It takes about 30 minutes to train in colab.
train(
    train_file_path=train_file_path,
    model_name=model_name,
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    num_train_epochs=num_train_epochs,
    save_steps=save_steps
)



Step,Training Loss
500,1.6345
1000,1.4429
1500,1.3624


In [23]:
def load_model(model_path):
    model = GPT2LMHeadModel.from_pretrained(model_path)
    return model


def load_tokenizer(tokenizer_path):
    tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
    return tokenizer


def generate_text(sequence, max_length):
    model_path = "/content/"
    model = load_model(model_path)
    tokenizer = load_tokenizer(model_path)
    ids = tokenizer.encode(f'{sequence}', return_tensors='pt')
    final_outputs = model.generate(
        ids,
        do_sample=True,
        max_length=max_length,
        pad_token_id=model.config.eos_token_id,
        top_k=50,
        top_p=0.95,
    )
    print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))

In [27]:
sequence = input() # oil price
max_len = int(input()) # 20
generate_text(sequence, max_len) # oil price for July June which had been low at as low as was originally stated Prices have since resumed

विष नै पिलाए
100
विष नै पिलाएन ।
देख्न प्राण! भन तिमी बिलाएन ।
बचाल स्वर्ग सारा च�


In [None]:
1