In [None]:
! pip install datasets
! pip install -U transformers
! pip install -U accelerate
! pip install evaluate
! pip install wandb
! pip install --upgrade accelerate

In [None]:
import wandb
!wandb login 532eb90ecf4aa93d56a353a11b3b74c253d882cb

import os
os.environ["WANDB_PROJECT"] = "Seq2SeqZip"

In [None]:
from typing import Dict, List, Tuple
from dataclasses import dataclass
from tqdm import tqdm

import Levenshtein
import numpy as np
import torch
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
import evaluate
from torch.utils.data import DataLoader

from transformers import BartTokenizer, BartForConditionalGeneration, Seq2SeqTrainingArguments
import pytorch_lightning as pl


SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

In [None]:
df = pd.read_csv('/kaggle/input/hexadecimalzip/randomized_shorthex2hex.csv')
df = df[:8000]
print(df.head())

In [None]:
df['deflate_hex'] = [elem + "</s>" for elem in df['deflate_hex']]      
df['text_hex'] = [elem + "</s>" for elem in df['text_hex']]
df['text'] = [elem + "</s>" for elem in df['text']]

ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

print(ds_splits)

In [None]:
@dataclass
class DataCollatorSeq2SeqWithPadding:
    tokenizer: BartTokenizer

    def __call__(self, dataset_elements) -> Dict[str, torch.Tensor]:

        # collect the input and output sequences
        input_text = [de["text_hex"] for de in dataset_elements]
        output_text = [de["deflate_hex"] for de in dataset_elements]

        # tokenize both sequences in batch so that it will be much faster!
        input_features = self.tokenizer(
            input_text,
            return_tensors="pt",  # output directly tensors
            padding=True, # add the padding on each sequence if needed
            truncation=True # If the input sequence is too long, truncate it
        )

        output_features = self.tokenizer(
            output_text,
            return_tensors="pt",
            padding=True,
            truncation=True
        )["input_ids"]  # here we only need the input_ids (output actually)

        output_features[output_features==self.tokenizer.pad_token_id] = -100 # cross entropy ignore index

        # This is the only parameters we need for the forward pass
        # to understand why, take a look to the BartForConditionalGeneration.forward method signature.
        batch = {
            "input_ids": input_features["input_ids"],
            "attention_mask": input_features["attention_mask"],
            "labels": output_features,
        }

        return batch

In [None]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

In [None]:
data_collator = DataCollatorSeq2SeqWithPadding(tokenizer)

In [None]:
from nltk.metrics.distance import edit_distance

def compute_metrics(preds, labels):
    for lab in labels:
        for i in range(len(lab)):
            if (lab[i] == -100):
                lab[i] = tokenizer.pad_token_id
    
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    assert len(decoded_preds) == len(decoded_labels)
    
    results = []
    count_unzippable = 0
    for i in range(len(decoded_preds)):
        distance = edit_distance(decoded_preds[i], decoded_labels[i])
        results.append(distance)
        if distance == 0:
            count_unzippable += 1
    
    avg_distance = np.mean(results)
    #print(f"Avg Distance = {avg_distance}")
    
    result_dict = {"average_edit_distance": avg_distance, "count_unzippable": count_unzippable}
    #wandb.log(result_dict)
    
    return result_dict

In [None]:
import numpy as np

class Seq2Seq(pl.LightningModule):
    def __init__(self, tokenizer, model, data_collator, training_args, compute_metrics):
        super(Seq2Seq, self).__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.data_collator = data_collator
        self.training_args = training_args
        self.compute_metrics = compute_metrics
        self.epoch_loss = []
        self.valid_labels = []
        self.valid_predictions = []

    def forward(self, input_ids, attention_mask, labels=None):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        self.log('train_loss', loss)
        self.epoch_loss.append(loss.cpu().item())
               
        return loss
    
    def on_train_epoch_end(self):
        #print(f"Epoch_loss = {self.epoch_loss}")
        print(f"Loss = {np.mean(self.epoch_loss)}")
        self.epoch_loss = []
        
    def validation_step(self, batch, batch_idx):        
        generated = self.model.generate(**batch, max_new_tokens=256)        
        
        for gen in generated:
            self.valid_predictions.append(gen.cpu().tolist())
        
        for label in batch['labels']:
            self.valid_labels.append(label.cpu().tolist())
        
        return
    
    def on_validation_epoch_end(self):
        res = compute_metrics(self.valid_predictions, self.valid_labels)
        print(f"Validation Edit distance = {res['average_edit_distance']}")
        print(f"Validation Unzippable = {res['count_unzippable']}")
        return

    def test_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        self.log('test_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.training_args.learning_rate,
            weight_decay=self.training_args.weight_decay
        )
        return optimizer

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="output",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    weight_decay=2e-4,
    warmup_steps=500,
    evaluation_strategy="epoch",
    num_train_epochs=4,
    fp16=True,
    per_device_eval_batch_size=16,
    generation_max_length=256,
    eval_steps=400,
    logging_steps=400,
    remove_unused_columns=False,
    label_names=["labels"],
    predict_with_generate=True,
    save_strategy="no"
)

trainer = pl.Trainer(
    precision="16-mixed",
    max_epochs = 10, 
    max_steps=training_args.max_steps,
    accumulate_grad_batches=training_args.gradient_accumulation_steps,
    val_check_interval=training_args.eval_steps,
    logger=False,
    callbacks=[pl.callbacks.ProgressBar()],
)

train_dataloader = DataLoader(ds_splits["train"], batch_size=training_args.per_device_train_batch_size, collate_fn=data_collator, num_workers=4)
valid_dataloader = DataLoader(ds_splits["valid"], batch_size=training_args.per_device_train_batch_size, collate_fn=data_collator, num_workers=4)

seq2seq_model = Seq2Seq(tokenizer, model, data_collator, training_args, compute_metrics)
trainer.fit(seq2seq_model, train_dataloader, valid_dataloader)
trainer.test(seq2seq_model, DataLoader(ds_splits["test"], batch_size=8, collate_fn=data_collator, num_workers=4))