In [1]:
import gc
import torch
import numpy as np
import pandas as pd
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments

2024-06-08 15:56:36.680123: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-08 15:56:36.680246: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-08 15:56:36.834121: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

In [None]:
class dataset_model(Dataset) :
    def __init__(self, data) :
        super().__init__()
        self.data = data
        
    def __len__(self) :
        return len(self.data)
    
    def __getitem__(self, idx) :

        scr_sentence = self.data['en'].loc[idx]
        tgt_sentence = self.data['fr'].loc[idx]
        
        return [scr_sentence, tgt_sentence]

In [None]:
def apply_masking(text, mask_rate=0.05):
    mask = torch.rand(text.shape) > mask_rate
    text = text * mask
    return text

In [None]:
task = 'translate English to French:'
pad = '<pad>'

def collator(data) :
        gc.enable()
        torch.cuda.empty_cache()
        data = np.array(data)
        src_sentence = data[:, 0]
        tgt_sentence = data[:, 1]
        scr = tokenizer([task + sentence for sentence in src_sentence.tolist()], padding='longest', truncation=True, max_length=100, 
                                return_token_type_ids=False, return_tensors='pt')
        tgt = tokenizer([pad + sentence for sentence in tgt_sentence.tolist()], padding='longest', truncation=True, max_length=100, 
                               return_token_type_ids=False, return_tensors='pt')
        
        output = {}
        output['input_ids'] = apply_masking(scr['input_ids'])
        output['attention_mask'] = scr['attention_mask']
        output['decoder_attention_mask'] = tgt['attention_mask'][:, :-1]           
        output['decoder_input_ids'] = tgt['input_ids'][:, :-1]
        output['labels'] = tgt['input_ids'][:, 1:]
        src = None
        tgt = None
        src_sentence = None
        tgt_sentence = None
        gc.collect()
        return output

In [None]:
data = pd.read_csv('/kaggle/input/en-fr-translation-dataset/en-fr.csv', nrows=600000)
dataset = dataset_model(data)

In [None]:
"""metric = evaluate.load('wer')
def eval_metric(pred) :
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=-1)
    output = metric.compute(predictions = preds, reference = labels)   
    return {'wer' : output}"""

In [None]:
class train_model(nn.Module) :
    def __init__(self) :
        super().__init__()
        self.model = model.to('cuda')
        
    def forward(self, **params) :
        return self.model(**params)

In [None]:
training_args = TrainingArguments(
                    output_dir="/kaggle/working/",
                    report_to = 'none',
                    lr_scheduler_type='cosine',
                    per_device_train_batch_size=16,
                    learning_rate=3e-4,
                    weight_decay = 2e-4,
                    warmup_steps=40,
                    num_train_epochs=1,
                    fp16=True,
                    save_strategy = 'no',
                    logging_steps=300,
                    torch_compile = True,
                    torch_compile_mode='max-autotune',
                )
training_model = train_model()
trainer = Trainer(model = training_model, args=training_args, data_collator = collator, train_dataset = dataset)
trainer.train()
state_dict = model.state_dict()
torch.save(state_dict, '/kaggle/working/model_weights.pth')