In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [3]:
import torch
import pandas as pd

# Load the dataset
df = pd.read_csv('../../Datasets/new_dataset_deflate.csv')

# Tokenize the text column
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
df['text'] = df['text'].apply(lambda x: tokenizer.encode(x, truncation=True))
df['deflate_hex'] = df['deflate_hex'].apply(lambda x: tokenizer.encode(x, truncation=True))

# Function to calculate minimum length
def calculate_minimum_length(column):
    return min(len(seq) for seq in column)

# Calculate average lengths
text_min_len = calculate_minimum_length(df['text'])
hex_min_len = calculate_minimum_length(df['deflate_hex'])

# Function to pad and truncate sequences
def truncate_sequences(sequences, target_length):
    return [sequence[:target_length] for sequence in sequences] 

# Pad and truncate text sequences

print(f"df['text'] before: {df['text']}")
print(f"text_min_len: {text_min_len}")
print(f"df['deflate_hex'] before: {df['deflate_hex']}")
print(f"hex_min_len: {hex_min_len}")

df['text'] = truncate_sequences(df['text'], text_min_len)
df['deflate_hex'] = truncate_sequences(df['deflate_hex'], hex_min_len)

# Convert to tensors
text_tensor = torch.tensor(df['text'].tolist(), dtype=torch.float32)
print(text_tensor.shape)

hex_tensor = torch.tensor([list(map(int, list(bin_seq))) for bin_seq in df['deflate_hex']], dtype=torch.float32)
print(hex_tensor.shape)

train_size = int(0.8 * len(text_tensor))
val_size = int(0.8 * train_size)

# Split into train, validation, and test sets
train_text, train_hex = text_tensor[:train_size], hex_tensor[:train_size]
val_text, val_hex = train_text[:val_size], train_hex[:val_size]
test_text, test_hex = text_tensor[train_size:], hex_tensor[train_size:]

batch_size = 1024

train_data = TensorDataset(train_text, train_hex)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

val_data = TensorDataset(val_text, val_hex)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

test_data = TensorDataset(test_text, test_hex)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# Test the data loaders
for text, hex in train_loader:
    print(text.shape)
    print(hex.shape)
    break


