In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print("Device:", device)


Device: cuda


In [3]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text']

        # Convert hex into tensor
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        hex_data = [int(x, 16) for x in hex_data]
        
        #pad to reach 512
        padded_hex_data = hex_data + [0] * (512 - len(hex_data))
        tensor_hex_data = torch.tensor(padded_hex_data)
        return text, tensor_hex_data

# Load the dataset
df = pd.read_csv('../../Datasets/new_dataset_deflate.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

In [4]:
import pytorch_lightning as pl
from transformers import BartTokenizer, BartModel
import torch
import torch.nn.functional as F

DEBUG = False


class TransformerModel(pl.LightningModule):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
        self.transformer = BartModel.from_pretrained('facebook/bart-base')

        #set the transformer padding character to 0
        self.transformer.config.pad_token_id = 0

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(128 * 768, 512)

        self.loss = torch.nn.MSELoss()

        self.to(device)

    def forward(self, text):

        # Tokenize the text
        input_ids = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        attention_mask = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).attention_mask.to(device)

        # Pad input_ids and attention_mask to a fixed length of 512
        padded_input_ids = F.pad(input_ids, (0, 128 - input_ids.shape[1]), 'constant', 0)
        padded_attention_mask = F.pad(attention_mask, (0, 128 - attention_mask.shape[1]), 'constant', 0)

        # Ensure padding is on the device
        padded_input_ids = padded_input_ids.to(device)
        padded_attention_mask = padded_attention_mask.to(device)

        if DEBUG:
            print(f"FORWARD: padded_input_ids.shape = {padded_input_ids.shape}")
            print(f"FORWARD: padded_attention_mask.shape = {padded_attention_mask.shape}")

        # Pass tokenized and padded text through the transformer
        transformer_output = self.transformer(input_ids=padded_input_ids, attention_mask=padded_attention_mask).last_hidden_state
        if DEBUG:
            print(f"FORWARD: transformer_output.shape = {transformer_output.shape}")

        # Pooling over the sequence dimension
        flattened_output = self.flatten(transformer_output)
        if DEBUG:
            print(f"FORWARD: flattened_output.shape = {flattened_output.shape}")

        # Apply the linear layer
        final_output = self.linear(flattened_output)
        if DEBUG:
            print(f"FORWARD: final_output.shape = {final_output.shape}")

        return final_output
    
    def training_step(self, batch, batch_idx):
        text, hex_data = batch

        if DEBUG:
            print(f"TRAINING_STEP: text = {text}")
            print(f"TRAINING_STEP: hex_data.shape = {hex_data.shape}")

        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"TRAINING_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, hex_data.float())
        if DEBUG:
            print(f"TRAINING_STEP: loss = {loss}")

        return loss
    
    def validation_step(self, batch, batch_idx):
        text, hex_data = batch
        if DEBUG:
            print(f"VALIDATION_STEP: text = {text}")
            print(f"VALIDATION_STEP: hex_data = {hex_data}")
        
        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"VALIDATION_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, hex_data.float())
        print(f"VALIDATION_STEP: loss = {loss}")

        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)



model = TransformerModel()
trainer = pl.Trainer(max_epochs=1, enable_checkpointing=False, logger=False)
trainer.fit(model, train_loader, val_loader)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


VALIDATION_STEP: loss = 29.07815170288086
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  1.23it/s]VALIDATION_STEP: loss = 28.568357467651367
                                                                           

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/625 [00:00<?, ?it/s] 

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
model.to(device)

# Test the model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

for text, hex in val_loader:

    print(f"Text: {text[0]}")
    print(f"Hex: {hex[0]}")

    prediction = model.forward(text[0]).tolist()[0]
    prediction = [0 if x < 0 else round(x) for x in prediction]
    print(f"Tokenized prediction: {prediction}")
    
    #tokenize gold[0]
    encoded_gold = tokenizer(hex[0], return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

    print(f"Tokenized gold: {encoded_gold[0].tolist()}")

    decoded_prediction = tokenizer.decode(prediction)
    print(f"Decoded prediction: {decoded_prediction}")

    decoded_gold = tokenizer.decode(encoded_gold[0].tolist())
    print(f"Decoded gold: {decoded_gold}")

    break

Text: In my never-ending quest to see as many quality movies as possible in my lifetime, i stumbled upon this film
Hex: 789c1dccd10d83301004d156b600520865386281957c67e0ce48ee1e92df37d2cc0e1b70debc3ef445bee1ec8c4436048912b0e2e3c5529503d66e317e7cb4087d2ba1ffa26a65ca384188ecf69605fd688edc155855ed016233278f
Tokenized prediction: [0, 39297, 460, 6470, 9892, 8780, 7797, 7377, 6276, 7823, 7703, 6944, 6500, 8512, 6993, 7050, 7429, 6982, 6458, 7754, 6914, 7126, 8348, 8155, 7893, 7209, 6558, 7590, 6840, 7220, 7655, 7373, 6613, 7786, 6738, 7538, 6866, 6614, 6937, 7844, 6657, 7182, 7320, 6401, 6481, 7131, 7669, 7396, 7202, 7091, 6956, 7255, 7199, 6681, 6999, 7308, 6854, 7439, 6709, 6317, 6906, 6469, 7593, 7064, 6949, 7198, 6692, 6604, 7421, 6928, 6911, 6680, 6538, 6886, 7005, 6648, 6939, 7094, 7154, 7163, 7619, 5987, 6663, 6507, 6794, 7669, 7417, 6116, 6683, 6309, 6296, 7166, 6588, 6535, 6057, 7037, 6337, 6129, 7532, 5780, 5660, 5873, 6108, 5549, 5243, 4491, 4688, 4840, 4397, 4071, 3607, 3893, 300