In [None]:
%%capture
%pip install torch pandas lightning trl wandb

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [None]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [None]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text_hex']
        text = [int(x, 16) for x in text]

        #pad to reach 256
        padded_text = text + [20] * (256 - len(text))
        tensor_text = torch.tensor(padded_text, dtype=torch.float)

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]

        #pad to reach 256
        padded_hex_data = hex_data + [20] * (256 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data, dtype=torch.float)
        return tensor_text, tensor_hex_data

# Load the dataset
df = pd.read_csv('shorthex2hex.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers = 2)
val_dataloader = DataLoader(val_dataset, batch_size=128, num_workers = 2)



In [None]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

class FeedForward(pl.LightningModule):
    def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, learning_rate=0.001,
                 dropout_rate=0.5, optimizer_type=AdamW, scheduler_type=StepLR,
                 scheduler_step_size=5, scheduler_gamma=0.1):
        super().__init__()
        self.save_hyperparameters()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer_type(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = self.hparams.scheduler_type(optimizer, step_size=self.hparams.scheduler_step_size, gamma=self.hparams.scheduler_gamma)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

# Initialize wandb logger
wandb_logger = WandbLogger(project='my_project', log_model='all')

# Assuming device, train_dataloader, and val_dataloader are defined
model = FeedForward().to(device)

# Define the trainer with wandb logging
trainer = pl.Trainer(max_epochs=50, logger=wandb_logger, enable_checkpointing=True, check_val_every_n_epoch=10)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type        | Params
----------------------------------------
0 | fc1     | Linear      | 65.8 K
1 | bn1     | BatchNorm1d | 512   
2 | dropout | Dropout     | 0     
3 | fc2     | Linear      | 65.8 K
4 | loss    | MSELoss     | 0     
----------------------------------------
132 K     Trainable params
0         Non-trainable params
132 K     Total params
0.528     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


In [None]:
import nltk
from nltk.metrics.distance import edit_distance

model.to('cpu')

def decimal_to_hexadecimal(decimal):
    hex_digits = "0123456789abcdefpppppppppppppppppp"
    return hex_digits[decimal]

hex_predictions = []
hex_golds = []
scores = []

for text, gold in val_dataloader:

    prediction = model.forward(text).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    gold = [round(x) for x in gold[0].tolist()]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold]

    hex_prediction = [x for x in hex_prediction if x != 'p']
    hex_gold = [x for x in hex_gold if x != 'p']

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(f"hex_prediction = {hex_prediction}")
    print(f"hex_gold = {hex_gold}\n")

    hex_predictions.append(hex_prediction)
    hex_golds.append(hex_gold)

    scores.append(edit_distance(hex_prediction, hex_gold))

# Convert hex strings to integers
print(f"Average distance is {np.mean(scores)}")

hex_prediction = 789c86a188676799b5d4a788998698959789849484758796a583534787958bd
hex_gold = 789cf3cc53c8ad54c84b2d4b2dd24dcd4bc9cc4b57282c4d2d2e01006b3a08f2

hex_prediction = 789c099c3c2c6a25577959889aba89596899795858787a79698aac8a685767899bde
hex_gold = 789c734c2b492d52284a4d4cc9cc4b5728c94855c82fca4ccfcc4bcc0100814f09c3

hex_prediction = 789c0b8b897c79a79a4a975a6859799a8a78595815344958478a
hex_gold = 789c2bcf48cd53f05448cb2c2a2e51c8484d2c4a01003d180688

hex_prediction = 678bb36798496949374438895b5a163658697958284657664836464556575779abdeff
hex_gold = 789cd3d3d32bc9485548cec8cc4951484c2ec92f2a56284f2d4a050061f0086f

hex_prediction = 789cf4976a478a8c4a897b4756898a4a68699969361335786969de
hex_gold = 789cf3cb2f5128cf2f2ac95028c9485528cb4c49cd0700468a071e

hex_prediction = 789ca676986b5a8a48559a9c7c5967886958475758577989ad
hex_gold = 789cf354c82a2d2e5148cbcccb2cce484d514804003644061b

hex_prediction = 789c49b8992946dacf3fa91953365a9a7958696859895847473447798aac
hex_gold = 789c0bc9c82c56c8