# Imports and installation


In [1]:
%%capture
!pip install lightning datasets

In [2]:
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader
import torch.nn as nn
import lightning as L
import random

SEED = 124
BATCH_SIZE = 128
HIDDEN_SIZE = 512
NUM_LAYERS= 2
torch.manual_seed(SEED)
L.seed_everything(SEED)

INFO: Seed set to 124


124

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Data Preparation

In [5]:
df = pd.read_csv('/kaggle/input/shortnew/shorthex2hex.csv')
df = df[:40960]

In [None]:
df.head()

Instead of using the standard \<EOS> and \<SOS> tags we're using the letter S and E since they are not present in the vocabulary

In [7]:
ds = Dataset.from_pandas(df)
ds_train_test = ds.train_test_split(test_size=0.2, seed=SEED)
ds_test_dev = ds_train_test['test'].train_test_split(test_size=0.5, seed=SEED)
ds_splits = DatasetDict({
    'train': ds_train_test['train'],
    'valid': ds_test_dev['train'],
    'test': ds_test_dev['test']
})

ds_splits

DatasetDict({
    train: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 32768
    })
    valid: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 4096
    })
    test: Dataset({
        features: ['text', 'text_hex', 'deflate_hex'],
        num_rows: 4096
    })
})

# Tokenizzare in caratteri singoli o in sequenze di caratteri?

In [None]:
ds_splits['train'][0]

## Data tokenization

In [9]:
token2id = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "P":16}

In [10]:
def create_id2token_vocab(token_to_id):
    id2token = {}
    for token, id in token_to_id.items():
        id2token[id] = token

    return id2token

id2token = create_id2token_vocab(token2id)
id2token

{0: '0',
 1: '1',
 2: '2',
 3: '3',
 4: '4',
 5: '5',
 6: '6',
 7: '7',
 8: '8',
 9: '9',
 10: 'a',
 11: 'b',
 12: 'c',
 13: 'd',
 14: 'e',
 15: 'f',
 16: 'P'}

In [11]:
# Corrected collate_fn function
def collate_fn(batch):

    texts = [elem['text_hex'] for elem in batch]
    encoded_hexs = [[token2id[char] for char in text] for text in texts]

    outputs = [elem['deflate_hex'] for elem in batch]
    encoded_outputs = [[token2id[char] for char in output] for output in outputs]

    # # Pad the sequences to 256 chars with the padding token
    padded_hex = [torch.Tensor(encoded_hex + [token2id["P"]] * (256 - len(encoded_hex))) for encoded_hex in encoded_hexs]
    padded_outputs = [torch.Tensor(encoded_output + [token2id["P"]] * (256 - len(encoded_output))) for encoded_output in encoded_outputs]

    # # Stack the sequences
    padded_hex = torch.stack(padded_hex).long()
    padded_outputs = torch.stack(padded_outputs).long()


    return {
        'inputs': padded_hex,
        'outputs': padded_outputs
    }

MAX_SEQ_LEN = 256

In [12]:
train_dataloader = DataLoader(ds_splits['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(ds_splits['valid'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

torch.set_printoptions(profile="full")

for batch in train_dataloader:
    print(batch['inputs'].shape)
    break

torch.Size([128, 256])


In [13]:
import nltk
from nltk.metrics.distance import edit_distance

def decode_output(output):
    return ''.join([id2token[int(id)] for id in output])

def decode_input(input):
    return ''.join([id2token[int(id)] for id in input])

def evaluate(_device, _print):
    model.eval()
    total_distance = 0
    total = 0

    for batch in val_dataloader:
        x = batch["inputs"].to(_device)
        y = batch["outputs"].to(_device)

        y_hat = model(x)
        y_hat = torch.argmax(y_hat, dim=-1)

        output = decode_output(y[0])
        output_hat = decode_output(y_hat[0])

        output = [x for x in output if x != "P"]
        output_hat = [x for x in output_hat if x != "P"]
        
        distance = edit_distance(output, output_hat)
        
        if _print:
            print(f"output = {output}")
            print(f"output_hat = {output_hat}")

        return distance

In [18]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

INPUT_DIM = 256
EMBED_DIM = 128
HIDDEN_DIM = 512
OUTPUT_DIM = len(token2id)
LEARNING_RATE = 1e-3
DROPOUT_RATE = 0.5
NUM_HEADS = 4 
ATTENTION_USED = True


class FeedForward(pl.LightningModule):
    def __init__(self, input_dim=INPUT_DIM, embed_dim = EMBED_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM, learning_rate=LEARNING_RATE,
                 dropout_rate=DROPOUT_RATE, optimizer_type=AdamW, scheduler_type=StepLR,
                 scheduler_step_size=5, scheduler_gamma=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.embed = nn.Embedding(input_dim, embed_dim)
        self.self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=NUM_HEADS, dropout=dropout_rate)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim//2)
        self.norm2 = nn.LayerNorm(hidden_dim//2)
        self.fc3 = nn.Linear(hidden_dim//2, output_dim)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        # Embedding
        x = self.embed(x)  # Shape: [Seq_len, Batch, Embed_dim]
        
        if (ATTENTION_USED):
            # Transpose x to match the input shape requirement of nn.MultiheadAttention
            x = x.transpose(0, 1)  # Shape: [Batch, Seq_len, Embed_dim]

            # Self-attention
            attn_output, _ = self.self_attention(x, x, x)
            # Transpose back to match the shape for the following layers
            x = attn_output.transpose(0, 1)  # Shape: [Seq_len, Batch, Embed_dim]
        
        # Fully connected layers
        x = torch.relu(self.norm1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.norm2(self.fc2(x)))
        x = self.fc3(x)
        
        return x

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer_type(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = self.hparams.scheduler_type(optimizer, step_size=self.hparams.scheduler_step_size, gamma=self.hparams.scheduler_gamma)
        return [optimizer], [scheduler]

    def step(self, batch, batch_idx):
        x = batch["inputs"]
        y = batch["outputs"]
        y = y.view(y.shape[0] * y.shape[1])
        y_hat = self(x)
        y_hat = y_hat.view(y_hat.shape[0] * y_hat.shape[1], y_hat.shape[2])
        loss = self.loss(y_hat, y)
        return loss

    def training_step(self, batch, batch_idx):  
        loss = self.step(batch, batch_idx)
        self.log('train_loss', loss, prog_bar = True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx)
        self.log('val_loss', loss, prog_bar = True)
        self.log("edit_distance", evaluate(_device = device, _print = False), prog_bar = True)
        return loss

# Assuming device, train_dataloader, and val_dataloader are defined
model = FeedForward().to(device)

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

In [16]:
print(evaluate(_device = device, _print = True))

output = ['7', '8', '9', 'c', 'f', '3', '5', '4', '2', '8', '4', 'f', '2', 'c', '5', '6', 'c', '8', 'c', 'c', 'c', 'b', '2', 'c', 'c', '9', '4', 'c', 'c', 'c', 'c', '9', 'a', '9', '5', '4', '4', '8', 'c', 'b', '2', 'f', '4', 'a', '4', 'e', '4', 'd', '0', '1', '0', '0', '5', 'a', 'c', '6', '0', '8', '3', '7']
output_hat = ['7', '8', 'c', 'c', 'c', 'c', 'c', '8', 'c', '9', 'c', 'c', 'c', '8', 'c', 'c', 'c', '8', 'c', '7', 'c', '8', 'c', '8', 'c', 'c', 'c', 'c', 'c', '8', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', 'c', '9', 'c', 'c', 'c', '7']
42


# Model

# Training