<a href="https://www.kaggle.com/code/billhensen/uas-nlp-qna-gru?scriptVersionId=238862187" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# QnA GRU

## Dataset

In [1]:
!pip install -q evaluate

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.12.0 which is incompatible.
torch 2.5.1+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cublas-cu12 12.8.4.1 which is incompatible.
torch 2.5.1+cu124 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.3.0.75 which is incompatible.
torch 2.5.1+cu124 requires nvidia-cufft-cu12==11.2.1.3; platform_system == "Linux" and platform_machine == "x86

In [2]:
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
import pandas as pd
from sklearn.model_selection import train_test_split

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import evaluate
from tqdm.auto import tqdm

2025-05-10 05:44:33.545775: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746855873.700446      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746855873.747967      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
def load_data(
    dataset_name: str = "lib3m/lib3m_qa_dataset_v1",
    split: str = "train",
    lang: str = "en"
) -> pd.DataFrame:
    ds = load_dataset(dataset_name, split=split)
    df = ds.to_pandas()
    df = df[df.language == lang].reset_index(drop=True)
    return df


def split_dataframe(
    df,
    test_size: float = 0.2,
    random_state: int = 42
) -> tuple:

    train_df, val_df = train_test_split(
        df,
        test_size=test_size,
        random_state=random_state,
        shuffle=True
    )
    
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True)

class QADataset(Dataset):
    def __init__(
        self,
        dataframe,
        tokenizer: AutoTokenizer,
        max_length: int = 512
    ):
        self.df = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        question = row['question']
        content = row['content']
        answer = row['answer']

        text = f"<question> {question} <context> {content} <answer>"
        tokenized = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        labels = self.tokenizer(
            answer,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        ).input_ids
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            'input_ids': tokenized.input_ids.squeeze(),
            'attention_mask': tokenized.attention_mask.squeeze(),
            'labels': labels.squeeze()
        }

## Model

In [4]:
class GRUGenerator(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, hidden_dim=768, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)
        outputs, _ = self.gru(x)
        logits = self.fc(outputs)
        return logits

## Configurations

In [5]:
MODEL_DIR = '/kaggle/working/gru_model'
BATCH_SIZE = 32
EPOCHS = 7
LR = 1e-3
MAX_LEN = 256
NUM_LAYERS=5
EMBED_DIM=512
HIDDEN_DIM=256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Use both GPUs
if torch.cuda.device_count() > 1:
    MULTI_GPU = True
else:
    MULTI_GPU = False

## Data Preparation

In [6]:
df = load_data()
train_df, val_df = split_dataframe(df)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

train_ds = QADataset(train_df, tokenizer, max_length=MAX_LEN)
val_ds = QADataset(val_df, tokenizer, max_length=MAX_LEN)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True,num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=4)

README.md:   0%|          | 0.00/9.64k [00:00<?, ?B/s]

qa_pairs.parquet:   0%|          | 0.00/724M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/337525 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

## Training

In [7]:
model = GRUGenerator(len(tokenizer), EMBED_DIM, HIDDEN_DIM, NUM_LAYERS).to(DEVICE)
if MULTI_GPU:
    model = nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=-100)


os.makedirs(MODEL_DIR, exist_ok=True)
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for batch in loop:
        input_ids = batch['input_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        optimizer.zero_grad()
        logits = model(input_ids)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.4f}")
    torch.save(model.module.state_dict() if MULTI_GPU else model.state_dict(),
               f"{MODEL_DIR}/checkpoint_epoch{epoch+1}.pt")
# Save final
torch.save(model.module.state_dict() if MULTI_GPU else model.state_dict(), f"{MODEL_DIR}/final.pt")
tokenizer.save_pretrained(MODEL_DIR)

Epoch 1/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Epoch 1/7, Loss: 6.8920


Epoch 2/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Epoch 2/7, Loss: 6.6484


Epoch 3/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Epoch 3/7, Loss: 6.5357


Epoch 4/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Epoch 4/7, Loss: 6.4607


Epoch 5/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Epoch 5/7, Loss: 6.4058


Epoch 6/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Epoch 6/7, Loss: 6.3616


Epoch 7/7:   0%|          | 0/4640 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b9da9862020>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    Exception ignored in: self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b9da9862020>
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        if w.is_alive():self._shutdown_workers()
 Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b9da9862020>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1550, in 

Epoch 7/7, Loss: 6.3255


('/kaggle/working/gru_model/tokenizer_config.json',
 '/kaggle/working/gru_model/special_tokens_map.json',
 '/kaggle/working/gru_model/vocab.json',
 '/kaggle/working/gru_model/merges.txt',
 '/kaggle/working/gru_model/added_tokens.json',
 '/kaggle/working/gru_model/tokenizer.json')

## Evaluation

In [8]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=a561a409a6ce208f55dd399917c5c11631492f206cfb3986ea388e3fa113336b
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [9]:
metric_rouge = evaluate.load('rouge')
model.eval()
preds, refs = [], []
loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
for batch in loop:
    input_ids = batch['input_ids'].to(DEVICE)
    with torch.no_grad():
        logits = model(input_ids)
    generated = torch.argmax(logits, dim=-1)
    for gen_ids, label_ids in zip(generated, batch['labels']):
        pred = tokenizer.decode(gen_ids.cpu(), skip_special_tokens=True)
        ref = tokenizer.decode(label_ids[label_ids!=-100].cpu(), skip_special_tokens=True)
        preds.append(pred)
        refs.append(ref)
results = metric_rouge.compute(predictions=preds, references=refs)
print("Evaluation (ROUGE):", results)

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

Epoch 7/7:   0%|          | 0/1160 [00:00<?, ?it/s]

Evaluation (ROUGE): {'rouge1': 0.08461288857231261, 'rouge2': 0.003961319565821067, 'rougeL': 0.07338301785231338, 'rougeLsum': 0.07337333205777707}


## Testing

In [10]:
sample = val_df.iloc[3]
prompt = f"<question> {sample['question']} <context> {sample['content']} <answer>"

inputs = tokenizer(
    prompt,
    return_tensors='pt',
    padding='max_length',
    max_length=MAX_LEN
).to(DEVICE)

model.eval()
generated_ids = []
input_ids = inputs['input_ids']
for _ in range(100):
    with torch.no_grad():
        logits = model(input_ids)
    next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    generated_ids.append(next_token_id.item())
    input_ids = torch.cat([input_ids, next_token_id], dim=-1)
    if next_token_id.item() == tokenizer.eos_token_id:
        break

print("Prompt: ", prompt)

answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("\nGenerated Answer:", answer)

Prompt:  <question> Why might the political focus on cost controls, electronic medical records, and preventing lawsuits be considered insufficient for solving the healthcare crisis? <context> # Chapter 4: Regulation: The Helping Hand That Harms
## Only Free Markets Can Solve The Healthcare Crisis

Every election cycle, we hear politicians talk only of cost controls, electronic medical records, and preventing lawsuits in order to solve our medical crisis. We do not hear from them discussions of the real problems of government-paid insurance and the third-party payer system, and of medical boards. Some pundits argue that technology increases medical costs. Though technology lowers costs in other industries, people think that it somehow increases costs in the healthcare industry. Indeed, Paul Krugman claims that healthcare costs rise simply "because of medical progress."140 With these kinds of backwards notions, our "leaders" set out to implement yet more regulation and price controls, wh