In [9]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [10]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cpu


In [13]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text_hex']
        text = [int(x, 16) for x in text]

        #pad to reach 256
        padded_text = text + [20] * (256 - len(text))
        tensor_text = torch.tensor(padded_text, dtype=torch.float)

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]

        #pad to reach 256
        padded_hex_data = hex_data + [20] * (256 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data, dtype=torch.float)
        return tensor_text, tensor_hex_data

# Load the dataset
df = pd.read_csv('../../Datasets/shorthex2hex.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers = 4)
val_dataloader = DataLoader(val_dataset, batch_size=16, num_workers = 4)

In [14]:
class LSTM(pl.LightningModule):
    def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, num_layers=2, dropout_rate=0.1, learning_rate=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout_rate)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Multiply by 2 for bidirectional
        self.learning_rate = learning_rate
        self.loss = nn.MSELoss()
        
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x.to(device)
        y.to(device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x.to(device)
        y.to(device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        print('val_loss', loss)

model = LSTM().to(device)

# Define the trainer with 50 epochs and showing eval results every 10 epochs
trainer = pl.Trainer(max_epochs=10, enable_checkpointing=False, logger=False)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)



Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]val_loss tensor(326.6190)
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 11.11it/s]val_loss tensor(326.1970)
Epoch 0:   3%|▎         | 74/2500 [00:04<02:37, 15.44it/s]                 

/home/tommaiberone/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
import nltk
from nltk.metrics.distance import edit_distance

model.to('cpu')

def decimal_to_hexadecimal(decimal):    
    hex_digits = "0123456789abcdefghilmnopqrstuvzppppppppppppppppppppp"
    return hex_digits[decimal]

hex_predictions = []
hex_golds = []
scores = []

for text, gold in val_dataloader:

    prediction = model.forward(text).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    gold = [round(x) for x in gold[0].tolist()]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold]

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(hex_prediction)
    print(hex_gold)

    hex_predictions.append(hex_prediction)
    hex_golds.append(hex_gold)

    scores.append(edit_distance(hex_prediction, hex_gold))

# Convert hex strings to integers
print(f"Average distance is {np.mean(scores)}")

torch.Size([256, 83])
torch.Size([256, 1232])
torch.Size([256, 1232])
<s>Just reading why this show got canceled makes me rather steamed. This was a favorite of mine as a kid</s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
01111000100111000001110111001010110000010000110110000000001000000001000000000100110000000101011010110110000000100000101111110001011011010000001100010111010110001110000010000010010000001100001010011101000100101011101100110111111110101001110111001100011111101001100101100011010100101010001011110110100011000101010100011110011110000101000110000011100101011011000110010000100001110010001101001000000011110011110000011001110100011010010011010010110100001000100000101001010111100011100001100001010011100110100110001100000110111000111011101111001011110011000100001000100100101101110001100011101010100001001100100011101000010110100100100111011111101010