In [None]:
#Imports
import requests
import lzma
import os
from datasets import Dataset, DatasetDict
import re
from transformers import AutoTokenizer, AutoModelForMaskedLM, TrainingArguments, DataCollatorForLanguageModeling, Trainer
import numpy as np
import optuna

In [2]:
# url = "http://data.statmt.org/cc-100/sw.txt.xz"
# file_name = "sw.txt.xz"
# response = requests.get(url, stream=True)
# with open(file_name, "wb") as file:
#     for chunk in response.iter_content(chunk_size=1024):
#         if chunk:
#             file.write(chunk)
# print(f"Downloaded {file_name}")

# output_file = "sw.txt"
# with lzma.open(file_name, "rb") as compressed_file:
#     with open(output_file, "wb") as extracted_file:
#         extracted_file.write(compressed_file.read())
# print(f"Extracted to {output_file}")
# os.remove(file_name)

In [3]:
# Step 2: Read and prepare data
num_lines_to_read = 100000  # Adjust this as needed
text_data = []
with open('/datasets/mdawood/sw.txt', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        if i < num_lines_to_read:
            line = line.strip()
            if line:
                text_data.append(line)
        else:
            break

In [4]:
# Create dataset
data_dict = {'text': text_data}
dataset = Dataset.from_dict(data_dict)


In [5]:
# Clean text
def clean_text(example):
    text = example['text']
    text = re.sub(r'<.*?>', '', text)
    text = re.sub(r'http\S+', '', text)
    text = re.sub(r'[^a-zA-ZäöüÄÖÜßẞ\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    text = text.lower()
    return {'text': text}



In [None]:
dataset = dataset.map(clean_text)
dataset = dataset.shuffle(seed=42)


In [7]:
# Split dataset (80% train, 10% validation, 10% test)
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
test_valid = split_dataset['test'].train_test_split(test_size=0.5, seed=42)
swahili_dataset = DatasetDict({
    'train': split_dataset['train'],
    'validation': test_valid['train'],
    'test': test_valid['test'],
})


In [None]:
# Print the number of samples in each split
print(f"Number of samples in train: {len(swahili_dataset['train'])}")
print(f"Number of samples in validation: {len(swahili_dataset['validation'])}")
print(f"Number of samples in test: {len(swahili_dataset['test'])}")

In [None]:
# Tokenization
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')

def tokenize_function(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=128)

tokenized_datasets = swahili_dataset.map(tokenize_function, batched=True, num_proc=4)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask'])



In [None]:
# Pre-trained model loading
def model_init():
    return AutoModelForMaskedLM.from_pretrained('xlm-roberta-base')

In [11]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)


In [None]:
# Training arguments (initial, can be overwritten by hyperparameter search)
training_args = TrainingArguments(
    output_dir='./results',
    overwrite_output_dir=True,
    evaluation_strategy='epoch',
    save_strategy='no',  # Avoid saving too many models during hyperparameter search
    logging_dir='./logs',
    logging_steps=500,
    report_to=['none'],  # Disable reporting to external services
    disable_tqdm=True,  # Disable tqdm to reduce output during hyperparameter search
)

In [13]:
# Initialize Trainer
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
)

In [None]:
# Hyperparameter search space
def hp_space(trial):
    return {
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 5e-5, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 0.0, 0.1),
        'per_device_train_batch_size': trial.suggest_categorical(
            'per_device_train_batch_size', [8, 16, 32]
        ),
        'num_train_epochs': trial.suggest_int('num_train_epochs', 2, 4),
    }

In [None]:
# Objective function for Optuna
def model_objective(trial):
    # Set hyperparameters
    args = TrainingArguments(
        output_dir='./results',
        overwrite_output_dir=True,
        evaluation_strategy='epoch',
        learning_rate=trial.suggest_float('learning_rate', 1e-5, 5e-5, log=True),
        per_device_train_batch_size=trial.suggest_categorical(
            'per_device_train_batch_size', [8, 16, 32]
        ),
        num_train_epochs=trial.suggest_int('num_train_epochs', 2, 4),
        weight_decay=trial.suggest_float('weight_decay', 0.0, 0.1),
        save_total_limit=1,
        logging_dir='./logs',
        logging_steps=500,
        report_to=['none'],
        disable_tqdm=True,
    )

    # Initialize Trainer with the trial's hyperparameters
    trainer = Trainer(
        model_init=model_init,
        args=args,
        train_dataset=tokenized_datasets['train'],
        eval_dataset=tokenized_datasets['validation'],
        data_collator=data_collator,
    )

    # Train the model
    trainer.train()

    # Evaluate the model
    eval_results = trainer.evaluate()
    perplexity = np.exp(eval_results['eval_loss'])
    return perplexity  # Objective is to minimize perplexity

In [None]:
# Run hyperparameter search
study = optuna.create_study(direction='minimize')
study.optimize(model_objective, n_trials=10)

In [None]:
# Print best hyperparameters
print("Best hyperparameters:", study.best_trial.params)

In [None]:
# Update training arguments with best hyperparameters
best_params = study.best_trial.params

In [None]:


training_args = TrainingArguments(
    output_dir='./results',
    overwrite_output_dir=True,
    evaluation_strategy='epoch',
    learning_rate=best_params['learning_rate'],
    per_device_train_batch_size=best_params['per_device_train_batch_size'],
    num_train_epochs=best_params['num_train_epochs'],
    weight_decay=best_params['weight_decay'],
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=500,
)


In [None]:
# Initialize Trainer with best hyperparameters
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
)

# %%
# Train the model with best hyperparameters
trainer.train()

In [None]:
# Evaluate on validation set
eval_results_after = trainer.evaluate()
perplexity_after = np.exp(eval_results_after['eval_loss'])
print(f"Perplexity after fine-tuning: {perplexity_after:.2f}")

In [None]:
# Evaluate on test set
eval_results_test = trainer.evaluate(eval_dataset=tokenized_datasets['test'])
perplexity_test = np.exp(eval_results_test['eval_loss'])
print(f"Perplexity on test set: {perplexity_test:.2f}")

In [None]:
# # Print best hyperparameters
# print("Best hyperparameters:", study.best_trial.params)

In [None]:
# # Step 4: Evaluate pre-trained model (before fine-tuning)
# eval_results_before = trainer_before.evaluate()
# perplexity_before = np.exp(eval_results_before['eval_loss'])
# print(f"Perplexity before fine-tuning: {perplexity_before:.2f}")

In [None]:
# Step 5: Fine-tune the model on Swahili dataset
# training_args = TrainingArguments(
#     output_dir='./results',
#     overwrite_output_dir=True,
#     evaluation_strategy='epoch',
#     learning_rate=5e-5,
#     per_device_train_batch_size=8,
#     per_device_eval_batch_size=8,
#     num_train_epochs=3,
#     weight_decay=0.01,
#     save_total_limit=2,
#     logging_dir='./logs',
#     logging_steps=500,
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_datasets['train'],
#     eval_dataset=tokenized_datasets['validation'],
#     data_collator=data_collator,
# )

In [None]:
# trainer.train()


In [None]:
# # Step 6: Calculate perplexity after fine-tuning
# eval_results_after = trainer.evaluate()
# perplexity_after = np.exp(eval_results_after['eval_loss'])
# print(f"Perplexity after fine-tuning: {perplexity_after:.2f}")


In [None]:
# # Step 7: Save the fine-tuned model
# trainer.save_model('./swahili-xlmr-finetuned')
# tokenizer.save_pretrained('./swahili-xlmr-finetuned')