In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [3]:
#retrieve this dataset ../Datasets/new_dataset_deflate_binary.csv'
df = pd.read_csv('../Datasets/new_dataset_deflate_binary.csv')

#the dataset has two columns: text and binary
#tokenize the text column using barttokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
df['text'] = df['text'].apply(lambda x: tokenizer.encode(x, truncation=True))

#now we need to pad the sequences to the same length
#first find the max length
text_max_len = 0
for i in df['text'].values:
    if len(i) > text_max_len:
        text_max_len = len(i)

#now pad the sequences
df['text'] = df['text'].apply(lambda x: x + [0]*(text_max_len-len(x)))

#now we need to convert the dataset into a tensor
#first convert the dataset into a list
text = df['text'].values.tolist()

#remove all the characters that are different from 0 and 1 from the binary column
deflate_binary = df['deflate_binary'].apply(lambda x: ''.join([i for i in x if i in ['0', '1']]))

#now we need to pad the sequences to the same length
#first find the max length
binary_max_len = 0
for i in deflate_binary.values:
    if len(i) > binary_max_len:
        binary_max_len = len(i)

deflate_binary = deflate_binary.apply(lambda x: x + '0'*(binary_max_len-len(x)))

#now convert the binary column into a list of lists
deflate_binary_list = []
for i in deflate_binary:
    deflate_binary_list.append([int(j) for j in i])

# Create a PyTorch tensor of float32s from the list of lists
deflate_binary_tensor = torch.tensor(deflate_binary_list, dtype=torch.float32)
text_tensor = torch.tensor(text, dtype=torch.float32)


#test if all the text tensors in the dataset have the same length
for i in text_tensor:
    if i.shape != torch.Size([text_max_len]):
        print('error')

#test if all the deflate_binary_tensor in the dataset have the same length
for i in deflate_binary_tensor:
    if i.shape != torch.Size([binary_max_len]):
        print('error')

#print the shapes of the tensors
print(text_tensor[0].shape)
print(deflate_binary_tensor[0].shape)


#now split the dataset into training, validation and test sets
#first split the dataset into training and test sets
train_text = text_tensor[:int(len(text_tensor)*0.8)]
train_binary = deflate_binary_tensor[:int(len(deflate_binary_tensor)*0.8)]

test_text = text_tensor[int(len(text_tensor)*0.8):]
test_binary = deflate_binary_tensor[int(len(deflate_binary_tensor)*0.8):]

#now split the training set into training and validation sets
train_text = train_text[:int(len(train_text)*0.8)]
train_binary = train_binary[:int(len(train_binary)*0.8)]

val_text = train_text[int(len(train_text)*0.8):]
val_binary = train_binary[int(len(train_binary)*0.8):]

#now create the dataloaders
batch_size = 1024

train_data = TensorDataset(train_text, train_binary)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

val_data = TensorDataset(val_text, val_binary)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

test_data = TensorDataset(test_text, test_binary)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# #test the dataloaders
# for batch in train_dataloader:
#     print(batch)
#     break


torch.Size([83])
torch.Size([1232])


In [7]:
class LSTM(pl.LightningModule):
    def __init__(self, input_dim=83, hidden_dim=256, output_dim=1232, num_layers=2, dropout_rate=0.1, learning_rate=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=dropout_rate)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Multiply by 2 for bidirectional
        self.learning_rate = learning_rate
        self.loss = nn.MSELoss()
        
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x.to(device)
        y.to(device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x.to(device)
        y.to(device)
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        print('val_loss', loss)

model = LSTM().to(device)

# Define the trainer with 50 epochs and showing eval results every 10 epochs
trainer = pl.Trainer(max_epochs=50)

# Train the model ⚡
trainer.fit(model, train_dataloader, val_dataloader)



Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]val_loss tensor(0.3008, device='cuda:0')
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 12.27it/s]val_loss tensor(0.3014, device='cuda:0')
Epoch 0: 100%|██████████| 32/32 [00:06<00:00,  5.09it/s, v_num=2]          val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
Epoch 1: 100%|██████████| 32/32 [00:06<00:00,  5.15it/s, v_num=2]val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
val_loss tensor(nan, device='cuda:0')
Epoch 2: 100%|██████████| 32/32 [00:06<00:00,  5.08it/s, v_num=2]val_loss tensor(nan, device='cuda:0')
v

In [None]:
#test the model
for batch in test_dataloader:
    x, y = batch
    y_hat = model(x)

    #print the shapes of the tensors
    print(x.shape)
    print(y.shape)
    print(y_hat.shape)
    
    #convert x and y into the original strings
    x = x.tolist()
    y = y.tolist()
    y_hat = y_hat.tolist()

    x = tokenizer.decode(x[0])
    y = ''.join([str(int(i)) for i in y[0]])
    y_hat = ''.join([str(int(i)) for i in y_hat[0]])
    print(x)
    print(y)
    print(y_hat)
   
    break



torch.Size([256, 83])
torch.Size([256, 1232])
torch.Size([256, 1232])
<s>Just reading why this show got canceled makes me rather steamed. This was a favorite of mine as a kid</s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
01111000100111000001110111001010110000010000110110000000001000000001000000000100110000000101011010110110000000100000101111110001011011010000001100010111010110001110000010000010010000001100001010011101000100101011101100110111111110101001110111001100011111101001100101100011010100101010001011110110100011000101010100011110011110000101000110000011100101011011000110010000100001110010001101001000000011110011110000011001110100011010010011010010110100001000100000101001010111100011100001100001010011100110100110001100000110111000111011101111001011110011000100001000100100101101110001100011101010100001001100100011101000010110100100100111011111101010