In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print("Device:", device)


Device: cuda


In [3]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text']

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]
        
        #pad to reach 256
        padded_hex_data = hex_data + [20] * (256 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data)
        return text, tensor_hex_data

# Load the dataset
df = pd.read_csv('../../Datasets/shorthex2hex.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size=16, num_workers = 4)

In [4]:
import pytorch_lightning as pl
from transformers import BartTokenizer, BartModel
import torch
import torch.nn.functional as F

DEBUG = False


class TransformerModel(pl.LightningModule):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
        self.transformer = BartModel.from_pretrained('facebook/bart-base')

        #set the transformer padding character to 20
        self.tokenizer.pad_token = self.tokenizer.convert_ids_to_tokens(20)
        self.transformer.config.pad_token_id = 20

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 768, 256),
            nn.Dropout(p=0.2)  # Example of adding a dropout
            # You can add more layers here
        )

        self.loss = torch.nn.MSELoss()

        self.to(device)

    def forward(self, text):

        # Tokenize the text with padding
        encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
        input_ids = encoding.input_ids
        attention_mask = encoding.attention_mask

        # Ensure padding is on the device
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # Ensure padding is on the device
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        if DEBUG:
            print(f"FORWARD: input_ids.shape = {input_ids.shape}")
            print(f"FORWARD: attention_mask.shape = {attention_mask.shape}")

        # Pass tokenized and padded text through the transformer
        transformer_output = self.transformer(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        if DEBUG:
            print(f"FORWARD: transformer_output.shape = {transformer_output.shape}")

        # Apply the linear layer
        final_output = self.encoder(transformer_output)

        if DEBUG:
            print(f"FORWARD: final_output.shape = {final_output.shape}")

        return final_output
    
    def training_step(self, batch, batch_idx):
        text, hex_data = batch

        if DEBUG:
            print(f"TRAINING_STEP: text = {text}")
            print(f"TRAINING_STEP: hex_data.shape = {hex_data.shape}")

        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"TRAINING_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, hex_data.float())
        print(f"TRAINING_STEP: loss = {loss}")

        return loss
    
    def validation_step(self, batch, batch_idx):
        text, hex_data = batch
        if DEBUG:
            print(f"VALIDATION_STEP: text = {text}")
            print(f"VALIDATION_STEP: hex_data = {hex_data}")
        
        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"VALIDATION_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, hex_data.float())
        print(f"VALIDATION_STEP: loss = {loss}")

        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=1e-5)



model = TransformerModel()

trainer = pl.Trainer(max_epochs=5, enable_checkpointing=False, logger=False)
trainer.fit(model, train_loader, val_loader)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


In [None]:
#save weights
torch.save(model.state_dict(), 'bart_model.pt')

In [None]:
model.to(device)

# Test the model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

def decimal_to_hexadecimal(decimal):    
    hex_digits = "0123456789abcdefghilmnopqrstuvzppppppppppppppppppppp"
    return hex_digits[decimal]

hex_predictions = []
hex_golds = []

for text, gold in val_loader:

    prediction = model.forward(text[0]).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold[0].tolist()]

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(f"Prediction: {hex_prediction}")
    print(f"Gold: {hex_gold}")

    hex_predictions.append(hex_prediction)
    hex_golds.append(hex_gold)

Prediction: 9c0edbcd00785d8h0bee0g4a7aigepefedi8ci5014d0gl00hchggqfoc00ee07n0n0eqh0g00hfo0rroppppq000or0pqqoq0s00opp0rqprp0pqti00p0q00oqr00rpqq0qop0qnoqprqosp0nqq00pprqospp00usrnqtrpr0soqq0nrmn00n0pp0rqoonolprqo0000qr0pr0s00srqqp0o00qms00rq00pp0p0qrqqsorq00qro0lprnq0s
Gold: 789cf3cc53c8ad54c84b2d4b2dd24dcd4bc9cc4b57282c4d2d2e01006b3a08f2mmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm


In [None]:
from nltk.metrics.distance import edit_distance

assert len(hex_predictions) == len(hex_golds)

scores = []
pred_lenghts = []
gold_lenghts = []

for i in range(len(hex_predictions)):
    pred = hex_predictions[i]
    gold = hex_golds[i]
    scores.append(edit_distance(pred, gold))
    pred_lenghts.append(len(pred))
    gold_lenghts.append(len(gold))
    
print(f"Average prediction lenght is {np.mean(pred_lenghts)}")
print(f"Average gold lenght is {np.mean(gold_lenghts)}")
print(f"Average distance is {np.mean(scores)}")