In [1]:
%%capture
%pip install torch pandas lightning trl wandb

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [3]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text_hex']
        text = [int(x, 16) for x in text]

        #pad to reach 256
        padded_text = text + [20] * (256 - len(text))
        tensor_text = torch.tensor(padded_text, dtype=torch.float)

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]

        #pad to reach 256
        padded_hex_data = hex_data + [20] * (256 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data, dtype=torch.float)
        return tensor_text, tensor_hex_data

# Load the dataset
df = pd.read_csv('/kaggle/input/shorthex2hex/shorthex2hex.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers = 2)
val_dataloader = DataLoader(val_dataset, batch_size=128, num_workers = 2)

In [4]:
import wandb
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [11]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

class EnhancedConvNet(pl.LightningModule):
    def __init__(self, input_channels=1, hidden_channels=[128, 256], output_dim=256, kernel_sizes=[3, 3], learning_rate=0.001,
                 dropout_rate=0.5, optimizer_type=AdamW, scheduler_type=StepLR,
                 scheduler_step_size=5, scheduler_gamma=0.1, input_length=256):
        super().__init__()
        self.save_hyperparameters()
        
        layers = []
        current_channels = input_channels
        current_length = input_length
        for hidden_channel, kernel_size in zip(hidden_channels, kernel_sizes):
            layers.append(nn.Conv1d(current_channels, hidden_channel, kernel_size))
            layers.append(nn.BatchNorm1d(hidden_channel))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            # Adjust input length for next layer
            current_length = current_length - kernel_size + 1
            current_channels = hidden_channel
        
        self.conv_layers = nn.Sequential(*layers)
        
        # Calculate the output size after the convolution operations
        self.fc1 = nn.Linear(hidden_channels[-1] * current_length, output_dim)
        self.loss = nn.MSELoss()

    def forward(self, x):
        x = x.unsqueeze(1)  # Adds a channel dimension, converting (N, L) to (N, 1, L)
        x = self.conv_layers(x)
        # Flatten the output for the linear layer
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc1(x)
        return x

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer_type(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = self.hparams.scheduler_type(optimizer, step_size=self.hparams.scheduler_step_size, gamma=self.hparams.scheduler_gamma)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

# Initialize wandb logger
wandb_logger = WandbLogger(project='my_project', log_model='all')

# Assuming device, train_dataloader, and val_dataloader are defined
model = EnhancedConvNet().to(device)

# Define the trainer with wandb logging
trainer = pl.Trainer(max_epochs=50, logger=wandb_logger, enable_checkpointing=True, check_val_every_n_epoch=10)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

wandb: ERROR Error while calling W&B API: Error 1062 (23000): Duplicate entry 'vvzqboae0l8kmlodtcnokyyjdqhhl2cwy0v3am5y18r48eg91oexekrzv2exvhtx' for key 'client_id_mappings.PRIMARY' (<Response [409]>)


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [12]:
import nltk
from nltk.metrics.distance import edit_distance

model.to('cpu')

def decimal_to_hexadecimal(decimal):
    hex_digits = "0123456789abcdefpppppppppppppppppp"
    return hex_digits[decimal]

hex_predictions = []
hex_golds = []
scores = []

for text, gold in val_dataloader:

    prediction = model.forward(text).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    gold = [round(x) for x in gold[0].tolist()]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold]

    hex_prediction = [x for x in hex_prediction if x != 'p']
    hex_gold = [x for x in hex_gold if x != 'p']

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(f"hex_prediction = {hex_prediction}")
    print(f"hex_gold = {hex_gold}\n")

    hex_predictions.append(hex_prediction)
    hex_golds.append(hex_gold)

    scores.append(edit_distance(hex_prediction, hex_gold))

# Convert hex strings to integers
print(f"Average distance is {np.mean(scores)}")

hex_prediction = 789ce4bc47b86d53b45a3d5b3ec85bca3bdacc5a524539392d2d03005a770867
hex_gold = 789cf3cc53c8ad54c84b2d4b2dd24dcd4bc9cc4b57282c4d2d2e01006b3a08f2

hex_prediction = 789c2aad47cc44c65348ab4b3a2d63273c490400166627ba
hex_gold = 789c0bcf48cd53c85448cb2c2a2e51284e2c07002ffb05cf

hex_prediction = 789c29c8b93c56c8cd2cc4c53632966473b495aa95bdbcb8a37436a996cec
hex_gold = 789c0bc9c82c56c8cd2fcb4c5500328a0b5293331373722a01662508bb

hex_prediction = 789c25567ca5b3d2f4a5a53187a4653669abbbb2b232145a94abb
hex_gold = 789cf35448cb4c2f2d4a4d5128c9485528cfc8cf49050042e906f0

hex_prediction = 789c25439cb2d3bd9c88a43398c3d4b5b4126ba0b2a141039473979
hex_gold = 789cf35428c92f49ccc9a954484c2f4a4d5528c9482c01004c060768

hex_prediction = 789cf4cb3d3c555476646771d594a37a5363337aa6aac
hex_gold = 789cf3c92c4e54c82c564854c8c82f49cd010028e60543

hex_prediction = 789c1ac9b92d56227576698297d7e74548eac9ba8c3623579a6acb
hex_gold = 789c0bc9c82c5600a24485f28cfc94d2bcbccc120043890716

hex_prediction = 789c1a9ab8