In [1]:
%%capture
!pip install lightning datasets

In [2]:
import pandas as pd
import torch
from datasets import load_dataset, Dataset, DatasetDict
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple
from dataclasses import dataclass
from pprint import pprint
import torch.nn as nn
import torch.optim as optim
import numpy as np
import lightning as L
import random

SEED = 999
BATCH_SIZE = 32
torch.manual_seed(SEED)
L.seed_everything(SEED)

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
df = pd.read_csv('/kaggle/input/hexadecimalzip/randomized_shorthex2hex.csv')
df = df[:12000]
print(df.head())

                      text                                        text_hex  \
0            this is not a                      74686973206973206e6f742061   
1  and gives a comforting,  616e64206769766573206120636f6d666f7274696e672c   
2  killer). While some may  6b696c6c6572292e205768696c6520736f6d65206d6179   
3          in his closet &                  696e2068697320636c6f7365742026   
4       film to watch. Mr.            66696c6d20746f2077617463682e204d722e   

                                         deflate_hex  
0             789c2bc9c82c5600a2bcfc1285440021fe04a7  
1  789c4bcc4b5148cf2c4b2d56485448cecf4dcb2f2ac9cc...  
2  789ccbceccc9492dd2d45308cfc8cc495528cecf4d55c8...  
3     789ccbcc53c8c82c5648cec92f4e2d515003002b16052c  
4  789c4bcbccc95528c957284f2c49ced053f02dd203003d...  


In [None]:
df['text_hex'] = 'S' + df['text_hex'] + 'E'
df['deflate_hex'] = 'S' + df['deflate_hex'] + 'E'

In [None]:
ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

ds_splits

DatasetDict({
    train: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 9600
    })
    valid: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 1200
    })
})

In [None]:
ds_splits['train'][0]

{'text': 'The D is the',
 'text_hex': 'S546865204420697320746865E',
 'deflate_hex': 'S789c0bc948557051c82c5628c9480500184c03e3E'}

In [None]:
token2id = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "P":16, "S": 17, "E":18 }

def create_id2token_vocab(token_to_id):
    id2token = {}
    for token, id in token_to_id.items():
        id2token[id] = token

    return id2token

id2token = create_id2token_vocab(token2id)
id2token

{0: '0',
 1: '1',
 2: '2',
 3: '3',
 4: '4',
 5: '5',
 6: '6',
 7: '7',
 8: '8',
 9: '9',
 10: 'a',
 11: 'b',
 12: 'c',
 13: 'd',
 14: 'e',
 15: 'f',
 16: 'P',
 17: 'S',
 18: 'E'}

In [None]:
def collate_fn(batch):

  def pad_sequences(sequences, maxlen, value=token2id['P']):
    padded_sequences = []
    for sequence in sequences:
        padded_sequence = sequence[:maxlen]
        padded_sequence.extend([value] * (maxlen - len(padded_sequence)))

        padded_sequence = sequence +  [value] * (maxlen - len(sequence))
        padded_sequences.append(padded_sequence)

    return padded_sequences


  texts = [elem['text_hex'] for elem in batch]
  encoded_hex = [[token2id[x] for x in hex] for hex in texts]


  outputs = [elem['deflate_hex'] for elem in batch]
  encoded_outputs = [[token2id[x] for x in hex] for hex in outputs]


  maxlen = 0
  for seq in encoded_hex:
    if len(seq) > maxlen:
      maxlen = len(seq)
  for seq in encoded_outputs:
    if len(seq) > maxlen:
      maxlen = len(seq)

  padded_encoded_hex = pad_sequences(encoded_hex, maxlen)
  padded_encoded_outputs = pad_sequences(encoded_outputs, maxlen)


  return {
      'inputs': torch.tensor(padded_encoded_hex),
      "outputs": torch.tensor(padded_encoded_outputs)
  }


In [None]:
import nltk
from nltk.metrics.distance import edit_distance

def decode_output(output):
    return ''.join([id2token[int(id)] for id in output])

def decode_input(input):
    return ''.join([id2token[int(id)] for id in input])

def evaluate(batch, _device, _print, _cycle, _training):
    model.eval()
    total_distance = 0
    total = 0
    
    distances_list = []
    
    #print(f"Batch inputs shape = {batch['inputs'].shape}")
    #print(f"Batch outputs shape = {batch['outputs'].shape}")

    x = batch["inputs"].transpose(0,1).to(_device)
    y = batch["outputs"].transpose(0,1).to(_device)

    y_hat = model(x, y)
    y_hat = torch.argmax(y_hat, dim=-1)
    
    y = y.transpose(0,1)
    y_hat = y_hat.transpose(0,1)
    
    assert len(y) == len(y_hat)
    
    for i in range(len(y)):
        output = decode_output(y[i])
        output_hat = decode_output(y_hat[i])

        output = [x for x in output if x != "P"]
        output_hat = [x for x in output_hat if x != "P"]

        first_eos_index = 0
        for i in range(len(output_hat)):
            if output_hat[i] == "E":
                first_eos_index = i
                break

        # REMOVE START OF SEQUENCE TOKEN
        output = output[1:]
        output_hat = output_hat[1:first_eos_index]
        distance = edit_distance(output, output_hat)
        distances_list.append(distance)

        if _print:
            print(f"output = {output}")
            print(f"output_hat = {output_hat}")

        total_distance += distance
        total += 1

        if distance == 0:
            print(f"DISTANCE = 0!")
            print(f"output = {output}")
            print(f"output_hat = {output_hat}")
            
    if _training:
        return total_distance/total
    
    return (total_distance/total, distances_list)

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import random

