In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [3]:
import torch
import pandas as pd

# Load the dataset
df = pd.read_csv('../Datasets/new_dataset_deflate.csv')

# Tokenize the text column
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
df['text'] = df['text'].apply(lambda x: tokenizer.encode(x, truncation=True))
df['deflate_hex'] = df['deflate_hex'].apply(lambda x: tokenizer.encode(x, truncation=True))

# Function to calculate minimum length
def calculate_minimum_length(column):
    return min(len(seq) for seq in column)

# Calculate average lengths
text_min_len = calculate_minimum_length(df['text'])
hex_min_len = calculate_minimum_length(df['deflate_hex'])

# Function to pad and truncate sequences
def truncate_sequences(sequences, target_length):
    return [sequence[:target_length] for sequence in sequences] 

# Pad and truncate text sequences

print(f"df['text'] before: {df['text']}")
print(f"text_min_len: {text_min_len}")
print(f"df['deflate_hex'] before: {df['deflate_hex']}")
print(f"hex_min_len: {hex_min_len}")

df['text'] = truncate_sequences(df['text'], text_min_len)
df['deflate_hex'] = truncate_sequences(df['deflate_hex'], hex_min_len)

# Convert to tensors
text_tensor = torch.tensor(df['text'].tolist(), dtype=torch.float32)
print(text_tensor.shape)

hex_tensor = torch.tensor([list(map(int, list(bin_seq))) for bin_seq in df['deflate_hex']], dtype=torch.float32)
print(hex_tensor.shape)

train_size = int(0.8 * len(text_tensor))
val_size = int(0.8 * train_size)

# Split into train, validation, and test sets
train_text, train_hex = text_tensor[:train_size], hex_tensor[:train_size]
val_text, val_hex = train_text[:val_size], train_hex[:val_size]
test_text, test_hex = text_tensor[train_size:], hex_tensor[train_size:]

batch_size = 1024

train_data = TensorDataset(train_text, train_hex)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

val_data = TensorDataset(val_text, val_hex)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

test_data = TensorDataset(test_text, test_hex)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# Test the data loaders
for text, hex in train_loader:
    print(text.shape)
    print(hex.shape)
    break


