In [19]:
from transformers import BartForConditionalGeneration, BartTokenizer
from datasets import load_dataset ,load_metric
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import pandas as pd
from datasets import Dataset
from datasets import load_dataset
import torch.nn.utils.prune as prune
from transformers import TrainingArguments,Trainer
import sacrebleu

In [2]:
# Load the tokenizer and model
model_name = 'facebook/bart-base'
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)


In [3]:
# Load WMT 2014 English-German dataset
dataset = load_dataset("wmt14", "de-en",split = 'train')
test_dataset = load_dataset('wmt14',"de-en",split = 'test')

val_dataset = load_dataset('wmt14',"de-en",split = 'validation')

In [4]:
sampled_dataset = dataset.shuffle(seed=42).select(range(int(0.001 * len(dataset))))
test_sample_dataset = test_dataset.shuffle(seed = 42).select(range(int(0.01*len(test_dataset))))

val_sample_dataset = val_dataset.shuffle(seed = 42).select(range(int(0.01*len(val_dataset))))

In [5]:
sample = sampled_dataset[0]

In [6]:
sample

{'translation': {'de': 'In diesem Rubrik finden Sie Fahndungsmeldungen, die auf Anfrage eines Staatsanwalts oder Untersuchungsrichter verbreitet werden.',
  'en': "On these pages you will find the wanted or missing notices that are issued at public prosecutor or examining magistrate's request."}}

In [7]:
# Function to preprocess the data
def preprocess_function(examples):
    # Extracting German and English texts from the 'translation' dictionary
    inputs = [ex['de'] for ex in examples['translation']]
    targets = [ex['en'] for ex in examples['translation']]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True, padding='max_length')

    # Set up the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=1024, truncation=True, padding='max_length')

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [8]:
# Tokenize the dataset
tokenized_datasets = sampled_dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)
test_tokenized_datasets = test_sample_dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)

val_tokenized_datasets = val_sample_dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)


Map:   0%|          | 0/30 [00:00<?, ? examples/s]



In [9]:


def apply_pruning(model):
    # Iterate over all modules and prune the linear layers found in the encoder and decoder
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Applying unstructured L1 pruning
            prune.l1_unstructured(module, name='weight', amount=0.2)
            # To make the pruning permanent, you might typically call prune.remove, but it is better to do it after training

In [10]:
# Pruning before training
apply_pruning(model)

In [11]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=1,
#     predict_with_generate=True
)


In [13]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    tokenizer=tokenizer,
    eval_dataset=val_tokenized_datasets,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [14]:
import wandb
wandb.init(mode="disabled")



In [15]:
# Train the model
trainer.train()



Epoch,Training Loss,Validation Loss
1,0.1105,0.070609


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


TrainOutput(global_step=1127, training_loss=0.5368960675880141, metrics={'train_runtime': 1123.559, 'train_samples_per_second': 4.012, 'train_steps_per_second': 1.003, 'total_flos': 2748691953745920.0, 'train_loss': 0.5368960675880141, 'epoch': 1.0})

In [17]:
pip install sacrebleu


Collecting sacrebleu
  Downloading sacrebleu-2.4.2-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.0/58.0 kB[0m [31m832.7 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.8.2-py3-none-any.whl.metadata (8.5 kB)
Downloading sacrebleu-2.4.2-py3-none-any.whl (106 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hDownloading portalocker-2.8.2-py3-none-any.whl (17 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-2.8.2 sacrebleu-2.4.2
Note: you may need to restart the kernel to use updated packages.


In [39]:
def generate_translation(batch):
    # Assuming batch['translation'] is a list of dictionaries
    german_sentences = [item['de'] for item in batch['translation']]
    english_sentences = [item['en'] for item in batch['translation']]

    # Tokenize the German sentences
    inputs = tokenizer(german_sentences, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
    inputs = {key: val.to(model.device) for key, val in inputs.items()}

    # Generate outputs
    outputs = model.generate(**inputs, max_length=512, num_beams=5)

    # Decode the outputs to human-readable translations
    translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return {"pred_translation": translations}


In [40]:
# Apply translation generation function to the test dataset
results =test_dataset.map(generate_translation, batched=True, batch_size=16)


Map:   0%|          | 0/3003 [00:00<?, ? examples/s]

In [31]:
# Apply translation generation function to the test dataset
# results =test_sample_dataset.map(generate_translation, batched=True, batch_size=16)


Map:   0%|          | 0/30 [00:00<?, ? examples/s]

In [41]:
# Extract the translations and references
translations = [result['pred_translation'] for result in results]
references = [[ref['en']] for ref in results['translation']]  # Note references are expected as a list of lists


In [42]:
# Compute BLEU score using sacrebleu
bleu = sacrebleu.corpus_bleu(translations, references)
print(f"BLEU Score: {bleu.score}")

BLEU Score: 21.3643503198117


In [35]:
# def generate_translation(batch):
#     # Print the batch structure to understand how the data is organized
#     print(batch)
#     return batch  # Return the batch as is for inspection


In [36]:
# test_sample_dataset.map(generate_translation, batched=True, batch_size=2)

Dataset({
    features: ['translation'],
    num_rows: 30
})