In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [None]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print("Device:", device)


Device: cuda


In [None]:
import pandas as pd
from torch.utils.data import DataLoader, Dataset

# Custom Dataset
class TextHexDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx]['text']
        hex_data = self.dataframe.iloc[idx]['deflate_hex']
        return text, hex_data

# Load the dataset
df = pd.read_csv('../Datasets/new_dataset_deflate.csv')

# Create datasets
dataset = TextHexDataset(df)

# Split the dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# DataLoaders with batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [None]:
import pytorch_lightning as pl
from transformers import BartTokenizer, BartModel
import torch
import torch.nn.functional as F

DEBUG = False


class TransformerModel(pl.LightningModule):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
        self.transformer = BartModel.from_pretrained('facebook/bart-base')

        self.flatten = nn.Flatten()
        self.linear = nn.Linear(128 * 768, 128)

        self.loss = torch.nn.MSELoss()

        self.to(device)

    def forward(self, text):

        # Tokenize the text
        input_ids = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        attention_mask = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).attention_mask.to(device)

        # Pad input_ids and attention_mask to a fixed length of 128
        padded_input_ids = F.pad(input_ids, (0, 128 - input_ids.shape[1]), 'constant', 0)
        padded_attention_mask = F.pad(attention_mask, (0, 128 - attention_mask.shape[1]), 'constant', 0)

        # Ensure padding is on the device
        padded_input_ids = padded_input_ids.to(device)
        padded_attention_mask = padded_attention_mask.to(device)

        if DEBUG:
            print(f"FORWARD: padded_input_ids.shape = {padded_input_ids.shape}")
            print(f"FORWARD: padded_attention_mask.shape = {padded_attention_mask.shape}")

        # Pass tokenized and padded text through the transformer
        transformer_output = self.transformer(input_ids=padded_input_ids, attention_mask=padded_attention_mask).last_hidden_state
        if DEBUG:
            print(f"FORWARD: transformer_output.shape = {transformer_output.shape}")

        # Pooling over the sequence dimension
        flattened_output = self.flatten(transformer_output)
        if DEBUG:
            print(f"FORWARD: flattened_output.shape = {flattened_output.shape}")

        # Apply the linear layer
        final_output = self.linear(flattened_output)
        if DEBUG:
            print(f"FORWARD: final_output.shape = {final_output.shape}")

        return final_output
    
    def training_step(self, batch, batch_idx):
        text, hex_data = batch

        if DEBUG:
            print(f"TRAINING_STEP: text = {text}")
            print(f"TRAINING_STEP: hex_data = {hex_data}")

        #encode hex_data
        encoded_hex_data = self.tokenizer(hex_data, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        if DEBUG:
            print(f"TRAINING_STEP: encoded_hex_data.shape = {encoded_hex_data.shape}")

        #padd the hex_data
        padded_encoded_hex_data = F.pad(encoded_hex_data, (0, 128 - encoded_hex_data.shape[1]), 'constant', 0)
        if DEBUG:
            print(f"TRAINING_STEP: padded_encoded_hex_data.shape = {padded_encoded_hex_data.shape}")

        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"TRAINING_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, padded_encoded_hex_data.float())
        if DEBUG:
            print(f"TRAINING_STEP: loss = {loss}")

        return loss
    
    def validation_step(self, batch, batch_idx):
        text, hex_data = batch
        if DEBUG:
            print(f"VALIDATION_STEP: text = {text}")
            print(f"VALIDATION_STEP: hex_data = {hex_data}")

        #encode hex_data
        encoded_hex_data = self.tokenizer(hex_data, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        if DEBUG:
            print(f"VALIDATION_STEP: encoded_hex_data.shape = {encoded_hex_data.shape}")

        #padd the hex_data
        padded_encoded_hex_data = F.pad(encoded_hex_data, (0, 128 - encoded_hex_data.shape[1]), 'constant', 0)
        if DEBUG:
            print(f"VALIDATION_STEP: padded_encoded_hex_data.shape = {padded_encoded_hex_data.shape}")
        
        # Pass the text through the transformer
        transformer_output = self.forward(text)
        if DEBUG:
            print(f"VALIDATION_STEP: transformer_output.shape = {transformer_output.shape}")

        # Calculate the loss
        loss = self.loss(transformer_output, padded_encoded_hex_data.float())
        print(f"VALIDATION_STEP: loss = {loss}")

        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)



model = TransformerModel()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, train_loader, val_loader)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


VALIDATION_STEP: loss = 172628288.0
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  1.45it/s]VALIDATION_STEP: loss = 184529344.0
                                                                           

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 1250/1250 [16:16<00:00,  1.28it/s, v_num=42]VALIDATION_STEP: loss = 120623312.0
VALIDATION_STEP: loss = 128157208.0
VALIDATION_STEP: loss = 116489544.0
VALIDATION_STEP: loss = 128303304.0
VALIDATION_STEP: loss = 121086440.0
VALIDATION_STEP: loss = 115566424.0
VALIDATION_STEP: loss = 127646448.0
VALIDATION_STEP: loss = 119994928.0
VALIDATION_STEP: loss = 121542488.0
VALIDATION_STEP: loss = 122171808.0
VALIDATION_STEP: loss = 122247120.0
VALIDATION_STEP: loss = 120634488.0
VALIDATION_STEP: loss = 124290512.0
VALIDATION_STEP: loss = 123832160.0
VALIDATION_STEP: loss = 121963120.0
VALIDATION_STEP: loss = 121057264.0
VALIDATION_STEP: loss = 124814656.0
VALIDATION_STEP: loss = 122875776.0
VALIDATION_STEP: loss = 117025872.0
VALIDATION_STEP: loss = 129087688.0
VALIDATION_STEP: loss = 119365600.0
VALIDATION_STEP: loss = 117803232.0
VALIDATION_STEP: loss = 123178144.0
VALIDATION_STEP: loss = 116724728.0
VALIDATION_STEP: loss = 125109856.0
VALIDATION_STEP: loss = 121503

In [None]:
for text, hex in val_loader:

    prediction = model(text)
    gold = hex

    print(f"Prediction: {prediction}")
    print(f"Gold: {gold}")

    break

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)