df['text'] before: 0        [0, 3762, 9, 5, 97, 34910, 34, 2801, 14, 71, 2...
1        [0, 250, 4613, 410, 931, 4, 28696, 3809, 1589,...
2        [0, 100, 802, 42, 21, 10, 4613, 169, 7, 1930, ...
3        [0, 34480, 89, 18, 10, 284, 147, 10, 410, 2143...
4        [0, 28970, 1334, 21129, 118, 18, 22, 16587, 11...
                               ...                        
49995    [0, 100, 802, 42, 1569, 222, 10, 159, 235, 205...
49996    [0, 26954, 6197, 6, 1099, 6054, 6, 1099, 3501,...
49997    [0, 100, 524, 10, 4019, 5850, 11, 2242, 4306, ...
49998    [0, 100, 437, 164, 7, 33, 7, 11967, 19, 5, 986...
49999    [0, 3084, 65, 3352, 5, 2141, 20351, 4133, 7, 2...
Name: text, Length: 50000, dtype: object
text_min_len: 10
df['deflate_hex'] before: 0        [0, 39413, 438, 288, 417, 3245, 23417, 242, 39...
1        [0, 39413, 438, 26866, 7309, 698, 34836, 1922,...
2        [0, 39413, 438, 288, 417, 398, 28690, 18616, 4...
3        [0, 39413, 438, 306, 417, 398, 23219, 428, 288...
4        [0,

In [7]:
DEBUG = False

class LSTM(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        #use dropout
        self.dropout = nn.Dropout(0.2)
        #use batchnorm
        self.bn = nn.BatchNorm1d(hidden_dim)

        #enrich the model with more layers
        self.lstm2 = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        #use relu
        self.relu = nn.ReLU()
        
        self.loss = nn.MSELoss()

    def forward(self, x):
        if DEBUG:
            print(f"FORWARD - x.shape: {x.shape}")
        out, _ = self.lstm(x)
        if DEBUG:
            print(f"FORWARD - out.shape: {out.shape}")
        out = self.fc(out)
        if DEBUG:
            print(f"FORWARD - out.shape: {out.shape}")
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        if DEBUG:
            print(f"TRAINING_STEP - x.shape: {x.shape}")
            print(f"TRAINING_STEP - y.shape: {y.shape}")

        y_hat = self(x)
        if DEBUG:
            print(f"TRAINING_STEP - y_hat.shape: {y_hat.shape}")
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        print('val_loss:', loss.item())  # Add this line for debugging
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)
    

input_dim = 10
hidden_dim = 64
num_layers = 2
output_dim = 46
model = LSTM(input_dim, hidden_dim, num_layers, output_dim)

trainer = pl.Trainer(max_epochs=500)
trainer.fit(model, train_loader, val_loader)



Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]val_loss: 219443312.0
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 86.91it/s]val_loss: 219610368.0
                                                                           

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bo

Epoch 0: 100%|██████████| 40/40 [00:00<00:00, 69.98it/s, v_num=148]val_loss: 222180416.0
val_loss: 216325696.0
val_loss: 216279536.0
val_loss: 219750624.0
val_loss: 219336624.0
val_loss: 220945184.0
val_loss: 221015136.0
val_loss: 221013760.0
val_loss: 219137792.0
val_loss: 218263376.0
val_loss: 220741280.0
val_loss: 217459456.0
val_loss: 219337552.0
val_loss: 219558160.0
val_loss: 218597552.0
val_loss: 217127936.0
val_loss: 219183216.0
val_loss: 219504512.0
val_loss: 219201456.0
val_loss: 217741232.0
val_loss: 218056464.0
val_loss: 222016288.0
val_loss: 219245216.0
val_loss: 220588928.0
val_loss: 220301456.0
val_loss: 216301616.0
val_loss: 222332864.0
val_loss: 219709984.0
val_loss: 218439328.0
val_loss: 218123664.0
val_loss: 218346144.0
val_loss: 215746192.0
Epoch 1: 100%|██████████| 40/40 [00:00<00:00, 45.19it/s, v_num=148]val_loss: 219048160.0
val_loss: 220435104.0
val_loss: 217757648.0
val_loss: 218866560.0
val_loss: 214271296.0
val_loss: 218672496.0
val_loss: 218124560.0
val_loss

In [8]:
# Test the model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

for text, hex in test_loader:
    prediction = model(text)
    gold = hex

    prediction = prediction[0].tolist()
    prediction = [round(x) for x in prediction]
    print(f"prediction: {prediction}")
    prediction = tokenizer.decode(prediction)
    
    gold = gold[0].tolist()
    gold = [round(x) for x in gold]
    print(f"gold: {gold}")
    gold = tokenizer.decode(gold)

    print(f"Prediction: {prediction}")
    print(f"Gold: {gold}")


    break

prediction: [0, 9540, 416, 5134, 7660, 7225, 5449, 5930, 5242, 5906, 5725, 5804, 5037, 6462, 5413, 5303, 5122, 5375, 5455, 5385, 5283, 5404, 5452, 5504, 5601, 5652, 5676, 5647, 5599, 5623, 5584, 5493, 5390, 5439, 5418, 5453, 5285, 5308, 5312, 5432, 5422, 5365, 5350, 5283, 5395, 5354]
gold: [0, 39413, 438, 134, 417, 398, 28690, 18616, 2940, 2965, 438, 39558, 38759, 5243, 612, 506, 406, 242, 398, 1366, 102, 134, 102, 5606, 102, 306, 438, 3546, 3079, 102, 466, 242, 176, 506, 401, 417, 4111, 3209, 417, 246, 428, 30042, 3204, 401, 2983, 417]
Prediction: <s> coat alreadybed taxpayers transferred63 locked establisheterUM therapyumastory Sport fled Euro copy Mount speaker pregnantEd crewsises survivedthan intense partly Tyler expense venue adjusted editing alert clothes sessionsante Tokyo recover Warren scenes Having cur pregnant prominent Industry
Gold: <s>789c1d8cd10980300c0557797400f7e818a1a61a4c0934a9e2f6d67ebd3b385ec631d
