In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [12]:
# Load and preprocess dataset
df = pd.read_csv('../../Datasets/new_dataset_deflate_binary.csv')

# Tokenization and Padding Functions
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

def tokenize(text):
    tokens = tokenizer.encode(text, truncation=True, return_tensors="pt")
    return tokens.squeeze(0).float()

# Tokenize text
df['text'] = df['text'].apply(tokenize)

# Preprocess binary data
def clean_binary(binary):
    return ''.join(filter(lambda x: x in '01', binary))

df['deflate_binary'] = df['deflate_binary'].apply(clean_binary)

# Convert binary strings to tensor
def binary_to_tensor(binary):
    return torch.tensor([int(b) for b in binary], dtype=torch.float)

df['deflate_binary'] = df['deflate_binary'].apply(binary_to_tensor)

def collate_fn(batch):
    texts, binaries = zip(*batch)
    lengths = torch.tensor([len(text) for text in texts], dtype=torch.float)
    max_length = max(lengths)
    texts_padded = torch.zeros(len(texts), int(max_length), dtype=torch.float)  # Initialize padded tensor
    for i, text in enumerate(texts):
        texts_padded[i, :len(text)] = text  # Fill with tokenized text
    binaries_padded = nn.utils.rnn.pad_sequence(binaries, batch_first=True)
    print("Collate - texts_padded shape:", texts_padded.shape, "binaries_padded shape:", binaries_padded.shape)  # Add this line
    return texts_padded, binaries_padded, lengths

# Split dataset and create DataLoader
train, val, test = np.split(df.sample(frac=1, random_state=SEED), [int(.6*len(df)), int(.8*len(df))])

train_dataloader = DataLoader(list(zip(train['text'], train['deflate_binary'])), batch_size=64, collate_fn=collate_fn, shuffle=True)
val_dataloader = DataLoader(list(zip(val['text'], val['deflate_binary'])), batch_size=64, collate_fn=collate_fn)
test_dataloader = DataLoader(list(zip(test['text'], test['deflate_binary'])), batch_size=64, collate_fn=collate_fn)


  return bound(*args, **kwds)


In [16]:
class LSTM(pl.LightningModule):
    def __init__(self, input_dim = 1, hidden_dim=256, output_dim=42, num_layers=2, dropout_rate=0.1, learning_rate=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout_rate)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Multiply by 2 for bidirectional
        self.learning_rate = learning_rate
        
    def forward(self, text, lengths):
        lengths = lengths.cpu()
        packed_text = pack_padded_sequence(text, lengths, batch_first=True, enforce_sorted=False)
        print("Forward - packed_text shape:", packed_text.data.shape)  # Add this line
        lstm_out, _ = self.lstm(packed_text)
        print("Forward - lstm_out shape:", lstm_out.data.shape)  # Add this line
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)
        print("Forward - lstm_out shape:", lstm_out.shape)  # Add this line
        output = self.fc(lstm_out)
        print("Forward - output shape:", output.shape)
        return output
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y, x_lengths = batch
        print("Batch input shape:", x.shape)  # Add this line
        x = x.to(device)
        y = y.to(device)
        y_hat = self.forward(x, x_lengths)
        print("y_hat shape:", y_hat.shape, "y shape:", y.shape)  # Add this line
        loss = nn.functional.mse_loss(y_hat, y)
        self.log("train loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, x_lengths = batch
        x = x.to(device)
        y = y.to(device)
        y_hat = self.forward(x, x_lengths)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y, x_lengths = batch
        x = x.to(device)
        y = y.to(device)
        y_hat = self.forward(x, x_lengths)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log('test_loss', loss)

model = LSTM().to(device)

# Define the trainer with 50 epochs and showing eval results every 10 epochs
trainer = pl.Trainer(max_epochs=50)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]Collate - texts_padded shape: torch.Size([64, 42]) binaries_padded shape: torch.Size([64, 936])
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]Forward - packed_text shape: torch.Size([1736])
Forward - lstm_out shape: torch.Size([1736, 512])
Forward - lstm_out shape: torch.Size([64, 42, 512])
Forward - output shape: torch.Size([64, 42, 42])


  loss = nn.functional.mse_loss(y_hat, y)


RuntimeError: The size of tensor a (42) must match the size of tensor b (936) at non-singleton dimension 2

In [None]:
#test the model
for batch in test_dataloader:
    x, y = batch
    y_hat = model(x)

    #print the shapes of the tensors
    print(x.shape)
    print(y.shape)
    print(y_hat.shape)
    
    #convert x and y into the original strings
    x = x.tolist()
    y = y.tolist()
    y_hat = y_hat.tolist()

    x = tokenizer.decode(x[0])
    y = ''.join([str(int(i)) for i in y[0]])
    y_hat = ''.join([str(int(i)) for i in y_hat[0]])
    print(x)
    print(y)
    print(y_hat)
   
    break



torch.Size([1024, 83])
torch.Size([1024, 1232])
torch.Size([1024, 1232])
<s>The actors are so bland that it's almost impossible to tell them apart (Pauline Kael said of this movie: "The</s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
011110001001110000001101110010101011000100001101110000110011000000001100010001001101000101010101000011100110111011100010000110011011001001000010100110100001010001011110111000000001110001011111001000000000001010010100011010011000100001110100111001101000111110111010100011111000111110110111001101011000000110011111100010101001000111100000000100000011001010110000001110111100111100000011110101010101100010110000011110101100110011101101001111010111001001110110101111110010001011010011011101100001011100101010010100000111001010011111010001100001110110111100001110000000101011101011100110111011011111011011001010011011110000101000010001111101001000001