In [None]:
from datasets import load_dataset
import random
from datasets import Dataset
from datasets import DatasetDict
import torch
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print(device)
# Define a function to process each example
def process_example(example):
    # Split the example into IN and OUT parts, and remove the labels
    text = example['text']
    parts = text.split('OUT:')
    in_part = parts[0].replace('IN:', '').strip()
    out_part = parts[1].strip() if len(parts) > 1 else ''
    return {'input': in_part, 'output': out_part}

datasets = load_dataset(
    'text', 
    data_files={'train': 'data/simple_split/tasks_train_simple.txt',
                'test': 'data/simple_split/tasks_test_simple.txt'})

# # Assuming your processed dataset is stored in a Hugging Face Dataset object called `processed_dataset`
# # Get the original samples as a list of dictionaries
# original_samples = datasets['train'].to_dict()

# # Calculate how many additional samples are needed
# total_samples = 100000
# original_count = len(original_samples['text'])
# additional_count = total_samples - original_count

# # Randomly sample additional samples with replacement
# additional_samples = {
#     key: random.choices(original_samples[key], k=additional_count)
#     for key in original_samples
# }

# # Combine the original samples and the additional samples
# datasets['train'] = Dataset.from_dict({
#         key: original_samples[key] + additional_samples[key]
#         for key in original_samples
#     })

# # Verify the structure of the combined DatasetDict
# print(datasets)

datasets['train'] = datasets['train'].map(process_example)
datasets['test'] = datasets['test'].map(process_example)
# Display the processed dataset
print(datasets, datasets['train'][0])



In [None]:

from transformers import BertTokenizer, DataCollatorWithPadding

tokenizer = BertTokenizer.from_pretrained('bert-base-cased', force_download=False)

In [None]:
input_tokens = tokenizer(datasets['train'][0]['text'], 
                            return_tensors='pt',
                            padding=True,
                            truncation=True,
                            max_length=512)
input_tokens

In [None]:
MAX_LENGTH = 512

def tokenize_function(example):
    inputs = dict()
    input_str = [example["input"][idx] + tokenizer.sep_token + (' ' + tokenizer.mask_token) * MAX_LENGTH for idx in range(len(example["input"]))]
    output_str = [example["input"][idx] + tokenizer.sep_token + example["output"][idx] + (' ' + tokenizer.sep_token) * MAX_LENGTH for idx in range(len(example["input"]))]

    input_tokens = tokenizer(input_str, 
                            return_tensors='pt',
                            padding=True,
                            truncation=True,
                            max_length=MAX_LENGTH)
    output_tokens = tokenizer(output_str, 
                            return_tensors='pt',
                            padding=True,
                            truncation=True,
                            max_length=MAX_LENGTH)
    
    inputs.update(input_tokens)
    inputs['labels'] = output_tokens['input_ids']
    for idx in range(len(inputs['labels'])):
        for i in range(len(inputs['labels'][idx])):
            sep = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
            if inputs['labels'][idx][i] != sep:
                inputs['labels'][idx][i] = -100
            else:
                inputs['labels'][idx][i] = -100
                break


    return inputs

tokenized_datasets = datasets.map(tokenize_function, batched=True, remove_columns=['text', 'input', 'output'])


In [None]:
tokenized_datasets

In [None]:
from transformers import BertForMaskedLM
model = BertForMaskedLM.from_pretrained("bert-base-cased", force_download=False)

In [None]:
from torch.utils.data import DataLoader
from pprint import pprint
import torch
from tqdm import tqdm
# Evaluation

def calculate_accuracies(eval_preds):
    """
    Calculate token-wise accuracy and sequence-wise accuracy.
    
    Args:
        predictions (list[list]): List of predicted token sequences.
        targets (list[list]): List of target token sequences.
    
    Returns:
        tuple: (token_wise_accuracy, sequence_wise_accuracy)
    """
    # Ensure predictions and targets are the same length
    logits, labels = eval_preds
    assert len(predictions) == len(targets), "Predictions and targets must have the same number of sequences."
    prediction = logits.argmax(dim=-1).cpu()
    targets = []
    predictions = []
    
    for batch_idx in range(len(labels)):
        target = []
        pre = []
        for i in range(len(labels[batch_idx])):
            if labels[batch_idx][i].item() != -100:
                sep = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
                if labels[batch_idx][i].item() == sep:
                    targets.append(target)
                    predictions.append(pre)
                    break
                target.append(labels[batch_idx][i].cpu().item())
                pre.append(prediction[batch_idx][i].cpu().item())     
    
    total_tokens = 0
    correct_tokens = 0
    correct_sequences = 0
    
    for pred_seq, target_seq in zip(predictions, targets):
        # Ensure sequences are the same length
        assert len(pred_seq) == len(target_seq), "Each prediction and target sequence must have the same length."
        
        # Token-wise comparison
        total_tokens += len(target_seq)
        correct_tokens += sum(p == t for p, t in zip(pred_seq, target_seq))
        
        # Sequence-wise comparison
        if pred_seq == target_seq:
            correct_sequences += 1
    
    # Calculate accuracies
    token_wise_accuracy = correct_tokens / total_tokens
    sequence_wise_accuracy = correct_sequences / len(targets)
    
    return {"token_wise_accuracy": token_wise_accuracy,
            "sequence_wise_accuracy": sequence_wise_accuracy,
            "targets": targets,
            "predictions": predictions}


In [None]:
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments

# Training arguments
training_args = TrainingArguments(
    output_dir="./bert_finetune",
    evaluation_strategy="steps",
    learning_rate=5e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=100,
    eval_steps=10,
    report_to="tensorboard",
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, max_length=512, padding=True)
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=calculate_accuracies,
)

# Fine-tune the model
trainer.train()

# Save the model
model.save_pretrained("./bert_finetuned")
tokenizer.save_pretrained("./bert_finetuned")


In [None]:
# print("Token wise ACC:", results[0][0], ";Sentence wise ACC:", results[0][1])
# print(len(results))

In [None]:
# Training
