# Imports and installation


In [1]:
%%capture
!pip install lightning datasets

In [2]:
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
import torch.nn as nn
import lightning as L
import random

SEED = 999
BATCH_SIZE = 64
torch.manual_seed(SEED)
L.seed_everything(SEED)

INFO: Seed set to 999


999

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Data Preparation

In [4]:
df = pd.read_csv('/kaggle/input/shorthex2hex/shorthex2hex.csv')

In [5]:
df.head()

Unnamed: 0,text,text_hex,deflate_hex
0,One of the other,4f6e65206f6620746865206f74686572,789cf3cf4b55c84f5328c9005240a208002eb405bb
1,A wonderful little production.,4120776f6e64657266756c206c6974746c652070726f64...,789c735428cfcf4b492d4a2bcd51c8c92c29c949552828...
2,I thought this was,492074686f75676874207468697320776173,789cf35428c9c82f4dcf2801d299c50ae589c5003dea06b0
3,Basically there's a family,4261736963616c6c79207468657265277320612066616d...,789c734a2cce4c4eccc9a95428c9482d4a552f56485448...
4,"Petter Mattei's ""Love in",506574746572204d6174746569277320224c6f766520696e,789c0b482d29492d52f04d045299eac50a4a3ef965a90a...


Instead of using the standard \<EOS> and \<SOS> tags we're using the letter S and E since they are not present in the vocabulary

In [6]:
df['text_hex'] = 'S' + df['text_hex'] + 'E'
df['deflate_hex'] = 'S' + df['deflate_hex'] + 'E'

In [7]:
ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

ds_splits

DatasetDict({
    train: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 40000
    })
    valid: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 5000
    })
})

# Tokenizzare in caratteri singoli o in sequenze di caratteri?

In [8]:
ds_splits['train'][0]

{'text': 'First of all, this',
 'text_hex': 'S4669727374206f6620616c6c2c2074686973E',
 'deflate_hex': 'S789c73cb2c2a2e51c84f5348ccc9d15128c9c82c06003c54065bE'}

## Data tokenization

In [9]:
token2id = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "P":16, "S": 17, "E":18 }

In [10]:
def create_id2token_vocab(token_to_id):
    id2token = {}
    for token, id in token_to_id.items():
        id2token[id] = token

    return id2token

id2token = create_id2token_vocab(token2id)
id2token

{0: '0',
 1: '1',
 2: '2',
 3: '3',
 4: '4',
 5: '5',
 6: '6',
 7: '7',
 8: '8',
 9: '9',
 10: 'a',
 11: 'b',
 12: 'c',
 13: 'd',
 14: 'e',
 15: 'f',
 16: 'P',
 17: 'S',
 18: 'E'}

In [11]:
# Corrected collate_fn function
def collate_fn(batch):
    def pad_sequences(sequences, maxlen, value=token2id['P']):
        return [sequence + [value] * (maxlen - len(sequence)) for sequence in sequences]

    texts = [elem['text_hex'] for elem in batch]
    encoded_hex = [[token2id[char] for char in text] for text in texts]

    outputs = [elem['deflate_hex'] for elem in batch]
    encoded_outputs = [[token2id[char] for char in output] for output in outputs]

    maxlen = max(MAX_SEQ_LEN, max(max(len(seq) for seq in encoded_hex), max(len(seq) for seq in encoded_outputs)))

    padded_encoded_hex = pad_sequences(encoded_hex, maxlen)
    padded_encoded_outputs = pad_sequences(encoded_outputs, maxlen)

    return {
        'inputs': torch.tensor(padded_encoded_hex),
        'outputs': torch.tensor(padded_encoded_outputs)
    }

MAX_SEQ_LEN = 256

# Model

In [12]:
# Encoder with Bidirectional GRU
class Encoder(nn.Module):
    def __init__(self, vocab_len, embedding_dim, hidden_dim, num_layers, bidirectional, dropout):
        super().__init__()

        self.embedding = nn.Embedding(vocab_len, embedding_dim, padding_idx=token2id['P'])
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, batch):
        embedded = self.dropout(self.embedding(batch))
        outputs, hidden = self.gru(embedded)
        return hidden

