In [None]:
from datasets import load_dataset
import re
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from torch import nn, optim
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import warnings
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from seqeval.metrics import f1_score, classification_report
ptb= load_dataset('ptb_text_only', split=['train', 'validation', 'test'], trust_remote_code=True)


# Tokenization and Vocabulary Building
def tokenize(text):
    return re.findall(r'\w+', text.lower())

def build_vocab(dataset):
    counter = Counter()
    for example in dataset:
        tokens = tokenize(example['sentence'])
        counter.update(tokens)
    vocab = {word: idx for idx, (word, _) in enumerate(counter.items())}
    vocab['<PAD>'] = len(vocab)
    return vocab

vocab = build_vocab(ptb[0])
vocab_size = len(vocab)
pad_token_idx = vocab['<PAD>']

# Convert text to sequences of indices
def encode_text(text, vocab):
    return [vocab[word] for word in tokenize(text) if word in vocab]

# Process each split
train_data = [torch.tensor(encode_text(example['sentence'], vocab)) for example in ptb[0]]
val_data = [torch.tensor(encode_text(example['sentence'], vocab)) for example in ptb[1]]
test_data = [torch.tensor(encode_text(example['sentence'], vocab)) for example in ptb[2]]

# DataLoader preparation
def collate_batch(batch):
    sequences = pad_sequence(batch, batch_first=True, padding_value=pad_token_idx)
    return sequences[:, :-1], sequences[:, 1:]  # Inputs and targets

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, collate_fn=collate_batch)

# Define the GRU Language Model
class LanguageModelGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, pad_idx):
        super(LanguageModelGRU, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=False, dropout=0.3, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        embedded = self.embedding(x)
        gru_out, _ = self.gru(embedded)
        logits = self.fc(gru_out)
        return logits

# Model Initialization
embedding_dim = 100
hidden_dim = 256
model = LanguageModelGRU(vocab_size, embedding_dim, hidden_dim, pad_token_idx)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=pad_token_idx)
optimizer = AdamW(model.parameters(), lr=0.001)

def train_model(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for inputs, targets in dataloader:
        inputs, targets = inputs, targets
        optimizer.zero_grad()
        logits = model(inputs)
        
        # Replace view with reshape here
        loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate_model(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs, targets
            logits = model(inputs)
            
            # Replace view with reshape here as well
            loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
            total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    return perplexity


# Training Loop
epochs = 10
for epoch in range(epochs):
    train_loss = train_model(model, train_loader, optimizer, criterion)
    val_perplexity = evaluate_model(model, val_loader, criterion)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}")
