In [21]:
%%capture
!pip install datasets transformers accelerate evaluate wandb nltk pandas pytorch_lightning

In [22]:
import os
import pandas as pd
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from transformers import BartTokenizer, BartForConditionalGeneration, T5ForConditionalGeneration, AutoTokenizer
from torch.utils.data import DataLoader
from nltk.metrics.distance import edit_distance
import pytorch_lightning as pl
from torch.optim import AdamW
from dataclasses import dataclass

# Set up wandb environment
os.environ["WANDB_PROJECT"] = "Seq2SeqZip"

# Seed for reproducibility
SEED = 999
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


In [23]:
# Function to load and preprocess dataset
def load_and_preprocess_data(filepath):
    df = pd.read_csv(filepath)[:15000]
    df[['deflate_hex', 'text_hex', 'text']] += "</s>"
    ds = Dataset.from_pandas(df)
    ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
    ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
    return DatasetDict({
        'train': ds_train_test['train'],
        'valid': ds_test_dev['train'],
        'test': ds_test_dev['test']
    })

ds_splits = load_and_preprocess_data('/kaggle/input/full-dataset/randomized_shorthex2hex.csv')

In [24]:
@dataclass
class DataCollatorSeq2SeqWithPadding:
    tokenizer: AutoTokenizer

    def __call__(self, dataset_elements):
        inputs = [x["text_hex"] for x in dataset_elements]
        outputs = [x["deflate_hex"] for x in dataset_elements]
        input_features = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length = MAX_SEQ_LEN)
        output_features = self.tokenizer(outputs, return_tensors="pt", padding=True, truncation=True, max_length = MAX_SEQ_LEN)["input_ids"]
        output_features[output_features == self.tokenizer.pad_token_id] = -100
        return {"input_ids": input_features["input_ids"], "attention_mask": input_features["attention_mask"], "labels": output_features}


In [25]:
#MODEL CHOICES: bart-base, bart-large, t5-base
MODEL = "t5-base"

if (MODEL == "bart-base"):
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

elif (MODEL == "bart-large"):
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
    model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

else:
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    model = T5ForConditionalGeneration.from_pretrained("t5-base")

data_collator = DataCollatorSeq2SeqWithPadding(tokenizer = tokenizer)

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [26]:
def compute_metrics(preds, labels, tokenizer):
    # Ensure labels with -100 are replaced by pad_token_id
    labels = torch.where(labels == -100, tokenizer.pad_token_id, labels)

    # Convert tensors to lists for decoding if they're not already in CPU
    if torch.is_tensor(preds):
        preds = preds.detach().cpu().tolist()
    if torch.is_tensor(labels):
        labels = labels.detach().cpu().tolist()

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    distances = [edit_distance(p, l) for p, l in zip(decoded_preds, decoded_labels)]
    avg_distance = np.mean(distances)
    count_unzippable = distances.count(0)

    return {"average_edit_distance": avg_distance, "count_unzippable": count_unzippable}


In [None]:
BATCH_SIZE = 16
MAX_EPOCHS = 5
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-2
MAX_SEQ_LEN = 256

class Seq2Seq(pl.LightningModule):
    def __init__(self, tokenizer, model, data_collator):
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.data_collator = data_collator

    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        self.log('train_loss', outputs.loss, prog_bar=True, logger=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        self.log('val_loss', outputs.loss, prog_bar=True, logger=True)

        preds = torch.argmax(outputs.logits, dim=-1)
        metrics = compute_metrics(preds, batch['labels'], self.tokenizer)
        for key, value in metrics.items():
            self.log(f'{key}', value, prog_bar=True, logger=True)

        return outputs.loss

    def test_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        self.log('test_loss', outputs.loss, prog_bar=True, logger=True)
        
        preds = torch.argmax(outputs.logits, dim=-1)
        metrics = compute_metrics(preds, batch['labels'], self.tokenizer)
        
        for key, value in metrics.items():
            self.log(f'{key}', value, prog_bar=True, logger=True)
        return outputs.loss

    def configure_optimizers(self):
        # Directly use learning rate and weight decay values here
        return AdamW(self.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

trainer = pl.Trainer(
    precision='16-mixed',
    max_epochs=MAX_EPOCHS,
    enable_progress_bar=True,
)

train_dataloader = DataLoader(ds_splits["train"], batch_size=BATCH_SIZE, shuffle = True, collate_fn=data_collator, num_workers = 3)
valid_dataloader = DataLoader(ds_splits["valid"], batch_size=BATCH_SIZE, shuffle = False, collate_fn=data_collator, num_workers = 3)
test_dataloader = DataLoader(ds_splits["test"], batch_size=BATCH_SIZE, shuffle = False, collate_fn=data_collator, num_workers = 3)

seq2seq_model = Seq2Seq(tokenizer, model, data_collator)
trainer.fit(seq2seq_model, train_dataloader, valid_dataloader)

trainer.test(seq2seq_model, test_dataloader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]