In [10]:
import os 
import sys
import transformers
import tensorflow as tf
from datasets import load_dataset


In [3]:
from transformers import AutoTokenizer
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq


In [7]:
from transformers import AdamWeightDecay

In [8]:
model_checkpoint ='Helsinki-NLP/opus-mt-en-hi'

In [11]:
raw_datasets = load_dataset('cfilt/iitb-english-hindi')

Downloading readme: 100%|██████████| 3.14k/3.14k [00:00<00:00, 5.67MB/s]
Downloading metadata: 100%|██████████| 953/953 [00:00<00:00, 2.26MB/s]
Downloading data: 100%|██████████| 190M/190M [00:33<00:00, 5.64MB/s] 
Downloading data: 100%|██████████| 85.7k/85.7k [00:01<00:00, 52.9kB/s]
Downloading data: 100%|██████████| 500k/500k [00:01<00:00, 255kB/s]
Generating train split: 100%|██████████| 1659083/1659083 [00:01<00:00, 1012923.22 examples/s]
Generating validation split: 100%|██████████| 520/520 [00:00<00:00, 160796.08 examples/s]
Generating test split: 100%|██████████| 2507/2507 [00:00<00:00, 351077.43 examples/s]


In [12]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 1659083
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 520
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 2507
    })
})

In [13]:
raw_datasets['train'][1]

{'translation': {'en': 'Accerciser Accessibility Explorer',
  'hi': 'एक्सेर्साइसर पहुंचनीयता अन्वेषक'}}

In [15]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [16]:
tokenizer("hello how are you")

{'input_ids': [39915, 287, 54, 27, 0], 'attention_mask': [1, 1, 1, 1, 1]}

In [18]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(['एक्सेर्साइसर पहुंचनीयता अन्वेषक']))

{'input_ids': [[26618, 16155, 346, 33383, 0]], 'attention_mask': [[1, 1, 1, 1, 1]]}




In [21]:
source_lang ='en'
target_lang='hi'

max_input_length =128
max_target_length=128

def preprocess_function(examples):
    inputs = [ex[source_lang] for ex in examples['translation']]
    targets =[ex[target_lang] for ex in examples['translation']]

    model_inputs =tokenizer(inputs,max_length=max_input_length,truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets,max_length=max_target_length,truncation=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [22]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[3872, 85, 2501, 132, 15441, 36398, 0], [32643, 28541, 36253, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1]], 'labels': [[63, 2025, 18, 16155, 346, 20311, 24, 2279, 679, 0], [26618, 16155, 346, 33383, 0]]}

In [23]:
tokenized_datasets = raw_datasets.map(preprocess_function,batched=True)

Map: 100%|██████████| 1659083/1659083 [04:54<00:00, 5637.13 examples/s] 
Map: 100%|██████████| 520/520 [00:00<00:00, 7045.72 examples/s]
Map: 100%|██████████| 2507/2507 [00:00<00:00, 6221.71 examples/s]


In [24]:
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

2024-08-13 18:10:15.866941: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-13 18:10:15.867054: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-13 18:10:15.867067: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-13 18:10:17.796568: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-13 18:10:17.796629: I external/local_xla/xla/stream_executor

In [25]:
batch_size = 16
learning_rate =2e-5
weight_decay =0.01
num_train_epochs =1

In [26]:
data_collator = DataCollatorForSeq2Seq(tokenizer,model = model,return_tensors='tf')

In [27]:
generation_data_collator = DataCollatorForSeq2Seq(tokenizer,model = model,return_tensors='tf',pad_to_multiple_of=128)

In [28]:
train_dataset =model.prepare_tf_dataset(
    tokenized_datasets['test'],
    batch_size = batch_size,
    shuffle = True,
    collate_fn =data_collator,
)

In [29]:
validation_data =model.prepare_tf_dataset(
    tokenized_datasets['validation'],
    batch_size = batch_size,
    shuffle = True,
    collate_fn =data_collator,
)

In [30]:
generation_dataset =model.prepare_tf_dataset(
    tokenized_datasets['validation'],
    batch_size = 8,
    shuffle = False,
    collate_fn =generation_data_collator,
)

In [31]:
optimizer = AdamWeightDecay(learning_rate=learning_rate,weight_decay_rate=weight_decay)
model.compile(optimizer=optimizer)

In [32]:
model.fit(train_dataset,validation_data=validation_data,epochs=1)

Cause: for/else statement not yet supported
Cause: for/else statement not yet supported


<tf_keras.src.callbacks.History at 0x7f9b45f2c160>

In [34]:
model.save_pretrained('tf_model/')

Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[61949]], 'forced_eos_token_id': 0}


In [35]:
tokenizer =AutoTokenizer.from_pretrained(model_checkpoint)



In [38]:
model = TFAutoModelForSeq2SeqLM.from_pretrained('tf_model/')

All model checkpoint layers were used when initializing TFMarianMTModel.

All the layers of TFMarianMTModel were initialized from the model checkpoint at tf_model/.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMarianMTModel for predictions without further training.


In [49]:
input_text = 'vaibhav is a bad boy'
tokenized = tokenizer([input_text],return_tensors='np')
out = model.generate(**tokenized,max_length =128)
print(out)

tf.Tensor([[61949  6206  5525  7668    38  1161  4504     5     0 61949]], shape=(1, 10), dtype=int32)


In [50]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0],skip_special_tokens=True))

विबाध एक बुरा लड़का है
