In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print("Device:", device)


Device: cuda


In [3]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text']

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]
        
        #pad to reach 512
        padded_hex_data = hex_data + [0] * (512 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data)
        return text, tensor_hex_data

# Load the dataset
df = pd.read_csv('../../Datasets/new_dataset_deflate.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [6]:
import pytorch_lightning as pl
from transformers import BartTokenizer, BartModel
import torch
import torch.nn.functional as F

DEBUG = False


class TransformerModel(pl.LightningModule):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
        self.transformer = BartModel.from_pretrained('facebook/bart-base')

        #set the transformer padding character to 0
        self.transformer.config.pad_token_id = 0

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(128 * 768, 512)

        self.loss = torch.nn.MSELoss()

        self.to(device)

    def forward(self, text):

        # Tokenize the text
        input_ids = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        attention_mask = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).attention_mask.to(device)

        # Pad input_ids and attention_mask to a fixed length of 512
        padded_input_ids = F.pad(input_ids, (0, 128 - input_ids.shape[1]), 'constant', 0)
        padded_attention_mask = F.pad(attention_mask, (0, 128 - attention_mask.shape[1]), 'constant', 0)

        # Ensure padding is on the device
        padded_input_ids = padded_input_ids.to(device)
        padded_attention_mask = padded_attention_mask.to(device)

        if DEBUG:
            print(f"FORWARD: padded_input_ids.shape = {padded_input_ids.shape}")
            print(f"FORWARD: padded_attention_mask.shape = {padded_attention_mask.shape}")

        # Pass tokenized and padded text through the transformer
        transformer_output = self.transformer(input_ids=padded_input_ids, attention_mask=padded_attention_mask).last_hidden_state
        if DEBUG:
            print(f"FORWARD: transformer_output.shape = {transformer_output.shape}")

        # Pooling over the sequence dimension
        flattened_output = self.flatten(transformer_output)
        if DEBUG:
            print(f"FORWARD: flattened_output.shape = {flattened_output.shape}")

        # Apply the linear layer
        final_output = self.linear(flattened_output)
        if DEBUG:
            print(f"FORWARD: final_output.shape = {final_output.shape}")

        return final_output
    
    def training_step(self, batch, batch_idx):
        text, hex_data = batch

        if DEBUG:
            print(f"TRAINING_STEP: text = {text}")
            print(f"TRAINING_STEP: hex_data.shape = {hex_data.shape}")

        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"TRAINING_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, hex_data.float())
        print(f"TRAINING_STEP: loss = {loss}")

        return loss
    
    def validation_step(self, batch, batch_idx):
        text, hex_data = batch
        if DEBUG:
            print(f"VALIDATION_STEP: text = {text}")
            print(f"VALIDATION_STEP: hex_data = {hex_data}")
        
        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"VALIDATION_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, hex_data.float())
        print(f"VALIDATION_STEP: loss = {loss}")

        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.05)



model = TransformerModel()

# load the weights
model.load_state_dict(torch.load('bart_model.pt'))

trainer = pl.Trainer(max_epochs=5, enable_checkpointing=False, logger=False)
trainer.fit(model, train_loader, val_loader)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]VALIDATION_STEP: loss = 9.817407608032227
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  8.26it/s]VALIDATION_STEP: loss = 9.684063911437988
Epoch 0:   0%|          | 0/2500 [00:00<?, ?it/s]                          TRAINING_STEP: loss = 10.058135032653809
Epoch 0:   0%|          | 1/2500 [00:00<08:46,  4.75it/s]TRAINING_STEP: loss = 23061.859375
Epoch 0:   0%|          | 2/2500 [00:01<35:41,  1.17it/s]TRAINING_STEP: loss = 127007.265625
Epoch 0:   0%|          | 3/2500 [00:02<41:09,  1.01it/s]TRAINING_STEP: loss = 41278.96484375
Epoch 0:   0%|          | 4/2500 [00:04<45:22,  0.92it/s]TRAINING_STEP: loss = 26156.712890625
Epoch 0:   0%|          | 5/2500 [00:05<48:40,  0.85it/s]TRAINING_STEP: loss = 28129.978515625
Epoch 0:   0%|          | 6/2500 [00:07<51:24,  0.81it/s]TRAINING_STEP: loss = 24468.435546875
Epoch 0:   0%|          | 7/2500 [00:08<52:31,  0.79it/s]TRAINING_STEP: loss = 1295.2961425781

In [None]:
#save weights
# torch.save(model.state_dict(), 'bart_model.pt')

In [None]:
model.to(device)

# Test the model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

def decimal_to_hexadecimal(decimal):    
    hex_digits = "0123456789abcdefghilmnopqrstuvz"
    return hex_digits[decimal]

for text, gold in val_loader:

    prediction = model.forward(text[0]).tolist()[0]

    #round every elem in prediction
    prediction = [round(x) for x in prediction]

    #set to 0 every negative elem in prediction
    prediction = [max(0, x) for x in prediction]

    #convert every 0 to 0, 1 to 1, ... , 15 to f
    hex_prediction = [decimal_to_hexadecimal(x) for x in prediction]

    hex_gold = [decimal_to_hexadecimal(x) for x in gold[0].tolist()]

    #convert to string
    hex_prediction = ''.join(hex_prediction)
    hex_gold = ''.join(hex_gold)

    print(f"Prediction: {hex_prediction}")
    print(f"Gold: {hex_gold}")



    
    break

Prediction: 789c2cab820d82311835a3876986a585769716865537864743669577626789c8a575867d788567877a27477675677776678ac934958a8896567977ab884a786a9798597b737777278aaa8a969855686a5766861855822554343552552534434544464435351233120100010101000000000000000000110100010011000001000010000000001000000001000001000100100000010000000000100000000000000000000000000001000000000100000010000000000100010000010000000001000100000000000000100000000000000100000100011000010000000000000000100000100000000000000100000000000100000000000010000000100001
Gold: 789c1dccd10d83301004d156b600520865386281957c67e0ce48ee1e92df37d2cc0e1b70debc3ef445bee1ec8c4436048912b0e2e3c5529503d66e317e7cb4087d2ba1ffa26a65ca384188ecf69605fd688edc155855ed016233278f000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000