In [1]:
%%capture
%pip install torch pandas lightning trl

import torch
from torch import nn
import pytorch_lightning as pl
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
SEED = 999
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [3]:
df = pd.read_csv('../Datasets/new_dataset_deflate_binary.csv')

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

def tokenize(text):
    return tokenizer(text, truncation=True, padding='max_length', max_length=512, return_tensors='pt').input_ids.squeeze(0)

tokenized_text = df['text'].apply(tokenize)

# Convert tokenized text to LongTensor
tokenized_text_tensor = torch.stack(tokenized_text.tolist()).long()

deflate_binary = df['deflate_binary'].apply(lambda x: ''.join(filter(str.isdigit, x)))
binary_max_len = max(len(x) for x in deflate_binary)
deflate_binary_padded = deflate_binary.apply(lambda x: x + '0' * (binary_max_len - len(x)))

deflate_binary_tensor = torch.tensor([list(map(int, list(x))) for x in deflate_binary_padded], dtype=torch.long)

# Split the dataset
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))
test_size = len(df) - train_size - val_size

train_dataset = TensorDataset(tokenized_text_tensor[:train_size], deflate_binary_tensor[:train_size])
val_dataset = TensorDataset(tokenized_text_tensor[train_size:train_size+val_size], deflate_binary_tensor[train_size:train_size+val_size])
test_dataset = TensorDataset(tokenized_text_tensor[train_size+val_size:], deflate_binary_tensor[train_size+val_size:])

# Create dataloaders
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [4]:
# Define the BART model
class BartModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
        self.learning_rate = 0.1

    def forward(self, input_ids, labels):
        loss, logits = self.model(input_ids, labels=labels)[:2]
        return loss, logits

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
    
    def training_step(self, batch, batch_idx):
        input_ids, labels = batch
        loss, _ = self(input_ids.to(device), labels.to(device))
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, labels = batch
        loss, _ = self(input_ids.to(device), labels.to(device))
        self.log("val_loss", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids, labels = batch
        loss, _ = self(input_ids.to(device), labels.to(device))
        self.log("test_loss", loss)
        return loss

# Initialize the model and trainer
model = BartModel().to(device)
trainer = pl.Trainer(
    max_epochs=50,
    precision=16,  # for 16-bit precision
    accumulate_grad_batches=4,  # for gradient accumulation
)
# Train the model
trainer.fit(model, train_dataloader, val_dataloader)

# Test the model
trainer.test(model, dataloaders=test_dataloader)

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\lightning_fabric\connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

C:\Users\tomma\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
