# TRL-Simp model

This notebook implement PPO Trainer using Transformer reinforcement Learning library.
(The pipeline was adapted based on example in this Github Repo: https://github.com/lvwerra/trl)

- Batch-size: 15
- Learning rate: 5e-4
- Add customization tuning for memory efficient: Learning rate scheduler, pass SGD optimizer (https://github.com/lvwerra/trl/blob/main/docs/source/customization.mdx)
 
View project at https://wandb.ai/ml2_g10/trl

View run at https://wandb.ai/ml2_g10/trl/runs/v0dcp47v

In [1]:
# from google.colab import drive

# drive.mount('/content/drive')
# cd '/content/drive/MyDrive/1_STUDY AT TWENTE/1_ML2 project'
# !ls
# !pip install transformers

In [2]:
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

In [3]:
import torch

In [4]:
# pip install trl-0.2.1.tar.gz

In [5]:
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model,AutoModelForSeq2SeqLMWithValueHead
from trl.core import respond_to_batch, LengthSampler
from transformers import AutoTokenizer, BartForConditionalGeneration, RobertaForSequenceClassification


In [7]:
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.37.0-py3-none-any.whl (76.3 MB)
[K     |████████████████████████████████| 76.3 MB 8.4 MB/s eta 0:00:011   |▏                               | 348 kB 3.5 MB/s eta 0:00:22     |████████████████████████████████| 76.2 MB 8.4 MB/s eta 0:00:01
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.37.0


## Configuration

In [8]:
!pip install wandb -qqq

In [9]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33manhtth[0m ([33mml2_g10[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
#  import wandb
#  !wandb login --relogin --host=https://api.wandb.ai

In [11]:
config = PPOConfig(learning_rate=5e-4,
                   forward_batch_size=1,
                   batch_size=15,
                   log_with="wandb")
#config = PPOConfig(learning_rate=5e-5,forward_batch_size=1,batch_size=1)

In [12]:
tags=["<ident>", "<para>", "<ssplit>", "<dsplit>"]

clfmodel = RobertaForSequenceClassification.from_pretrained("liamcripwell/ctrl44-clf")
simpmodel = AutoModelForSeq2SeqLMWithValueHead(BartForConditionalGeneration.from_pretrained("liamcripwell/ctrl44-simp"))

tokenizerclf = AutoTokenizer.from_pretrained("liamcripwell/ctrl44-clf")
tokenizersimp = AutoTokenizer.from_pretrained("liamcripwell/ctrl44-simp")

## Dataset

In [13]:
# def pipelinemodel(text):
# #not batch
#     inputs1=[]
#     logits=[]
#     inputs2=[]
#     outputs=[]
#     simpsentence=[]
#     inputs1 = tokenizerclf(text, return_tensors="pt")

#     with torch.no_grad():
#         logits = clfmodel(**inputs1).logits.to(device)

#     predicted_class_id = logits.argmax().item()    
#     inputs2 = tokenizersimp((tags[predicted_class_id] + ' ' + text), return_tensors="pt")

#     outputs = simpmodel.generate(**inputs2, num_beams=10, max_length=128).to(device)

#     simpsentence=tokenizersimp.decode(outputs[0])

#     return inputs1, outputs,simpsentence


In [14]:
import datasets

ds_asset = datasets.load_dataset("asset", split = 'validation')
ds_asset.set_format(type="torch")

No config specified, defaulting to: asset/simplification
Found cached dataset asset (/Users/anhtth/.cache/huggingface/datasets/asset/simplification/1.0.0/a1ebd31e2a43bb6d4b5826423c73e8397d1696526af6c99f20da612f51799a8f)


## Reward Function

In [15]:
# rm -r easse
# !git clone https://github.com/feralvam/easse.git
# cd easse
# pip install -e .

In [16]:
from easse.sari import corpus_sari

corpus_sari(orig_sents=["About 95 species are currently accepted.", "The cat perched on the mat."],  
            sys_sents=["About 95 you now get in.", "Cat on mat."], 
            refs_sents=[["About 95 species are currently known.", "The cat sat on the mat."],
                        ["About 95 species are now accepted.", "The cat is on the mat."],  
                        ["95 species are now accepted.", "The cat sat."]])

33.17472563619544

In [17]:
from easse.report import get_all_scores

get_all_scores(orig_sents=["About 95 species are currently accepted.", "The cat perched on the mat."],  
            sys_sents=["About 95 you now get in.", "Cat on mat."], 
            refs_sents=[["About 95 species are currently known.", "The cat sat on the mat."],
                        ["About 95 species are now accepted.", "The cat is on the mat."],  
                        ["95 species are now accepted.", "The cat sat."]])

{'BLEU': 14.99,
 'SARI': 31.95,
 'FKGL': 0,
 'Compression ratio': 0.52,
 'Sentence splits': 1.0,
 'Levenshtein similarity': 0.52,
 'Exact copies': 0.0,
 'Additions proportion': 0.36,
 'Deletions proportion': 0.57,
 'Lexical complexity score': 8.96}

In [19]:
# from easse.samsa import corpus_samsa
# corpus_samsa(orig_sents=["About 95 species are currently accepted.", "The cat perched on the mat."],  
#             sys_sents=["About 95 you now get in.", "Cat on mat."], 
#             refs_sents=[["About 95 species are currently known.", "The cat sat on the mat."],
#                         ["About 95 species are now accepted.", "The cat is on the mat."],  
#                         ["95 species are now accepted.", "The cat sat."]])

## PPO

In [18]:
output_min_length = 16
output_max_length = 32
output_length_sampler = LengthSampler(output_min_length, output_max_length)

In [19]:
def batchpipelinemodel(text):
    inputs1 = []
    inputs2 = []
    inputs11 = []
    logits = []
    outputs = []
    simpsentences = []
    
    inputs1 = tokenizerclf(text, return_tensors="pt",padding=True)

    with torch.no_grad():
        logits = clfmodel(**inputs1).logits

    for i in range(len(text)):
        predicted_class_id = logits[i].argmax().item()    
        input2 = tags[predicted_class_id] + ' ' + text[i]
        inputs2.append(input2)     
  
    inputs22 = tokenizersimp(inputs2, return_tensors="pt",padding='longest')

    output = simpmodel.generate(**inputs22, num_beams=10, max_length=128)
    for i in range(len(text)):
         simpsentences.append(tokenizersimp.decode(output[i]))

    return inputs1, output, simpsentences

In [20]:
referes=ds_asset[0]['simplifications']
print([[j]for j in referes])

[['countries next to it are Marin, Mendocino, Lake, Napa, Solano, and Contra Costa.'], ['Nearby counties are Marin, Mendocino, Lake, Napa, and Solano and Contra Costa.'], ['Adjacent counties are Marin, Mendocino, Lake, Napa, Solano and Contra Costa.'], ['Neighboring counties are Marin, Mendocino, Lake, Napa, Solano, and Contra Costa.'], ['Adjacent counties are Marin (south), Mendocino (north), Lake (northeast), and Napa (east). Solano and Contra Costa are to the southeast.'], ['Counties next to it are Marin (to the south), Mendocino (to the north), Lake (northeast), Napa (to the east), and Solano and Contra Costa (to the southeast).'], ['Marin, Mendocino, Lake, Napa, Solano, and Contra Costa counties are next to it.'], ['Adjacent counties are Marin, Mendocino, Lake, Napa, Solano, and Contra Costa.'], ['Counties next door are Marin (south), Mendocino (north), Lake (northeast), Napa (east), and Solano and Contra Costa (southeast).'], ['Nearby counties are Marin (to the south), Mendocino 

In [25]:
# # Customization for memory efficient - Adam8bit
# import bitsandbytes as bnb
# # Create optimizer
# model=simpmodel
# model_ref=create_reference_model(model)
# optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
# lr_scheduler = lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [28]:
%%wandb
#batch
wandb.init(project='trl-simp-full', config=config)

model=simpmodel
model_ref=create_reference_model(model)
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
#optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
lr_scheduler = lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

ppo_trainer = PPOTrainer(config, model, model_ref, tokenizersimp,
                        optimizer=optimizer,
                        lr_scheduler=lr_scheduler)

train_loader = torch.utils.data.DataLoader(ds_asset,
                                           shuffle=True,
                                           num_workers=1,
                                           batch_size=config.batch_size)


for epoch in train_loader:

    query_tensors = []
    response_tensors = []
    batch = {}
    rewards1=[]
    rewards=[]
    ref=[]
    for j in range(len(epoch['simplifications'])):
        
        each = [i for i in epoch['simplifications'][j]]
        ref.append(each) 

    ori=epoch['original']
    print('original sentences:')
    print(ori)
    query_tensor, response_tensor, response = batchpipelinemodel(ori)

    for i in range(config.batch_size):
        query_tensors.append(torch.tensor(query_tensor['input_ids'][i]))
    for i in range(config.batch_size):
        response_tensors.append(torch.tensor(response_tensor[i]))
    
    print('simplified sentences:')
    print(response)

    print(epoch['simplifications'][0])
    for i in range(config.batch_size):
        refs_sents = [[epoch['simplifications'][j][i]] for j in range(10)]
        orig_sents=[ori[i]]
        sys_sents=[response[i]]
        sari = get_all_scores(orig_sents, sys_sents, refs_sents)['SARI']
        bleu = get_all_scores(orig_sents, sys_sents, refs_sents)['BLEU']
        fkgl = get_all_scores(orig_sents, sys_sents, refs_sents)['FKGL']
        reward = (float(sari) + float(bleu) + float(fkgl))/3
        rewards1.append(reward)
    
    rewards = [torch.tensor(rew) for rew in rewards1]
    print(rewards)
        
    batch['query'] = epoch
    batch['response'] = response
    #### Run PPO step
    print(query_tensors)
    print(response_tensors)
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
wandb.run

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016678435516666923, max=1.0…

VBox(children=(Label(value='0.024 MB of 0.024 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016731035899999823, max=1.0…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
original sentences:
['Sacred music While Rore is best known for his Italian madrigals, he was also a prolific composer of sacred music, both masses and motets.', 'The endpoints where a continental divide meets the coast are not always definite, because the exact border between adjacent bodies of water is usually not clearly defined.', 'This includes all natural and human accidents and errors and is calculated over roughly 160,000 Shinkansen trips completed.', 'While there he meets a young half-black boy by the name of Arthur Stuart, the son of a slave and a slave-owner who has been adopted by the owners of the local guesthouse.', 'Terrorists have seized control of Las Vegas to instill widespread panic in bot

  query_tensors.append(torch.tensor(query_tensor['input_ids'][i]))
  response_tensors.append(torch.tensor(response_tensor[i]))


simplified sentences:
['</s><s> Sacred music. While Rore is best known for his Italian madrigals, he was also a prolific composer of sacred music, both masses and motets.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '</s><s> The endpoints where a continental divide meets the coast are not always definite. This is because the exact border between adjacent bodies of water is usually not clearly defined.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '</s><s> This includes all natural and human accidents and errors. It is calculated over roughly 160,000 Shinkansen trips completed.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '</s><s> While there he meets a young half-black boy by the name of Arthur Stuart. He is the son of a slave and a slave-owner who has been adopted by the owners of the local guesthouse.</s><pad><pad>', '</s><s> Terrorists have seized control of Las Vegas to instill wid

IndexError: index 5 is out of bounds for dimension 0 with size 5

In [29]:
#ppo_trainer.state_dict()
model.save_pretrained('trl-simp-batch15', push_to_hub=False)
tokenizersimp.save_pretrained('trl-simp-batch15', push_to_hub=False)

('trl-simp-batch15/tokenizer_config.json',
 'trl-simp-batch15/special_tokens_map.json',
 'trl-simp-batch15/vocab.json',
 'trl-simp-batch15/merges.txt',
 'trl-simp-batch15/added_tokens.json',
 'trl-simp-batch15/tokenizer.json')

## Evaluation