# Decoder adjusted for bidirectional input
class Decoder(nn.Module):
    def __init__(self, vocab_len, embedding_dim, hidden_dim, output_dim, num_layers, bidirectional, dropout):
        super().__init__()

        self.hidden_dim = hidden_dim * 2 if bidirectional else hidden_dim  # Adjusted for bidirectional
        self.num_layers = num_layers
        self.bidirectional = bidirectional

        self.embedding = nn.Embedding(vocab_len, embedding_dim, padding_idx=token2id['P'])
        self.gru = nn.GRU(embedding_dim, self.hidden_dim, num_layers=num_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(self.hidden_dim, output_dim)

    def forward(self, x, hidden):
        x = x.unsqueeze(0)
        embedded = self.dropout(self.embedding(x))
        output, hidden = self.gru(embedded, hidden)
        logits = self.linear(output.squeeze(0))
        return logits, hidden

# Training

In [None]:
import pytorch_lightning as pl
import torch
import random
from nltk.metrics.distance import edit_distance

def decode_sequence(sequence_ids, id2token):
    decoded_sequence = []
    for id in sequence_ids:
        if id == token2id['E']:  # Check for ending token
            break  # Stop decoding once padding token is encountered
        decoded_sequence.append(id2token.get(id, ''))
    return ''.join(decoded_sequence)

EPOCHS = 3
LR = 5e-3
EMBEDDING_DIM = 256
HIDDEN_DIM = 1024
NUM_LAYERS = 4
DROPOUT = 0.3
BIDIRECTIONAL = True


# Combined EncoderDecoder model
class EncoderDecoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(len(token2id), EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS, BIDIRECTIONAL, DROPOUT)
        self.decoder = Decoder(len(token2id), EMBEDDING_DIM, HIDDEN_DIM, len(token2id), NUM_LAYERS, BIDIRECTIONAL, DROPOUT)
        self.criterion = nn.CrossEntropyLoss(ignore_index=token2id['P'])

    def forward(self, source, target, teacher_forcing_ratio=0.5, training=False):
        target_len = target.shape[0]
        batch_size = target.shape[1]
        target_vocab_size = len(token2id)

        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(source.device)
        encoder_hidden = self.encoder(source)

        # Adjust for bidirectional
        if BIDIRECTIONAL:
            h = torch.cat((encoder_hidden[-2,:,:], encoder_hidden[-1,:,:]), dim = 1)
        else:
            h = encoder_hidden[-1]

        h = h.unsqueeze(0).repeat(NUM_LAYERS, 1, 1)  # Repeat hidden state

        x = target[0]

        for t in range(1, target_len):
            output, h = self.decoder(x, h)
            outputs[t] = output
            if training:
                # Teacher forcing: next input is current target
                x = target[t] if random.random() < teacher_forcing_ratio else output.argmax(1)
            else:
                # Inference mode: next input is current output
                x = output.argmax(1)

        return outputs

    def training_step(self, batch, batch_idx):
        inputs, targets = batch['inputs'], batch['outputs']
        outputs = self(inputs, targets, training=True)
        output_dim = outputs.shape[-1]
        outputs = outputs[1:].view(-1, output_dim)
        targets = targets[1:].view(-1)
        loss = self.criterion(outputs, targets)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def process_batch(self, batch):
        inputs, targets = batch['inputs'], batch['outputs']
        outputs = self(inputs, targets)  # Assuming outputs are logits
        output_dim = outputs.shape[-1]

        # Convert outputs to predicted indices
        _, predicted_indices = torch.max(outputs, dim=2)

        edit_distances = []
        
        for idx in range(inputs.size(0)):  # Iterate over each example in the batch
            predicted_seq = decode_sequence(predicted_indices[idx, :].tolist(), id2token)
            target_seq = decode_sequence(targets[idx, :].tolist(), id2token)

            # Calculate edit distance for the current example
            edit_dist = edit_distance(predicted_seq, target_seq)
            edit_distances.append(edit_dist)

            # Optionally, print/log the sequences and their edit distance
            print(f'Predicted: {predicted_seq}, Target: {target_seq}, Edit Distance: {edit_dist}')

        # Return the average edit distance for the batch
        return sum(edit_distances) / len(edit_distances)

    def validation_step(self, batch, batch_idx):
        avg_edit_distance = self.process_batch(batch)
        self.log('val_avg_edit_distance', avg_edit_distance, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        avg_edit_distance = self.process_batch(batch)
        self.log('test_avg_edit_distance', avg_edit_distance, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

class HexDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, test_dataset):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers = 2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, num_workers = 2, collate_fn=collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# Initialize the data module and model
data_module = HexDataModule(ds_splits['train'], ds_splits['valid'], ds_splits['test'])
model = EncoderDecoder()

# Train the model
trainer = pl.Trainer(max_epochs=EPOCHS)
trainer.fit(model, datamodule=data_module)

trainer.test(datamodule=data_module)


2024-02-03 19:46:29.432547: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-03 19:46:29.432625: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-03 19:46:29.434105: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Predicted: 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789cf354c8c9cc4e4d5108c92f50702fcdd303002c8e053e, Edit Distance: 249
Predicted: 14PPcP4P0b1P0b0P8P1b1P0P8P1P1P4P9P2P428P4P080b0808080b0008209000000090090900000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789c0bcf48cd53f05448cb2c2a2e51c8484d2c4a01003ad80668, Edit Distance: 241
Predicted: 12P28P, Target: S789cf354c8484c51c8cb2fc94c4e4d5128c9c82c060039e10675, Edit Distance: 49
Predicted: 128282, Target: S789c0bc9c82c5628ce4dccc9d151282ccd4c2dd151c8482ccacdcfcbcc2f2d06009d7f0ad7, Edit Distance: 69
Predicted: 128282282888988282121228829812229222228229829299829299992229999999999

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Predicted: 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789cf354c8c9cc4e4d5108c92f50702fcdd303002c8e053e, Edit Distance: 249
Predicted: S789c982828282c282c2c08282c2c2c2c2c28000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789c0bcf48cd53f05448cb2c2a2e51c8484d2c4a01003ad80668, Edit Distance: 239
Predicted: S789cf84848484c4848280848484828482828000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789cf354c8484c51c8cb2fc9

Validation: |          | 0/? [00:00<?, ?it/s]

Predicted: 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789cf354c8c9cc4e4d5108c92f50702fcdd303002c8e053e, Edit Distance: 249
Predicted: S78ccc8c8c8c8c8c8c8c8c8c8c8c8c8ccc8cccccc0ccc0c0c0c0c0c0c0c0c0c0c0c0c000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789c0bcf48cd53f05448cb2c2a2e51c8484d2c4a01003ad80668, Edit Distance: 239
Predicted: S78ccc8c8c8c8c8c8c8c8c8c8c8c8c8cccccccccc0c0c0c0c0c0c0c0c0c0c0c0c0c0c000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, Target: S789cf354c8484c51c8cb2fc9