In [None]:
import os, sys
sys.path.append(os.path.abspath("../"))
from src.data.make_dataset import get_dataset

import pandas as pd

INTERIM_DATASET_PATH = '../data/interim/preprocessed.tsv'

df = pd.read_csv(INTERIM_DATASET_PATH, delimiter='\t')

In [None]:
df.head()

In [None]:
from sklearn.model_selection import train_test_split

# Train/Test Split
train, eval = train_test_split(df, test_size=0.2, shuffle=False)

In [None]:
train.head()

## Create Dataloader:

In [None]:
import torch

# Set up device
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
device


In [None]:
import torch
from torch.utils.data import Dataset
import re

def string_to_list(s):
    return [int(x) for x in re.findall(r'\d+', s)]

class TextDetoxDataset(Dataset):
    def __init__(self, dataframe):
        self.reference_input_ids = dataframe['reference_input_ids'].apply(string_to_list)
        self.reference_attention_mask = dataframe['reference_attention_mask'].apply(string_to_list)
        self.similarity = dataframe['similarity']
        self.ref_tox = dataframe['ref_tox']
        self.trn_tox = dataframe['trn_tox']
        self.length_diff = dataframe['lenght_diff']
        self.translation_input_ids = dataframe['translation_input_ids'].apply(string_to_list)
        self.translation_attention_mask = dataframe['translation_attention_mask'].apply(string_to_list)

    def __len__(self):
        return len(self.reference_input_ids)

    def __getitem__(self, idx):
        return {
            "reference_input_ids": self.reference_input_ids[idx],
            "reference_attention_mask": self.reference_attention_mask[idx],
            "similarity": self.similarity[idx],
            "ref_tox": self.ref_tox[idx],
            "trn_tox": self.trn_tox[idx],
            "length_diff": self.length_diff[idx],
            "translation_input_ids": self.translation_input_ids[idx],
            "translation_attention_mask": self.translation_attention_mask[idx],
        }
    
train_dataset = TextDetoxDataset(train)
eval_dataset = TextDetoxDataset(eval)

In [None]:
# train_dataset[3]
ex=0
for i in range(len(train_dataset)):
    try:
        # print(i)
        assert len(train_dataset[i]['reference_input_ids']) == 512
    except Exception as e:
        ex+=1
print(ex)

In [None]:
train_dataset[0].keys()

In [None]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    # Each entry in the batch is a dictionary; we want to batch by key
    # First, we extract keys from the first dictionary in the batch (assuming all entries have the same keys)
    keys = batch[0].keys()
    collated_batch = []
    c = 0
    reference_input_ids_batch, reference_attention_mask_batch, translation_input_ids_batch = [], [], []
    for entry in batch:
        # print(c)
        # print(entry['input_ids'])
        # print(entry['attention_mask'])
        # print(entry['target_ids'])
        # collated_batch.append({
        #     "input_ids": torch.tensor(entry['input_ids'], dtype=torch.int32).to(device),
        #     "attention_mask": torch.tensor(entry['attention_mask'], dtype=torch.float32).to(device),
        #     "target_ids": torch.tensor(entry['target_ids'], dtype=torch.int32).to(device)
        # })
        reference_input_ids_batch.append(entry["reference_input_ids"])
        reference_attention_mask_batch.append(entry["reference_attention_mask"])
        translation_input_ids_batch.append(entry["translation_input_ids"])
    # for key in keys:
    #     # For each key, extract the corresponding values from all batch entries
    #     values = [entry[key] for entry in batch]

    #     # If the values are torch tensors (like input_ids, attention_mask, etc.), pad them
    #     if isinstance(values[0], torch.Tensor):
    #         values_padded = pad_sequence(values, batch_first=True)
    #         collated_batch[key] = torch.tensor(values_padded, device=device)
    #     else:
    #         collated_batch[key] = torch.tensor(values, device=device)

    reference_input_ids_batch = torch.tensor(reference_input_ids_batch, dtype=torch.int32).to(device)
    reference_reference_attention_mask_batch = torch.tensor(reference_attention_mask_batch, dtype=torch.float32).to(device)
    translation_input_ids_batch = torch.tensor(translation_input_ids_batch, dtype=torch.float32).to(device)
    
    collated_batch = {
        "input_ids": reference_input_ids_batch,
        "attention_mask": reference_attention_mask_batch,
        "labels": translation_input_ids_batch
    }

    return collated_batch
    # return reference_input_ids_batch, reference_attention_mask_batch, translation_input_ids_batch

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=False, collate_fn=custom_collate_fn)

## Define the Model:

In [36]:
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainingArguments, Seq2SeqTrainer

# Load pre-trained BERT models and tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
encoder_decoder_model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

# Tokenize and prepare the dataset, make sure to set truncation=True, padding=True
# and max_length as per your requirements

# Fine-tuning configuration
training_args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    predict_with_generate=True,
    evaluation_strategy="steps",
    remove_unused_columns=False,
    logging_dir='./logs',
    logging_steps=100,
    save_steps=500,
    eval_steps=750,
    save_total_limit=3,
    output_dir="./output",
    # ... (add more arguments as needed)
)

# Create a Seq2SeqTrainer
trainer = Seq2SeqTrainer(
    model=encoder_decoder_model,
    args=training_args,
    # data_collator= custom_collate_fn,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Fine-tune the model
trainer.train()

In [None]:
# from transformers import BertTokenizer
# from torch.optim import AdamW
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# # Hyperparameters
# HIDDEN_DIM = 256
# OUTPUT_DIM = len(tokenizer.vocab)  # Vocabulary size
# N_LAYERS = 2
# DROPOUT = 0.5
# LR = 5e-5
# EPOCHS = 3

# # Initialize model, loss, optimizer
# bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
# model = TextDetoxifier(bert_model, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, DROPOUT).to(device)
# criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id).to(device)
# optimizer = AdamW(model.parameters(), lr=LR)

# # (Optional) Set up a learning rate scheduler
# scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train)*EPOCHS)

## Training Loop:

In [None]:
# from tqdm import tqdm

# # Training loop
# for epoch in range(EPOCHS):
#     model.train()
#     total_loss = 0
    
#     # Wrapping train_dataloader with tqdm to get a progress bar
#     progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
#     for batch in progress_bar:
#         # Get inputs and targets from batch
#         input_ids = batch[0]
#         attention_mask = batch[1]
#         target_ids = batch[2]

#         optimizer.zero_grad()
         
#         # print(len(input_ids), len(attention_mask), len(target_ids))
#         # print(target_ids.shape)
#         # break
#         # Forward pass
#         outputs = model(input_ids, attention_mask, target_ids)

#         print('x')

#         # Reshape outputs and target_ids for loss calculation
#         outputs = outputs.view(-1, OUTPUT_DIM)
#         target_ids = target_ids.view(-1)
        
#         # Calculate loss
#         loss = criterion(outputs, target_ids)
        
#         # Backward pass
#         loss.backward()
        
#         # Gradient clipping (often used with BERT)
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
#         optimizer.step()
        
#         # Update the learning rate
#         scheduler.step()
        
#         total_loss += loss.item()

#     print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss/len(train_dataloader)}")
