In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, GenerationConfig, DataCollatorWithPadding
from datasets import load_dataset
from peft import PeftModel, PeftConfig, LoraConfig, TaskType
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler
import torch
import evaluate
import numpy as np
import pandas as pd
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset
from huggingface_hub import login


data = load_dataset("kapilrk04/codemix-en_enhi",split="train",use_auth_token=True,cache_dir="codemix-dataset")
data = data.train_test_split(test_size=0.2,shuffle=True)
print(data)

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home2/likhithasapu/.cache/huggingface/token
Login successful




DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 1003028
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 250757
    })
})


Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home2/likhithasapu/.cache/huggingface/token
Login successful


In [4]:
# Load tokenizers and models
translation_model_name = "likhithasapu/codemix-indicbart"
acceptability_model_name = "likhithasapu/indic-bert-regression-v1"
tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
acceptability_model = AutoModelForSequenceClassification.from_pretrained(acceptability_model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# Preprocess function
def preprocess_function(examples):
    prompts = [f"Translate the English sentence to Hindi-English sentence: <s> {example['en']} </s>" for example in examples['translation']]
    responses = [f"<s> {example['en-hi']} </s>" for example in examples['translation']]
    
    model_inputs = tokenizer(prompts, truncation=True, padding=True, max_length=256)
    labels = tokenizer(responses, truncation=True, padding=True, max_length=256)["input_ids"]
    
    model_inputs["labels"] = labels
    return model_inputs

In [6]:

# Apply the preprocess function to the dataset with batching
batch_size = 64
processed_data = data.map(
    preprocess_function,
    batched=True,
    batch_size=batch_size,
    remove_columns=data["train"].column_names
)


Map:   0%|          | 0/1003028 [00:00<?, ? examples/s]

Map: 100%|██████████| 1003028/1003028 [02:39<00:00, 6277.20 examples/s]
Map: 100%|██████████| 250757/250757 [00:40<00:00, 6148.58 examples/s]


In [7]:
# Define PPO configuration
ppo_config = PPOConfig(
    model_name=translation_model_name,    
    learning_rate=1.41e-5,
    ppo_epochs=1,
    mini_batch_size=2,
    batch_size=8
)

In [8]:
# Define reward function
def compute_reward(predictions):
    inputs = tokenizer(predictions, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(acceptability_model.device) for k, v in inputs.items()}
    with torch.no_grad():
        scores = acceptability_model(**inputs).logits.squeeze().tolist()
        
    return scores

In [9]:
from torch.utils.data import DataLoader
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(processed_data["train"], batch_size=8, collate_fn=data_collator)
eval_dataloader = DataLoader(processed_data["test"], batch_size=8, collate_fn=data_collator)

In [10]:
# Initialize PPO model
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(translation_model)
ref_model = create_reference_model(ppo_model)
ppo_trainer = PPOTrainer(config=ppo_config, model=ppo_model, ref_model=ref_model, tokenizer=tokenizer, dataset=processed_data["train"], data_collator=data_collator)

Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [12]:
# Training loop
max_ppo_steps = 10000
# Set up length sampler and generation configurations
output_min_length = 10
output_max_length = 100


# Training loop
for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps:
        break   

    prompt_tensors = [tensor for tensor in batch["input_ids"]]
    attention_mask = [tensor for tensor in batch["attention_mask"]]

    # Get response from the model
    summary_tensors = []

    for prompt_tensor,attention_tensor in zip(prompt_tensors,attention_mask):        
        summary = ppo_model.generate(input_ids = prompt_tensor.unsqueeze(0).to("cuda"), attention_mask = attention_tensor.unsqueeze(0).to("cuda"), max_new_tokens=output_max_length, num_beams=5, no_repeat_ngram_size=3, do_sample=True, num_return_sequences=1)
        summary_tensors.append(summary.squeeze()[-output_max_length:])
        
    batch["response"] = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in summary_tensors]

    # Compute reward outputs.
    query_response_pairs = [tokenizer.decode(q, skip_special_tokens=True) + r for q, r in zip(batch["input_ids"], batch["response"])]    
    rewards = [torch.tensor(reward) for reward in compute_reward(query_response_pairs)]
    print(rewards)
    # Run PPO step.
    
    stats = ppo_trainer.step(prompt_tensors, summary_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

0it [00:00, ?it/s]

[tensor(3.1194), tensor(3.1821), tensor(3.1052), tensor(3.1015), tensor(3.0763), tensor(3.1072), tensor(3.0079), tensor(3.0481)]


1it [00:05,  5.14s/it]

[tensor(3.4473), tensor(3.1617), tensor(3.0443), tensor(3.0003), tensor(3.1177), tensor(3.1408), tensor(3.1187), tensor(3.0363)]


2it [00:09,  4.95s/it]

[tensor(3.0271), tensor(2.9640), tensor(2.9981), tensor(3.2500), tensor(3.0266), tensor(3.1734), tensor(3.1685), tensor(3.1751)]


3it [00:16,  5.58s/it]

[tensor(2.9837), tensor(3.0993), tensor(2.9794), tensor(2.9580), tensor(3.0001), tensor(3.4215), tensor(2.8588), tensor(3.0236)]


4it [00:22,  5.64s/it]

[tensor(2.9908), tensor(3.1220), tensor(3.0971), tensor(2.9426), tensor(2.9960), tensor(3.2007), tensor(2.9821), tensor(2.9597)]


5it [00:28,  5.85s/it]

[tensor(3.0827), tensor(3.0198), tensor(2.9308), tensor(3.2331), tensor(3.0502), tensor(3.0070), tensor(3.0140), tensor(3.0698)]


6it [00:34,  5.96s/it]

[tensor(3.2477), tensor(3.2156), tensor(2.9408), tensor(3.3888), tensor(3.1056), tensor(2.9571), tensor(3.4194), tensor(3.2556)]


7it [00:41,  6.48s/it]

[tensor(3.0938), tensor(3.1014), tensor(2.9085), tensor(3.1620), tensor(3.3297), tensor(3.2002), tensor(2.9234), tensor(3.1501)]


8it [00:50,  7.16s/it]

[tensor(3.1071), tensor(3.1838), tensor(3.1378), tensor(3.2612), tensor(3.0304), tensor(3.3437), tensor(3.2392), tensor(3.1272)]


9it [00:56,  6.81s/it]

[tensor(3.1148), tensor(3.0586), tensor(3.0680), tensor(3.1916), tensor(3.1167), tensor(3.1355), tensor(2.9137), tensor(3.0921)]


10it [01:03,  6.90s/it]

[tensor(3.0466), tensor(3.1549), tensor(3.0523), tensor(3.1326), tensor(3.3409), tensor(3.0271), tensor(3.1874), tensor(3.2121)]


11it [01:11,  7.19s/it]

[tensor(2.9791), tensor(3.0006), tensor(2.9567), tensor(3.3520), tensor(3.3087), tensor(3.0773), tensor(2.9697), tensor(3.1030)]


12it [01:19,  7.48s/it]

[tensor(3.1096), tensor(3.2876), tensor(3.3587), tensor(3.4035), tensor(3.5088), tensor(3.0374), tensor(3.2350), tensor(3.2014)]


13it [01:25,  7.02s/it]

[tensor(2.9840), tensor(2.9042), tensor(3.0774), tensor(3.1028), tensor(3.0903), tensor(3.4317), tensor(3.0658), tensor(3.0225)]


14it [01:34,  7.71s/it]

[tensor(2.9242), tensor(3.3597), tensor(3.2713), tensor(2.9748), tensor(3.2112), tensor(3.3000), tensor(3.1779), tensor(3.1415)]


15it [01:42,  7.80s/it]

[tensor(3.2104), tensor(3.1485), tensor(3.0929), tensor(2.9943), tensor(2.9972), tensor(3.0719), tensor(3.1454), tensor(3.0253)]


16it [01:50,  7.57s/it]

[tensor(3.2045), tensor(3.1871), tensor(3.0353), tensor(2.9451), tensor(3.1337), tensor(3.1867), tensor(2.9902), tensor(3.1255)]


17it [01:58,  7.74s/it]

[tensor(3.3264), tensor(3.0703), tensor(2.9121), tensor(3.1824), tensor(3.0180), tensor(2.9473), tensor(3.2299), tensor(3.1021)]


18it [02:05,  7.50s/it]

[tensor(3.0330), tensor(3.2176), tensor(3.0820), tensor(3.2160), tensor(2.9877), tensor(3.0193), tensor(2.8865), tensor(2.9971)]


19it [02:11,  7.23s/it]

[tensor(3.0032), tensor(3.1261), tensor(3.1557), tensor(3.0696), tensor(2.9865), tensor(3.1301), tensor(3.0872), tensor(3.0986)]


20it [02:19,  7.33s/it]

[tensor(3.4305), tensor(3.0113), tensor(3.0106), tensor(3.1369), tensor(2.9468), tensor(3.3672), tensor(3.1441), tensor(3.0082)]


21it [02:28,  7.86s/it]

[tensor(3.0272), tensor(3.0087), tensor(3.2467), tensor(3.0810), tensor(2.7429), tensor(3.2359), tensor(3.2044), tensor(3.1598)]


22it [02:36,  7.84s/it]

[tensor(3.0821), tensor(3.0572), tensor(3.1914), tensor(3.1519), tensor(3.2178), tensor(3.2035), tensor(3.2776), tensor(3.0350)]


23it [02:43,  7.77s/it]

[tensor(3.1683), tensor(3.3021), tensor(2.9545), tensor(3.0296), tensor(3.0098), tensor(3.1960), tensor(3.0943), tensor(3.0369)]


24it [02:50,  7.46s/it]

[tensor(3.0867), tensor(3.1141), tensor(3.1626), tensor(3.2045), tensor(2.9264), tensor(3.3164), tensor(3.2027), tensor(3.2596)]


25it [02:59,  7.86s/it]

[tensor(2.9955), tensor(2.9819), tensor(3.1767), tensor(3.1659), tensor(3.3106), tensor(3.2230), tensor(3.1631), tensor(3.2228)]


26it [03:06,  7.73s/it]

[tensor(3.3302), tensor(3.3460), tensor(2.9360), tensor(3.2255), tensor(3.1143), tensor(3.2441), tensor(3.2350), tensor(3.0052)]


27it [03:14,  7.68s/it]

[tensor(3.1625), tensor(2.9414), tensor(3.1652), tensor(3.0344), tensor(3.2098), tensor(3.1870), tensor(3.0555), tensor(2.8255)]


28it [03:20,  7.31s/it]

[tensor(3.1185), tensor(3.2130), tensor(3.1820), tensor(3.2705), tensor(2.5936), tensor(2.8630), tensor(2.9474), tensor(3.1307)]


29it [03:29,  7.79s/it]

[tensor(3.1720), tensor(3.0021), tensor(3.2459), tensor(3.1706), tensor(3.0989), tensor(3.2096), tensor(3.1912), tensor(2.9079)]


30it [03:36,  7.55s/it]

[tensor(3.1913), tensor(3.1889), tensor(3.1621), tensor(3.2937), tensor(3.0495), tensor(3.4647), tensor(3.1459), tensor(3.2033)]


31it [03:45,  7.99s/it]

[tensor(2.9687), tensor(3.1414), tensor(3.1431), tensor(3.1275), tensor(3.1719), tensor(3.1053), tensor(3.1081), tensor(3.0109)]


32it [03:52,  7.77s/it]

[tensor(3.2501), tensor(2.8006), tensor(3.0881), tensor(3.1044), tensor(2.9194), tensor(3.0881), tensor(3.1650), tensor(3.2738)]


33it [03:57,  6.89s/it]

[tensor(3.1914), tensor(3.0695), tensor(2.9500), tensor(3.1225), tensor(3.1484), tensor(3.0404), tensor(3.1230), tensor(3.3775)]


34it [04:05,  7.16s/it]

[tensor(3.1349), tensor(3.2519), tensor(3.3901), tensor(3.2612), tensor(3.0157), tensor(3.2349), tensor(3.1829), tensor(3.3470)]


35it [04:13,  7.32s/it]

[tensor(3.1361), tensor(3.2173), tensor(3.1113), tensor(3.2112), tensor(3.2348), tensor(3.3040), tensor(3.0975), tensor(3.2109)]


36it [04:20,  7.42s/it]

[tensor(2.8794), tensor(3.2186), tensor(3.1894), tensor(3.3592), tensor(3.0173), tensor(3.0305), tensor(2.9925), tensor(2.9922)]


37it [04:28,  7.62s/it]

[tensor(3.1057), tensor(3.1061), tensor(3.2642), tensor(3.3033), tensor(3.1804), tensor(3.0737), tensor(3.1385), tensor(3.0834)]


38it [04:38,  8.05s/it]

[tensor(3.1393), tensor(3.1817), tensor(3.0282), tensor(3.4872), tensor(2.9599), tensor(2.8761), tensor(3.0077), tensor(2.8921)]


39it [04:46,  8.15s/it]

[tensor(3.0692), tensor(3.4592), tensor(3.0078), tensor(3.1262), tensor(3.0959), tensor(2.9706), tensor(3.0055), tensor(3.1410)]


40it [04:53,  7.85s/it]

[tensor(3.3328), tensor(3.3725), tensor(3.0540), tensor(3.0836), tensor(3.2419), tensor(3.1534), tensor(3.2795), tensor(3.2204)]


41it [05:01,  7.93s/it]

[tensor(3.0440), tensor(3.1097), tensor(3.0259), tensor(3.0221), tensor(3.0712), tensor(3.0458), tensor(3.3643), tensor(3.1822)]


42it [05:08,  7.50s/it]

[tensor(3.4327), tensor(3.0273), tensor(2.9522), tensor(3.3480), tensor(3.1089), tensor(3.0266), tensor(3.1141), tensor(3.2476)]


43it [05:14,  7.25s/it]

[tensor(3.0219), tensor(3.2756), tensor(3.0309), tensor(3.0124), tensor(3.2117), tensor(3.0299), tensor(3.0410), tensor(3.1339)]


44it [05:20,  6.82s/it]

[tensor(3.1692), tensor(3.2532), tensor(3.2115), tensor(2.9819), tensor(3.1936), tensor(3.2703), tensor(3.1906), tensor(3.2119)]


45it [05:26,  6.67s/it]

[tensor(3.0350), tensor(3.1820), tensor(3.2735), tensor(3.3045), tensor(3.0694), tensor(3.0662), tensor(3.0713), tensor(3.2098)]


46it [05:33,  6.70s/it]

[tensor(3.1634), tensor(3.0788), tensor(3.1461), tensor(3.1030), tensor(3.0918), tensor(2.8193), tensor(3.0484), tensor(3.1036)]


47it [05:42,  7.44s/it]

[tensor(3.0524), tensor(3.1506), tensor(3.2247), tensor(3.1866), tensor(3.0669), tensor(3.3244), tensor(3.3876), tensor(3.2517)]


48it [05:51,  7.64s/it]

[tensor(3.2038), tensor(3.0283), tensor(3.0267), tensor(3.3951), tensor(3.1963), tensor(3.1229), tensor(3.1568), tensor(3.2953)]


49it [05:59,  7.84s/it]

[tensor(3.0848), tensor(3.0419), tensor(3.3791), tensor(3.1156), tensor(3.1436), tensor(3.2695), tensor(3.3465), tensor(3.0181)]


50it [06:08,  8.34s/it]

[tensor(3.2377), tensor(3.1838), tensor(3.0798), tensor(3.1216), tensor(3.1222), tensor(3.1213), tensor(3.3460), tensor(3.2444)]


51it [06:16,  8.28s/it]

[tensor(3.2254), tensor(3.0921), tensor(3.2476), tensor(3.2570), tensor(2.9627), tensor(3.1902), tensor(3.3104), tensor(2.9985)]


52it [06:24,  8.20s/it]

[tensor(3.0598), tensor(3.0834), tensor(3.0813), tensor(3.3721), tensor(3.1497), tensor(2.9957), tensor(3.2832), tensor(3.0745)]


53it [06:34,  8.55s/it]

[tensor(3.1869), tensor(3.1772), tensor(3.1709), tensor(3.0377), tensor(3.0546), tensor(3.1399), tensor(3.1438), tensor(3.2487)]


54it [06:43,  8.62s/it]

[tensor(3.0812), tensor(3.0136), tensor(3.1878), tensor(3.0549), tensor(2.8550), tensor(3.2819), tensor(3.2835), tensor(2.8938)]


55it [06:51,  8.54s/it]

[tensor(3.1476), tensor(3.0780), tensor(3.2028), tensor(3.2453), tensor(3.0792), tensor(3.0919), tensor(3.0180), tensor(2.9957)]


56it [07:00,  8.55s/it]

[tensor(3.1123), tensor(2.9809), tensor(3.1626), tensor(3.1338), tensor(3.0747), tensor(3.3051), tensor(3.0680), tensor(3.1830)]


57it [07:08,  8.50s/it]

[tensor(3.0572), tensor(2.9270), tensor(2.9268), tensor(3.2254), tensor(3.3471), tensor(3.3500), tensor(3.1124), tensor(3.0592)]


58it [07:16,  8.49s/it]

[tensor(3.2005), tensor(3.0577), tensor(3.2356), tensor(3.2176), tensor(2.9665), tensor(3.2539), tensor(3.3063), tensor(3.3493)]


59it [07:24,  8.31s/it]

[tensor(3.2341), tensor(3.2356), tensor(3.1506), tensor(3.0403), tensor(3.1360), tensor(3.1709), tensor(3.3870), tensor(3.2567)]


60it [07:34,  8.73s/it]

[tensor(3.1454), tensor(3.1223), tensor(3.1927), tensor(3.1542), tensor(3.4025), tensor(3.3101), tensor(3.1526), tensor(3.1069)]


61it [07:41,  8.33s/it]

[tensor(2.9640), tensor(3.1454), tensor(3.1018), tensor(3.0588), tensor(3.2406), tensor(3.1806), tensor(3.0838), tensor(3.0482)]


62it [07:50,  8.56s/it]

[tensor(3.0138), tensor(3.1088), tensor(3.1884), tensor(3.0974), tensor(2.9439), tensor(3.0578), tensor(3.1421), tensor(3.2517)]


63it [07:59,  8.51s/it]

[tensor(3.2301), tensor(3.1595), tensor(3.1803), tensor(3.2743), tensor(3.0003), tensor(2.9408), tensor(3.1379), tensor(3.2726)]


64it [08:08,  8.70s/it]

[tensor(3.0769), tensor(3.2116), tensor(3.1654), tensor(2.9217), tensor(3.2286), tensor(2.9884), tensor(3.1738), tensor(3.2081)]


65it [08:18,  8.94s/it]

[tensor(3.2325), tensor(3.4040), tensor(3.2511), tensor(3.4508), tensor(3.0683), tensor(3.3937), tensor(2.8784), tensor(3.2152)]


66it [08:26,  8.95s/it]

[tensor(2.8756), tensor(3.2265), tensor(3.0295), tensor(3.2841), tensor(3.0902), tensor(3.0311), tensor(3.2199), tensor(3.3550)]


67it [08:36,  9.15s/it]

[tensor(3.2465), tensor(3.0712), tensor(3.0372), tensor(3.1755), tensor(3.1571), tensor(3.3301), tensor(3.3149), tensor(3.1643)]


68it [08:46,  9.26s/it]

[tensor(2.9668), tensor(3.0942), tensor(2.9190), tensor(3.1004), tensor(2.9458), tensor(3.0965), tensor(3.0618), tensor(3.2215)]


69it [08:55,  9.18s/it]

[tensor(3.1304), tensor(3.0935), tensor(3.2608), tensor(2.8575), tensor(3.2639), tensor(3.1494), tensor(3.0539), tensor(3.3625)]


70it [09:04,  9.26s/it]

[tensor(2.6518), tensor(3.1313), tensor(3.3727), tensor(3.1346), tensor(3.0916), tensor(3.2771), tensor(3.0203), tensor(3.1769)]


71it [09:14,  9.32s/it]

[tensor(3.0244), tensor(2.9873), tensor(3.1355), tensor(3.1795), tensor(3.2272), tensor(3.0247), tensor(2.9067), tensor(3.1666)]


72it [09:23,  9.34s/it]

[tensor(3.2050), tensor(3.1113), tensor(3.1539), tensor(3.1560), tensor(2.9861), tensor(2.4981), tensor(3.0662), tensor(2.9319)]


73it [09:32,  9.41s/it]

[tensor(2.8493), tensor(3.0444), tensor(2.9559), tensor(3.3343), tensor(3.0030), tensor(2.9450), tensor(3.2399), tensor(3.2602)]


74it [09:41,  9.23s/it]

[tensor(2.9807), tensor(3.0857), tensor(3.1151), tensor(3.2129), tensor(3.3844), tensor(3.1186), tensor(3.5220), tensor(3.2500)]


75it [09:50,  9.11s/it]

[tensor(3.3377), tensor(3.2271), tensor(2.9094), tensor(3.1720), tensor(3.1636), tensor(3.3947), tensor(3.3026), tensor(3.2604)]


76it [10:00,  9.26s/it]

[tensor(3.1259), tensor(3.0535), tensor(3.1880), tensor(3.1619), tensor(3.2198), tensor(2.6805), tensor(3.2355), tensor(3.0322)]


77it [10:09,  9.25s/it]

[tensor(3.0906), tensor(3.1095), tensor(3.3467), tensor(3.1776), tensor(3.0608), tensor(3.2714), tensor(2.9369), tensor(3.3335)]


78it [10:18,  9.14s/it]

[tensor(3.1611), tensor(3.2333), tensor(2.9510), tensor(3.1540), tensor(3.0777), tensor(3.2657), tensor(3.0744), tensor(3.1512)]


79it [10:27,  9.28s/it]

[tensor(3.3256), tensor(3.0280), tensor(3.0301), tensor(2.7862), tensor(3.4284), tensor(3.2865), tensor(3.0787), tensor(3.1765)]


80it [10:37,  9.36s/it]

[tensor(2.9617), tensor(3.0790), tensor(3.1831), tensor(3.0657), tensor(3.0635), tensor(3.2924), tensor(3.1713), tensor(3.2595)]


81it [10:46,  9.25s/it]

[tensor(2.9150), tensor(3.1952), tensor(3.2879), tensor(3.2035), tensor(3.4003), tensor(3.2159), tensor(3.3775), tensor(3.0316)]


82it [10:55,  9.29s/it]

[tensor(3.2449), tensor(3.3247), tensor(3.3175), tensor(3.1738), tensor(3.0517), tensor(3.1718), tensor(3.1618), tensor(3.0777)]


83it [11:05,  9.29s/it]

[tensor(3.0007), tensor(2.9919), tensor(3.2471), tensor(3.2039), tensor(3.2011), tensor(3.0712), tensor(3.1754), tensor(3.3107)]


84it [11:14,  9.34s/it]

[tensor(3.2163), tensor(3.0085), tensor(3.0951), tensor(3.1815), tensor(3.1372), tensor(3.2465), tensor(3.3919), tensor(3.0276)]


85it [11:23,  9.17s/it]

[tensor(3.1754), tensor(3.5548), tensor(3.2640), tensor(3.3583), tensor(3.0908), tensor(3.2172), tensor(3.1647), tensor(3.0439)]


86it [11:32,  9.20s/it]

[tensor(3.1408), tensor(3.3770), tensor(3.0801), tensor(3.1814), tensor(3.0238), tensor(3.1770), tensor(3.0546), tensor(3.0343)]


87it [11:42,  9.25s/it]

[tensor(3.2675), tensor(3.2224), tensor(3.2418), tensor(3.2470), tensor(3.1447), tensor(2.9445), tensor(3.0574), tensor(3.2633)]


88it [11:50,  9.05s/it]

[tensor(3.2443), tensor(3.4308), tensor(3.0244), tensor(3.2464), tensor(3.2664), tensor(3.1711), tensor(3.0784), tensor(3.3225)]


89it [11:59,  8.98s/it]

[tensor(3.3548), tensor(3.2237), tensor(3.0697), tensor(3.3562), tensor(3.0798), tensor(3.1288), tensor(3.0948), tensor(2.9896)]


90it [12:08,  9.13s/it]

[tensor(3.2890), tensor(2.9039), tensor(3.3065), tensor(3.0469), tensor(3.2033), tensor(3.2514), tensor(3.1046), tensor(3.2310)]


91it [12:18,  9.24s/it]

[tensor(3.1115), tensor(3.2557), tensor(3.3596), tensor(3.2035), tensor(3.3137), tensor(3.3066), tensor(3.3504), tensor(3.2429)]


92it [12:27,  9.21s/it]

[tensor(3.2714), tensor(3.3400), tensor(3.0485), tensor(3.2378), tensor(3.2438), tensor(3.3783), tensor(3.0003), tensor(3.0517)]


93it [12:36,  9.06s/it]

[tensor(2.9937), tensor(3.2095), tensor(3.0560), tensor(3.1157), tensor(3.2046), tensor(3.0794), tensor(2.9927), tensor(3.1056)]


94it [12:45,  9.18s/it]

[tensor(3.0565), tensor(2.8411), tensor(3.1881), tensor(3.2194), tensor(3.4738), tensor(3.1807), tensor(3.2176), tensor(2.8973)]


95it [12:54,  9.22s/it]

[tensor(2.9667), tensor(3.2491), tensor(3.1192), tensor(3.0605), tensor(3.1758), tensor(2.9683), tensor(3.1107), tensor(3.3299)]


96it [13:04,  9.25s/it]

[tensor(3.2551), tensor(2.9507), tensor(3.2847), tensor(2.9601), tensor(3.1419), tensor(3.5001), tensor(3.3367), tensor(3.3962)]


97it [13:13,  9.34s/it]

[tensor(3.3150), tensor(3.4441), tensor(3.2965), tensor(3.0019), tensor(3.3485), tensor(3.2219), tensor(3.2419), tensor(3.2947)]


98it [13:23,  9.38s/it]

[tensor(3.1435), tensor(2.9964), tensor(3.4018), tensor(3.0415), tensor(3.2813), tensor(3.2844), tensor(3.1187), tensor(3.2220)]


99it [13:32,  9.43s/it]

[tensor(3.1573), tensor(3.2862), tensor(2.9639), tensor(3.0577), tensor(3.0155), tensor(3.1020), tensor(3.1899), tensor(3.0846)]


100it [13:42,  9.47s/it]

[tensor(3.0913), tensor(3.0577), tensor(3.1046), tensor(3.3903), tensor(3.1719), tensor(2.9873), tensor(3.2847), tensor(3.2256)]


101it [13:51,  9.40s/it]

[tensor(2.9253), tensor(3.0957), tensor(3.0897), tensor(3.0657), tensor(3.1409), tensor(3.2357), tensor(3.1740), tensor(3.3151)]


102it [14:01,  9.46s/it]

[tensor(3.0676), tensor(3.1204), tensor(3.2803), tensor(3.1511), tensor(3.4769), tensor(3.2921), tensor(3.2070), tensor(3.1675)]


103it [14:10,  9.47s/it]

[tensor(2.9804), tensor(3.2413), tensor(3.0644), tensor(3.1506), tensor(3.2603), tensor(3.1771), tensor(3.1320), tensor(2.9215)]


104it [14:19,  9.22s/it]

[tensor(3.2738), tensor(3.0237), tensor(3.0194), tensor(3.0892), tensor(3.2748), tensor(3.4323), tensor(3.1997), tensor(3.5099)]


105it [14:28,  9.29s/it]

[tensor(3.2626), tensor(3.0822), tensor(3.2631), tensor(3.0806), tensor(3.1341), tensor(2.9496), tensor(3.1504), tensor(3.0984)]


106it [14:38,  9.37s/it]

[tensor(3.3300), tensor(3.2796), tensor(3.1930), tensor(3.1042), tensor(3.2419), tensor(3.4641), tensor(3.0919), tensor(3.1998)]


107it [14:47,  9.41s/it]

[tensor(3.0816), tensor(3.2636), tensor(3.1500), tensor(2.3504), tensor(3.3053), tensor(2.9996), tensor(3.3639), tensor(3.0846)]


108it [14:57,  9.42s/it]

[tensor(3.3528), tensor(3.0594), tensor(3.1352), tensor(3.1189), tensor(2.7509), tensor(2.9629), tensor(3.0899), tensor(3.3211)]


109it [15:06,  9.31s/it]

[tensor(3.3218), tensor(3.2795), tensor(3.2491), tensor(3.3606), tensor(3.1527), tensor(2.9613), tensor(3.4337), tensor(3.2191)]


110it [15:15,  9.34s/it]

[tensor(3.1784), tensor(2.9083), tensor(3.3359), tensor(3.5898), tensor(3.4653), tensor(3.2163), tensor(3.1775), tensor(3.1915)]


111it [15:25,  9.38s/it]

[tensor(3.4045), tensor(3.2339), tensor(3.2973), tensor(3.2072), tensor(3.1316), tensor(3.0413), tensor(3.3709), tensor(3.1836)]


112it [15:34,  9.42s/it]

[tensor(3.0116), tensor(2.9286), tensor(3.2939), tensor(3.0917), tensor(3.2460), tensor(2.8458), tensor(3.1776), tensor(2.9930)]


113it [15:44,  9.46s/it]

[tensor(3.0679), tensor(3.2513), tensor(3.3000), tensor(3.2215), tensor(3.1673), tensor(3.1182), tensor(3.0382), tensor(3.2546)]


114it [15:54,  9.57s/it]

[tensor(3.2671), tensor(3.1007), tensor(3.1800), tensor(3.0694), tensor(3.4102), tensor(3.3414), tensor(3.2937), tensor(3.3746)]


115it [16:03,  9.53s/it]

[tensor(3.3596), tensor(3.2748), tensor(2.9621), tensor(3.2301), tensor(3.1208), tensor(3.2055), tensor(3.0008), tensor(3.0596)]


116it [16:12,  9.44s/it]

[tensor(2.8891), tensor(3.5270), tensor(3.1313), tensor(3.1690), tensor(3.4242), tensor(3.0349), tensor(3.2250), tensor(3.1281)]


117it [16:22,  9.48s/it]

[tensor(3.2469), tensor(3.3252), tensor(3.5477), tensor(3.0491), tensor(3.2969), tensor(3.2526), tensor(2.9748), tensor(3.1050)]


118it [16:31,  9.49s/it]

[tensor(3.1609), tensor(3.2400), tensor(3.0420), tensor(3.2332), tensor(3.0754), tensor(3.2998), tensor(3.2940), tensor(3.2754)]


119it [16:41,  9.53s/it]

[tensor(3.2959), tensor(3.2122), tensor(3.2445), tensor(3.1701), tensor(3.3401), tensor(3.3068), tensor(3.2344), tensor(3.2240)]


120it [16:51,  9.51s/it]

[tensor(3.1822), tensor(3.2582), tensor(3.0495), tensor(2.9701), tensor(3.1967), tensor(3.4695), tensor(3.4941), tensor(3.0330)]


121it [17:00,  9.49s/it]

[tensor(3.2468), tensor(3.1072), tensor(3.1042), tensor(3.2674), tensor(3.2031), tensor(3.2250), tensor(3.1135), tensor(3.0792)]


122it [17:10,  9.51s/it]

[tensor(3.1805), tensor(3.1015), tensor(3.0021), tensor(3.1626), tensor(3.1179), tensor(3.2154), tensor(3.2657), tensor(3.0466)]


123it [17:19,  9.51s/it]

[tensor(3.2739), tensor(3.0588), tensor(3.2455), tensor(3.1168), tensor(3.4940), tensor(3.3891), tensor(3.1890), tensor(3.2398)]


124it [17:29,  9.53s/it]

[tensor(3.0112), tensor(3.4935), tensor(3.1544), tensor(3.2326), tensor(3.3259), tensor(3.2887), tensor(3.4499), tensor(3.1427)]


125it [17:38,  9.54s/it]

[tensor(3.2433), tensor(3.1970), tensor(3.3824), tensor(3.2222), tensor(3.1587), tensor(3.1642), tensor(3.2495), tensor(3.1959)]


126it [17:48,  9.56s/it]

[tensor(3.2478), tensor(3.0987), tensor(2.9973), tensor(3.2836), tensor(3.2795), tensor(3.3179), tensor(3.0681), tensor(3.2377)]


127it [17:57,  9.56s/it]

[tensor(3.1181), tensor(3.2205), tensor(3.3153), tensor(3.1010), tensor(3.2781), tensor(3.3897), tensor(3.3182), tensor(3.1852)]


128it [18:07,  9.54s/it]

[tensor(3.4588), tensor(3.1058), tensor(3.2978), tensor(3.2662), tensor(3.2808), tensor(3.4161), tensor(3.2144), tensor(3.2065)]


129it [18:16,  9.53s/it]

[tensor(3.1923), tensor(3.1855), tensor(3.3657), tensor(3.1259), tensor(3.1177), tensor(3.0333), tensor(3.0963), tensor(3.0428)]


130it [18:26,  9.52s/it]

[tensor(3.1122), tensor(3.4188), tensor(3.1444), tensor(3.2005), tensor(3.3978), tensor(3.3640), tensor(3.2832), tensor(3.1572)]


131it [18:35,  9.50s/it]

[tensor(3.0911), tensor(3.3426), tensor(3.4355), tensor(3.1923), tensor(3.2877), tensor(3.1930), tensor(3.3363), tensor(3.1356)]


132it [18:45,  9.52s/it]

[tensor(3.4769), tensor(3.4025), tensor(3.1384), tensor(3.1364), tensor(3.2519), tensor(3.1873), tensor(2.9677), tensor(3.3302)]


133it [18:54,  9.51s/it]

[tensor(3.4019), tensor(3.3106), tensor(3.3849), tensor(3.3712), tensor(3.3229), tensor(3.2425), tensor(3.2280), tensor(3.0916)]


134it [19:04,  9.54s/it]

[tensor(3.2216), tensor(3.3616), tensor(3.2806), tensor(3.1111), tensor(3.2026), tensor(3.3618), tensor(3.1136), tensor(3.3010)]


135it [19:13,  9.52s/it]

[tensor(3.0374), tensor(3.1389), tensor(3.2598), tensor(3.3876), tensor(3.4212), tensor(3.3483), tensor(3.3833), tensor(3.1735)]


136it [19:23,  9.49s/it]

[tensor(3.1849), tensor(3.2836), tensor(3.2748), tensor(3.4348), tensor(3.1661), tensor(3.2955), tensor(3.3266), tensor(3.1940)]


137it [19:32,  9.48s/it]

[tensor(3.2349), tensor(3.2629), tensor(3.2240), tensor(3.1123), tensor(3.2027), tensor(3.2648), tensor(3.1728), tensor(3.0117)]


138it [19:42,  9.49s/it]

[tensor(3.2280), tensor(3.2437), tensor(3.1630), tensor(3.1847), tensor(3.2461), tensor(2.9295), tensor(3.0873), tensor(2.9833)]


139it [19:51,  9.50s/it]

[tensor(3.3935), tensor(3.0087), tensor(3.1008), tensor(3.3205), tensor(3.2348), tensor(3.2817), tensor(3.2434), tensor(3.2507)]


140it [20:01,  9.52s/it]

[tensor(3.2690), tensor(3.0499), tensor(3.0802), tensor(3.4456), tensor(3.6533), tensor(3.0511), tensor(3.2365), tensor(3.0294)]


141it [20:10,  9.53s/it]

[tensor(3.1225), tensor(3.1927), tensor(3.3300), tensor(3.3422), tensor(3.1037), tensor(3.3896), tensor(3.3099), tensor(3.1410)]


142it [20:20,  9.54s/it]

[tensor(3.0669), tensor(3.1766), tensor(3.4136), tensor(3.2787), tensor(3.1536), tensor(3.3624), tensor(3.3934), tensor(3.2859)]


143it [20:30,  9.53s/it]

[tensor(3.1304), tensor(3.1009), tensor(3.2924), tensor(3.5621), tensor(3.2862), tensor(3.1227), tensor(3.2483), tensor(3.1861)]


144it [20:39,  9.56s/it]

[tensor(3.4842), tensor(3.1334), tensor(3.3605), tensor(3.3219), tensor(3.0762), tensor(3.2837), tensor(3.3540), tensor(3.3542)]


145it [20:49,  9.57s/it]

[tensor(3.2571), tensor(3.1691), tensor(3.1389), tensor(3.2112), tensor(3.2093), tensor(3.1969), tensor(3.1250), tensor(3.0439)]


146it [20:58,  9.54s/it]

[tensor(3.1246), tensor(3.2001), tensor(3.1836), tensor(3.1993), tensor(3.2231), tensor(3.0990), tensor(3.1916), tensor(3.2159)]


147it [21:08,  9.54s/it]

[tensor(3.0903), tensor(3.0227), tensor(3.1506), tensor(3.5062), tensor(3.4897), tensor(3.1080), tensor(3.2446), tensor(3.2311)]


148it [21:17,  9.55s/it]

[tensor(3.0874), tensor(3.4382), tensor(3.3833), tensor(3.1421), tensor(3.2651), tensor(3.0209), tensor(3.0837), tensor(3.3518)]


149it [21:27,  9.53s/it]

[tensor(3.2811), tensor(3.3321), tensor(3.3972), tensor(3.1866), tensor(3.5338), tensor(3.4159), tensor(3.0540), tensor(3.3158)]


150it [21:36,  9.52s/it]

[tensor(3.3377), tensor(3.0438), tensor(3.2150), tensor(3.5389), tensor(3.0792), tensor(2.9320), tensor(3.1377), tensor(3.2888)]


151it [21:46,  9.50s/it]

[tensor(3.0861), tensor(3.4388), tensor(3.2122), tensor(3.4206), tensor(3.2535), tensor(3.2523), tensor(3.3983), tensor(3.1918)]


152it [21:55,  9.51s/it]

[tensor(2.9597), tensor(3.0750), tensor(3.2063), tensor(3.2164), tensor(2.9852), tensor(3.2543), tensor(3.0956), tensor(3.1521)]


153it [22:05,  9.51s/it]

[tensor(3.2766), tensor(3.1753), tensor(3.1738), tensor(3.2688), tensor(3.5161), tensor(3.3181), tensor(2.9661), tensor(3.1614)]


154it [22:14,  9.49s/it]

[tensor(3.3585), tensor(3.2308), tensor(3.4074), tensor(3.2981), tensor(3.2209), tensor(3.2369), tensor(3.4452), tensor(3.1675)]


155it [22:24,  9.52s/it]

[tensor(3.3820), tensor(3.4126), tensor(3.3739), tensor(3.1579), tensor(3.1177), tensor(3.4436), tensor(3.2707), tensor(3.0130)]


156it [22:33,  9.51s/it]

[tensor(3.3318), tensor(3.3307), tensor(3.4304), tensor(3.5055), tensor(3.2322), tensor(3.2361), tensor(3.2587), tensor(3.2847)]


157it [22:43,  9.49s/it]

[tensor(3.3722), tensor(3.2914), tensor(3.2891), tensor(3.3627), tensor(3.0851), tensor(2.8570), tensor(3.2525), tensor(2.9825)]


158it [22:52,  9.48s/it]

[tensor(3.2120), tensor(3.3218), tensor(3.2465), tensor(3.2638), tensor(3.1782), tensor(3.3605), tensor(3.3155), tensor(3.2195)]


159it [23:02,  9.47s/it]

[tensor(3.2004), tensor(3.2079), tensor(3.2882), tensor(3.1261), tensor(3.4547), tensor(3.2580), tensor(3.0977), tensor(3.1771)]


160it [23:11,  9.47s/it]

[tensor(3.0018), tensor(3.2265), tensor(3.2916), tensor(3.2850), tensor(3.2925), tensor(3.1708), tensor(3.0972), tensor(3.0685)]


161it [23:21,  9.48s/it]

[tensor(3.4290), tensor(3.3290), tensor(3.4818), tensor(2.9948), tensor(3.1246), tensor(3.4089), tensor(3.1560), tensor(3.2053)]


162it [23:30,  9.54s/it]

[tensor(3.2574), tensor(3.2739), tensor(3.2367), tensor(3.2182), tensor(3.0601), tensor(3.1154), tensor(3.3276), tensor(3.3904)]


163it [23:40,  9.51s/it]

[tensor(3.2065), tensor(3.2590), tensor(3.0876), tensor(3.3323), tensor(3.3033), tensor(3.1366), tensor(3.3350), tensor(3.3393)]


164it [23:49,  9.53s/it]

[tensor(3.1279), tensor(3.4538), tensor(3.0480), tensor(3.1522), tensor(3.3322), tensor(3.5140), tensor(3.2733), tensor(3.5368)]


165it [24:01,  8.74s/it]


KeyboardInterrupt: 

In [None]:
# Save the model
tokenizer.push_to_hub("likhithasapu/codemix-indicbart-ppo-1000")
ppo_model.push_to_hub("likhithasapu/codemix-indicbart-ppo-1000")