torch.set_printoptions(profile="full")

EPOCHS = 100
LR = 3e-4
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 4
DROPOUT = 0.4
BIDIRECTIONAL = False
MAX_SEQ_LEN = 256
BATCH_SIZE = 1024

# Global variable to choose the RNN type
RNN_TYPE = 'RNN'  # Options: 'LSTM', 'GRU', 'RNN'

class Seq2Seq(pl.LightningModule):
    def __init__(self, vocab_len, embedding_dim, hidden_dim, output_dim, num_layers, bidirectional, dropout):
        super(Seq2Seq, self).__init__()
        self.rnn_type = RNN_TYPE
        self.embedding = nn.Embedding(vocab_len, embedding_dim, padding_idx=token2id['P'])
        
        if self.rnn_type == 'LSTM':
            rnn_cell = nn.LSTM
        elif self.rnn_type == 'GRU':
            rnn_cell = nn.GRU
        else:  # Default to RNN if neither LSTM nor GRU is selected
            rnn_cell = nn.RNN
        
        self.encoder_rnn = rnn_cell(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout if num_layers > 1 else 0)
        self.decoder_rnn = rnn_cell(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout if num_layers > 1 else 0)
        
        self.dropout = nn.Dropout(dropout)
        self.output_dim = output_dim
        self.linear = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        target_len = target.shape[0]
        batch_size = target.shape[1]
        target_vocab_size = self.output_dim

        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

        x = self.dropout(self.embedding(source))
        rnn_output, h = self.encoder_rnn(x)

        x = target[0]
        for t in range(1, target_len):
            x = self.dropout(self.embedding(x.unsqueeze(0)))
            out, h = self.decoder_rnn(x, h if self.rnn_type in ['LSTM', 'GRU'] else None)
            predictions = self.linear(out)
            predictions = predictions.squeeze(0)
            outputs[t] = predictions
            pred = predictions.argmax(1)
            x = target[t] if random.random() < teacher_forcing_ratio else pred

        return outputs

    def step(self, batch):
        inputs, targets = batch['inputs'], batch['outputs']
        inputs = inputs.transpose(0, 1)
        targets = targets.transpose(0, 1)

        output = self(inputs, targets)
        output_dim = output.shape[-1]

        output = output.reshape(-1, output_dim)
        targets = targets.reshape(-1)
        
        return (output, targets)
    
    def training_step(self, batch):
        loss = self.criterion(*self.step(batch))
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch):
        loss = self.criterion(*self.step(batch))
        self.log('val_loss', loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

model = Seq2Seq(len(token2id), EMBEDDING_DIM, HIDDEN_DIM, len(token2id), NUM_LAYERS, BIDIRECTIONAL, DROPOUT)

train_dataloader = DataLoader(ds_splits['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2)
val_dataloader = DataLoader(ds_splits['valid'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2)
test_dataloader = DataLoader(ds_splits['test'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2)

In [None]:
# Train the model
trainer = pl.Trainer(max_epochs=EPOCHS)
trainer.fit(model, train_dataloader, val_dataloader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

# TEST


In [None]:
from tqdm import tqdm

total_distances = []
for batch in tqdm(test_dataloader):
    _, distances_list = evaluate(batch = batch, _device = device, _print = False, _cycle = True, _training = False)
    total_distances.extend(distances_list)

print(f"Total sentences = {len(total_distances)}")
print(f"Average TEST distance = {np.mean(total_distances)}")
    


  0%|          | 0/38 [00:00<?, ?it/s][A
  3%|▎         | 1/38 [00:00<00:07,  4.74it/s][A
  5%|▌         | 2/38 [00:00<00:05,  6.03it/s][A
  8%|▊         | 3/38 [00:00<00:05,  6.96it/s][A
 11%|█         | 4/38 [00:00<00:05,  6.80it/s][A
 13%|█▎        | 5/38 [00:00<00:04,  6.72it/s][A
 16%|█▌        | 6/38 [00:00<00:04,  6.98it/s][A
 18%|█▊        | 7/38 [00:01<00:04,  7.10it/s][A
 21%|██        | 8/38 [00:01<00:04,  7.25it/s][A
 24%|██▎       | 9/38 [00:01<00:04,  7.22it/s][A
 26%|██▋       | 10/38 [00:01<00:03,  7.32it/s][A
 29%|██▉       | 11/38 [00:01<00:03,  7.32it/s][A
 32%|███▏      | 12/38 [00:01<00:03,  6.95it/s][A
 34%|███▍      | 13/38 [00:01<00:03,  7.17it/s][A
 37%|███▋      | 14/38 [00:02<00:03,  6.78it/s][A
 39%|███▉      | 15/38 [00:02<00:03,  6.94it/s][A
 42%|████▏     | 16/38 [00:02<00:03,  6.70it/s][A
 45%|████▍     | 17/38 [00:02<00:03,  6.99it/s][A
 47%|████▋     | 18/38 [00:02<00:02,  7.10it/s][A
 50%|█████     | 19/38 [00:02<00:02,  7.12it/s]

Total sentences = 1200
Average TEST distance = 59.33833333333333



