In [1]:
%%capture
%pip install torch pandas lightning trl wandb

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [3]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text_hex']
        text = [int(x, 16) for x in text]

        #pad to reach 256
        padded_text = text + [20] * (256 - len(text))
        tensor_text = torch.tensor(padded_text, dtype=torch.float)

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]

        #pad to reach 256
        padded_hex_data = hex_data + [20] * (256 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data, dtype=torch.float)
        return tensor_text, tensor_hex_data

# Load the dataset
df = pd.read_csv('/kaggle/input/shorthex2hex/shorthex2hex.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers = 2)
val_dataloader = DataLoader(val_dataset, batch_size=128, num_workers = 2)

In [4]:
import wandb
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # print(f"x.shape in positional encoding: {x.shape}")
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class Transformer(pl.LightningModule):
    def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, num_heads=4, num_layers=2, dropout_rate=0.1, learning_rate=0.1):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(input_dim, dropout_rate)
        encoder_layers = nn.TransformerEncoderLayer(input_dim, num_heads, hidden_dim, dropout_rate)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.fc = nn.Linear(input_dim, output_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.learning_rate = learning_rate
        self.loss = nn.MSELoss()
        
        self.save_hyperparameters()

    def forward(self, x):
        x = x.transpose(0, 1).unsqueeze(-1)  # Now x is [256, 128, 1]
        # print(f"x.shape: {x.shape}")
        x = self.pos_encoder(x)
        # print(f"x.shape: {x.shape}")
        x = self.transformer_encoder(x)
        # print(f"x.shape: {x.shape}")
        x = x.mean(dim=0)  # Aggregate over the sequence dimension
        x = self.fc(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

# Initialize wandb logger
wandb_logger = WandbLogger(project='my_project', log_model='all')

# Assuming device, train_dataloader, and val_dataloader are defined
model = Transformer().to(device)

# Define the trainer with wandb logging
trainer = pl.Trainer(max_epochs=50, logger=wandb_logger, enable_checkpointing=True, check_val_every_n_epoch=10)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [27]:
import nltk
from nltk.metrics.distance import edit_distance

model.to('cpu')

def decimal_to_hexadecimal(decimal):
    hex_digits = "0123456789abcdefpppppppppppppppppp"
    return hex_digits[decimal]

hex_predictions = []
hex_golds = []
scores = []

for text, gold in val_dataloader:

    prediction = model.forward(text).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    gold = [round(x) for x in gold[0].tolist()]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold]

    hex_prediction = [x for x in hex_prediction if x != 'p']
    hex_gold = [x for x in hex_gold if x != 'p']

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(f"hex_prediction = {hex_prediction}")
    print(f"hex_gold = {hex_gold}\n")

    hex_predictions.append(hex_prediction)
    hex_golds.append(hex_gold)

    scores.append(edit_distance(hex_prediction, hex_gold))

# Convert hex strings to integers
print(f"Average distance is {np.mean(scores)}")

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789cf3cc53c8ad54c84b2d4b2dd24dcd4bc9cc4b57282c4d2d2e01006b3a08f2

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789c0bcf48cd53c85448cb2c2a2e51284e2c07002ffb05cf

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789c0bc9c82c56c8cd2fcb4c5500328a0b5293331373722a01662508bb

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789cf35448cb4c2f2d4a4d5128c9485528cfc8cf49050042e906f0

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789cf35428c92f49ccc9a954484c2f4a4d5528c9482c01004c060768

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789cf3c92c4e54c82c564854c8c82f49cd010028e60543

hex_prediction = 789c7689796a6969796979696969696969695858586869899babcddeeff
hex_gold = 789c0bc9c82c5600a24485f28cfc94d2bcbccc12004389