df['text'] before: 0        [0, 3762, 9, 5, 97, 34910, 34, 2801, 14, 71, 2...
1        [0, 250, 4613, 410, 931, 4, 28696, 3809, 1589,...
2        [0, 100, 802, 42, 21, 10, 4613, 169, 7, 1930, ...
3        [0, 34480, 89, 18, 10, 284, 147, 10, 410, 2143...
4        [0, 28970, 1334, 21129, 118, 18, 22, 16587, 11...
                               ...                        
49995    [0, 100, 802, 42, 1569, 222, 10, 159, 235, 205...
49996    [0, 26954, 6197, 6, 1099, 6054, 6, 1099, 3501,...
49997    [0, 100, 524, 10, 4019, 5850, 11, 2242, 4306, ...
49998    [0, 100, 437, 164, 7, 33, 7, 11967, 19, 5, 986...
49999    [0, 3084, 65, 3352, 5, 2141, 20351, 4133, 7, 2...
Name: text, Length: 50000, dtype: object
text_min_len: 10
df['deflate_hex'] before: 0        [0, 39413, 438, 288, 417, 3245, 23417, 242, 39...
1        [0, 39413, 438, 26866, 7309, 698, 34836, 1922,...
2        [0, 39413, 438, 288, 417, 398, 28690, 18616, 4...
3        [0, 39413, 438, 306, 417, 398, 23219, 428, 288...
4        [0,

In [4]:
DEBUG = False

class LSTM(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super().__init__()
        # LSTM layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)
        self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)

        # Fully connected layers
        self.fc = nn.Linear(hidden_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(hidden_dim)  # Batch normalization after fully connected layer
        self.dropout1 = nn.Dropout(0.2)  # Dropout for regularization

        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm1d(output_dim)
        
        self.loss = nn.MSELoss()

    def forward(self, x):
        if DEBUG:
            print(f"FORWARD - x.shape: {x.shape}")
        out, _ = self.lstm(x)
        if DEBUG:
            print(f"FORWARD - out.shape: {out.shape}")
        out = self.fc(out)
        if DEBUG:
            print(f"FORWARD - out.shape: {out.shape}")
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        if DEBUG:
            print(f"TRAINING_STEP - x.shape: {x.shape}")
            print(f"TRAINING_STEP - y.shape: {y.shape}")

        y_hat = self(x)
        if DEBUG:
            print(f"TRAINING_STEP - y_hat.shape: {y_hat.shape}")
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        print('val_loss:', loss.item())  # Add this line for debugging
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)
    

input_dim = 10
hidden_dim = 46
num_layers = 2
output_dim = 46
model = LSTM(input_dim, hidden_dim, num_layers, output_dim)

trainer = pl.Trainer(max_epochs=10, enable_checkpointing=False, logger=False)
trainer.fit(model, train_loader, val_loader)



Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]val_loss: 218759104.0
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 12.90it/s]val_loss: 221764976.0
                                                                           

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bo

Epoch 0: 100%|██████████| 40/40 [00:00<00:00, 56.89it/s]val_loss: 220630688.0
val_loss: 219465952.0
val_loss: 220290992.0
val_loss: 218038768.0
val_loss: 219814832.0
val_loss: 219205728.0
val_loss: 218368000.0
val_loss: 216166176.0
val_loss: 214549664.0
val_loss: 220892816.0
val_loss: 219742608.0
val_loss: 218755792.0
val_loss: 217150336.0
val_loss: 220281728.0
val_loss: 219545152.0
val_loss: 221283328.0
val_loss: 219296992.0
val_loss: 216364656.0
val_loss: 219938688.0
val_loss: 218419696.0
val_loss: 221002464.0
val_loss: 220014880.0
val_loss: 217844672.0
val_loss: 217194768.0
val_loss: 219567776.0
val_loss: 219486992.0
val_loss: 221463696.0
val_loss: 220031952.0
val_loss: 218036032.0
val_loss: 218882432.0
val_loss: 224930464.0
val_loss: 224104624.0
Epoch 1: 100%|██████████| 40/40 [00:00<00:00, 58.95it/s]val_loss: 218003280.0
val_loss: 218022560.0
val_loss: 216780672.0
val_loss: 218999008.0
val_loss: 217580944.0
val_loss: 218935328.0
val_loss: 221127776.0
val_loss: 219520608.0
val_loss

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
# Test the model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

for text, hex in test_loader:
    prediction = model(text)
    gold = hex

    prediction = prediction[0].tolist()
    prediction = [round(x) for x in prediction]
    print(f"prediction: {prediction}")
    prediction = tokenizer.decode(prediction)
    
    gold = gold[0].tolist()
    gold = [round(x) for x in gold]
    print(f"gold: {gold}")
    gold = tokenizer.decode(gold)

    print(f"Prediction: {prediction}")
    print(f"Gold: {gold}")


    break

5134
bed
prediction: [0, 7123, 418, 5134, 6558, 6459, 5429, 5831, 5230, 5812, 5666, 5727, 5038, 6163, 5402, 5282, 5127, 5359, 5429, 5362, 5275, 5382, 5438, 5482, 5568, 5609, 5632, 5609, 5562, 5580, 5553, 5469, 5386, 5424, 5401, 5429, 5285, 5293, 5306, 5409, 5399, 5347, 5341, 5290, 5380, 5350]
gold: [0, 39413, 438, 996, 38133, 23417, 242, 39134, 612, 438, 3761, 242, 134, 873, 25484, 134, 428, 288, 1610, 102, 2146, 242, 398, 438, 5379, 417, 245, 438, 466, 438, 40847, 23219, 6232, 4111, 4027, 7309, 5134, 134, 428, 5607, 34099, 1755, 4015, 428, 398, 417]
Prediction: <s> reviewed moneybed Intelligencefully FC toll streaming begun quarterssell Atlantic� principalkey ninth Maybe FC publication MitchellENT Acc Stone somewhat keenwith keen Lineistic 76 Christopher dramatic wounded proven FCante Make Carter Bruce WatsonaresNEWils Robinson cur
Gold: <s>789c15cb310e83300c46e1abfc1b0bea21e8c62d5c9c044bc846764ccbed1b96373ce95b8d
e


ValueError: invalid literal for int() with base 10: 'e'