In [None]:
import csv
import re
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from torch import nn
from torch.optim import AdamW
import numpy as np
from datasets import load_dataset

# Specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load the dataset
ptb = load_dataset('ptb_text_only', split=['train', 'validation', 'test'], trust_remote_code=True)

# Tokenization and Vocabulary Building
def tokenize(text):
    return re.findall(r'\w+', text.lower())

def build_vocab(dataset):
    counter = Counter()
    for example in dataset:
        tokens = tokenize(example['sentence'])
        counter.update(tokens)
    vocab = {word: idx for idx, (word, _) in enumerate(counter.items())}
    vocab['<PAD>'] = len(vocab)
    return vocab

vocab = build_vocab(ptb[0])
vocab_size = len(vocab)
pad_token_idx = vocab['<PAD>']

# Convert text to sequences of indices
def encode_text(text, vocab):
    return [vocab[word] for word in tokenize(text) if word in vocab]

# Process each split
train_data = [torch.tensor(encode_text(example['sentence'], vocab)) for example in ptb[0]]
val_data = [torch.tensor(encode_text(example['sentence'], vocab)) for example in ptb[1]]
test_data = [torch.tensor(encode_text(example['sentence'], vocab)) for example in ptb[2]]

# DataLoader preparation
def collate_batch(batch):
    sequences = pad_sequence(batch, batch_first=True, padding_value=pad_token_idx)
    return sequences[:, :-1], sequences[:, 1:]  # Inputs and targets

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=collate_batch)  # Batch size 1 for sentence-level perplexity

# Define the LSTM Language Model with dropout
class LanguageModelLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, pad_idx, dropout_prob=0.5):
        super(LanguageModelLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, dropout=dropout_prob, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        logits = self.fc(lstm_out)
        return logits

# Model Initialization
embedding_dim = 100
hidden_dim = 128
model = LanguageModelLSTM(vocab_size, embedding_dim, hidden_dim, pad_token_idx, dropout_prob=0.5).to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=pad_token_idx)
optimizer = AdamW(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

def train_model(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate_model(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
            total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    return perplexity

# Calculate sentence-wise perplexities on the test set
def evaluate_sentence_perplexities(model, dataloader, criterion, total_sentences):
    model.eval()
    sentence_perplexities = []
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(dataloader):
            if idx >= total_sentences:  # Stop after required sentences
                break
            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.size(1) == 0:
                perplexity = -1  # Placeholder for empty sequence
            else:
                logits = model(inputs)
                loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
                perplexity = np.exp(loss.item())
            sentence_perplexities.append((idx, perplexity))
    return sentence_perplexities

# Training Loop with Early Stopping
epochs = 50
best_val_perplexity = float("inf")
epochs_no_improve = 0
early_stop_patience = 5

for epoch in range(epochs):
    train_loss = train_model(model, train_loader, optimizer, criterion)
    val_perplexity = evaluate_model(model, val_loader, criterion)
    
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}")
    
    # Update learning rate based on validation perplexity
    scheduler.step(val_perplexity)
    
    # Check for early stopping
    if val_perplexity < best_val_perplexity:
        best_val_perplexity = val_perplexity
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stop_patience:
            print("Early stopping triggered.")
            break

# Evaluate on the test set and save sentence-wise perplexities
total_sentences = 3761  # Expected number of test sentences
test_perplexities = evaluate_sentence_perplexities(model, test_loader, criterion, total_sentences)

# Save the perplexities to the required CSV submission file format
with open("submission.csv", mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["ID", "ppl"])  # Header as per requirement
    for idx, perplexity in test_perplexities:
        writer.writerow([idx, perplexity])

print("Submission file 'submission.csv' generated.")
