In [1]:
%%capture
! pip install datasets
! pip install transformers -U
! pip install accelerate -U
! pip install evaluate
! pip install bleu
! pip install python-Levenshtein
! pip install wandb

In [4]:
from typing import Dict, List, Tuple
from dataclasses import dataclass
from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict

from transformers import BartTokenizer, BartForConditionalGeneration

SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)



In [2]:
LOG = False

if LOG:
  import wandb
  wandb.login()

  import os
  os.environ["WANDB_PROJECT"] = "Seq2SeqZip"

# Dataset

In [5]:
df = pd.read_csv('/kaggle/input/hexadecimalzip/shorthex.csv')
df = df[:8000]
print(df.head())
df['deflate_hex'] = [elem + "</s>" for elem in df['deflate_hex']]      
df['text_hex'] = [elem + "</s>" for elem in df['text_hex']]

                             text  \
0                One of the other   
1  A wonderful little production.   
2              I thought this was   
3      Basically there's a family   
4        Petter Mattei's "Love in   

                                            text_hex  \
0                   4f6e65206f6620746865206f74686572   
1  4120776f6e64657266756c206c6974746c652070726f64...   
2               492074686f75676874207468697320776173   
3  4261736963616c6c79207468657265277320612066616d...   
4   506574746572204d6174746569277320224c6f766520696e   

                                         deflate_hex  
0         789cf3cf4b55c84f5328c9005240a208002eb405bb  
1  789c735428cfcf4b492d4a2bcd51c8c92c29c949552828...  
2   789cf35428c9c82f4dcf2801d299c50ae589c5003dea06b0  
3  789c734a2cce4c4eccc9a95428c9482d4a552f56485448...  
4  789c0b482d29492d52f04d045299eac50a4a3ef965a90a...  


In [6]:
ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

print(ds_splits)

DatasetDict({
    train: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 6400
    })
    valid: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 800
    })
    test: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 800
    })
})


In [7]:
ds_splits['train'][0]

{'text': 'I just finished watching',
 'text_hex': '49206a7573742066696e6973686564207761746368696e67</s>',
 'deflate_hex': '789cf354c82a2d2e5148cbcccb2cce484d51284f2c49cec8cc4b07006d1d090f</s>'}

# Model

In [None]:
model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

In [9]:
@dataclass
class DataCollatorSeq2SeqWithPadding:
    tokenizer: BartTokenizer

    def __call__(self, dataset_elements) -> Dict[str, torch.Tensor]:

        # collect the input and output sequences
        input_text = [de["text_hex"] for de in dataset_elements]
        output_text = [de["deflate_hex"] for de in dataset_elements]

        # tokenize both sequences in batch so that it will be much faster!
        input_features = self.tokenizer(
            input_text,
            return_tensors="pt",  # output directly tensors
            padding=True, # add the padding on each sequence if needed
            truncation=True # If the input sequence is too long, truncate it
        )

        output_features = self.tokenizer(
            output_text,
            return_tensors="pt",
            padding=True,
            truncation=True
        )["input_ids"]  # here we only need the input_ids (output actually)

        output_features[output_features==self.tokenizer.pad_token_id] = -100 # cross entropy ignore index

        # This is the only parameters we need for the forward pass
        # to understand why, take a look to the BartForConditionalGeneration.forward method signature.
        batch = {
            "input_ids": input_features["input_ids"],
            "attention_mask": input_features["attention_mask"],
            "labels": output_features,
        }

        return batch

In [10]:
data_collator = DataCollatorSeq2SeqWithPadding(tokenizer)

# Trainer

In [14]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="temp",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    warmup_steps=500,
    max_steps=10000,
    evaluation_strategy="steps",
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=250,
    eval_steps=1000,  # evaluate on the validation every "eval_steps"
    logging_steps=1000,  # log standard metrics each "logging_steps"
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
    predict_with_generate=True,
    save_strategy = "no"
)

In [15]:
## UNUSED FOR NOW
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    print(f"Decoded preds = {decoded_preds}\n\n")
    print(f"Decoded labels = {decoded_labels}")

    #result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return 0

In [16]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=ds_splits["train"],
    eval_dataset=ds_splits["valid"],
    data_collator=data_collator,
    #compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

In [17]:
trainer.train()



Step,Training Loss,Validation Loss
1000,2.598,1.212662
2000,1.1272,0.83191
3000,0.8313,0.713727
4000,0.7107,0.651779
5000,0.6357,0.614619
6000,0.5897,0.59769
7000,0.5561,0.583713
8000,0.5339,0.578531
9000,0.5164,0.571998
10000,0.5069,0.570023


TrainOutput(global_step=10000, training_loss=0.8605920043945312, metrics={'train_runtime': 3008.2586, 'train_samples_per_second': 53.187, 'train_steps_per_second': 3.324, 'total_flos': 3360143295774720.0, 'train_loss': 0.8605920043945312, 'epoch': 25.0})

## Save model if necessary

In [39]:
trainer.save_model("/kaggle/working/bart_model")

In [40]:
!zip -r bart_model.zip /kaggle/working/bart_model

  adding: kaggle/working/bart_model/ (stored 0%)
  adding: kaggle/working/bart_model/merges.txt (deflated 53%)
  adding: kaggle/working/bart_model/training_args.bin (deflated 49%)
  adding: kaggle/working/bart_model/vocab.json (deflated 68%)
  adding: kaggle/working/bart_model/tokenizer_config.json (deflated 76%)
  adding: kaggle/working/bart_model/special_tokens_map.json (deflated 85%)
  adding: kaggle/working/bart_model/generation_config.json (deflated 47%)
  adding: kaggle/working/bart_model/config.json (deflated 64%)
  adding: kaggle/working/bart_model/model.safetensors (deflated 8%)


# TEST

In [18]:
test_dataloader = torch.utils.data.DataLoader(ds_splits["test"], batch_size=8, collate_fn=data_collator)

In [21]:
gold_strings = []
predicted_strings = []

model.eval()
for step, batch in enumerate(tqdm(test_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.inference_mode():

            generated_tokens = (
                model.generate(
                    input_ids=batch["input_ids"].to("cuda"),
                    max_new_tokens=255,
                )
                .cpu()
                .numpy()
            )

            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

            # turn subwords ids back into text
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            #print("Gold summary: ", decoded_labels)
            #print("Predicted summary: ", decoded_preds)
            
            gold_strings.extend(decoded_labels)
            predicted_strings.extend(decoded_preds)
            
            
    del generated_tokens, labels, batch

100%|██████████| 100/100 [01:31<00:00,  1.09it/s]


In [37]:
import nltk
from nltk.metrics.distance import edit_distance

assert len(predicted_strings) == len(gold_strings)

scores = []
pred_lenghts = []
gold_lenghts = []

for i in range(len(predicted_strings)):
    pred = predicted_strings[i]
    gold = gold_strings[i]
    scores.append(edit_distance(pred, gold))
    pred_lenghts.append(len(pred))
    gold_lenghts.append(len(gold))
    
print(f"Average prediction lenght is {np.mean(pred_lenghts)}")
print(f"Average gold lenght is {np.mean(gold_lenghts)}")
print(f"Average distance is {np.mean(scores)}")

Average prediction lenght is 55.55875
Average gold lenght is 55.3275
Average distance is 7.305
