In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cpu


In [33]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        def one_hot_encode(hex_string):
            hex_to_int = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7,
                          '8': 8, '9': 9, 'a': 10, 'b': 11, 'c': 12, 'd': 13, 'e': 14, 'f': 15}
            one_hot_encoded = []
            for char in hex_string:
                vec = [0] * 16
                vec[hex_to_int[char.lower()]] = 1
                one_hot_encoded.extend(vec)


            return one_hot_encoded
            # Your one-hot encoding implementation

        text = self.dataframe.iloc[idx]['text_hex']
        one_hot_text = one_hot_encode(text)
        tensor_text = torch.tensor(one_hot_text, dtype=torch.float)
        # print(f"tensor_text.shape: {tensor_text.shape}")

        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        one_hot_hex_data = one_hot_encode(hex_data)
        tensor_hex_data = torch.tensor(one_hot_hex_data, dtype=torch.float)
        # print(f"tensor_hex_data.shape: {tensor_hex_data.shape}")

        return tensor_text, tensor_hex_data

    def collate_fn(batch):
        texts, hex_data = zip(*batch)
        texts_padded = rnn_utils.pad_sequence(texts, batch_first=True)
        hex_data_padded = rnn_utils.pad_sequence(hex_data, batch_first=True)
        lengths = torch.tensor([len(x) for x in texts])
        return texts_padded, hex_data_padded, lengths


# Load the dataset
df = pd.read_csv('../../Datasets/shorthex2hex.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=TextHexDataset.collate_fn, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=16, collate_fn=TextHexDataset.collate_fn, num_workers=4)

In [46]:
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.nn.utils.rnn as rnn_utils

class Seq2SeqModel(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers):
        super(Seq2SeqModel, self).__init__()
        self.encoder = nn.LSTM(9504, hidden_dim, n_layers, batch_first=True)
        self.decoder = nn.LSTM(hidden_dim, output_dim, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, x_lengths):
        
        print(f"FORWARD: x.shape = {x.shape}")

        # Pack padded batch of sequences for RNN module
        packed = rnn_utils.pack_padded_sequence(x, x_lengths, batch_first=True, enforce_sorted=False)

        print(f"FORWARD: packed.data.shape = {packed.data.shape}")

        # Forward pass through LSTM
        _, hidden = self.encoder(packed)
        
        # Unpack padding
        outputs, _ = rnn_utils.pad_packed_sequence(hidden[0], batch_first=True)

        print("Shape before linear layer:", outputs.shape)


        # Decode the hidden state of the last time step
        outputs = self.fc(outputs[:, -1, :])
        return outputs
    
    def training_step(self, batch, batch_idx):

        print(f"TRAINING STEP: batch = {batch}")

        x, y, x_lengths = batch
        print(f"TRAINING STEP: x.shape = {x.shape}")
        print(f"TRAINING STEP: y.shape = {y.shape}")
        print(f"TRAINING STEP: x_lengths = {x_lengths}")

        y_hat = self.forward(x, x_lengths)
        print(f"TRAINING STEP: y_hat.shape = {y_hat.shape}")

        mask = torch.arange(y.size(1))[None, :] < x_lengths[:, None]
        print(f"TRAINING STEP: mask.shape = {mask.shape}")

        train_loss = F.mse_loss(y_hat[mask], y[mask])
        print(f"TRAINING STEP: train_loss = {train_loss}")

        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx):

        x, y, x_lengths = batch
        print(f"VALIDATION STEP: x.shape = {x.shape}")
        print(f"VALIDATION STEP: y.shape = {y.shape}")
        print(f"VALIDATION STEP: x_lengths = {x_lengths}")

        y_hat = self.forward(x, x_lengths)
        print(f"VALIDATION STEP: y_hat.shape = {y_hat.shape}")

        mask = torch.arange(y.size(1))[None, :] < x_lengths[:, None]
        print(f"VALIDATION STEP: mask = {mask}")

        val_loss = F.mse_loss(y_hat[mask], y[mask])
        print(f"VALIDATION STEP: val_loss = {val_loss}")
        
        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

# Model parameters
input_dim = 16  # One-hot encoded hex digits
hidden_dim = 4  # Adjust as needed
output_dim = 16  # One-hot encoded hex digits
n_layers = 2    # Adjust as needed

# Create the model
model = Seq2SeqModel(input_dim, hidden_dim, output_dim, n_layers)

# Training
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)



Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]VALIDATION STEP: x.shape = torch.Size([16, 992])
VALIDATION STEP: y.shape = torch.Size([16, 1248])
VALIDATION STEP: x_lengths = tensor([704, 896, 448, 448, 640, 800, 320, 448, 992, 512, 608, 512, 576, 480,
        736, 384])
FORWARD: x.shape = torch.Size([16, 992])
FORWARD: packed.data.shape = torch.Size([9504])


RuntimeError: start (16) + length (16) exceeds dimension size (16).

In [None]:
import nltk
from nltk.metrics.distance import edit_distance

model.to('cpu')

def decimal_to_hexadecimal(decimal):
    hex_digits = "0123456789abcdefghilmnopqrstuvzppppppppppppppppppppp"
    return hex_digits[decimal]

hex_predictions = []
hex_golds = []
scores = []

for text, gold in val_dataloader:

    prediction = model.forward(text).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    gold = [round(x) for x in gold[0].tolist()]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold]

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(hex_prediction)
    print(hex_gold)

    hex_predictions.append(hex_prediction)
    hex_golds.append(hex_gold)

    scores.append(edit_distance(hex_prediction, hex_gold))

# Convert hex strings to integers
print(f"Average distance is {np.mean(scores)}")



789c8689796a6969796979696979697979695836263447789befiimmmmmmmmlmlmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm
789ccbcb2f5128c957c8c9cf4b57484c2b492d020038880657mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm
789c8689696a6969796969696979796979797969583626243668abdeiilmmmnnnmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm
789c0bc9c82c5648cbccc955282e492c2a2956c82f2d01004deb079ammmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm

Average distance is 57.6128
