In [None]:
import torch
from torch import nn
from torch import optim
import torchtext
from torchtext import data
from torchtext import datasets
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import track
from rich.text import Text





# Define the fields for the text and labels
TEXT = data.Field(sequential=True, batch_first=True, lower=True)
LABEL = data.LabelField()

# Load data splits
train_data, val_data, test_data = datasets.SST.splits(TEXT, LABEL)

# Build vocabulary
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)

# Hyperparameters
vocab_size = len(TEXT.vocab)
label_size = len(LABEL.vocab)  # Number of classes
padding_idx = TEXT.vocab.stoi[TEXT.pad_token]  # Access pad_token correctly
embedding_dim = 128
hidden_dim = 128


# Build iterators
train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (train_data, val_data, test_data),
    batch_size=32,
    sort_within_batch=True,
    sort_key=lambda x: len(x.text)
)

class RNN(nn.Module):
    def __init__(self, i_size, output_size, hidden_dim, number_layers, padding_idx):
        super(RNN, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(i_size, embedding_dim, padding_idx=padding_idx)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, number_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_size)

    def forward(self, x, hidden=None):
        if hidden is None:
            # Initialize hidden state with zeros if not provided
            batch_size = x.size(0)
            hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(x.device)

        # Pass through embedding layer
        x = self.embedding(x)

        # RNN layer
        r_out, hidden = self.rnn(x, hidden)

        # Get the output from the last time step
        r_out = r_out[:, -1, :]  # Take the output of the last time step

        # Fully connected layer
        output = self.fc(r_out)
        return output  # Return logits

# Initialize model, criterion, and optimizer
model = RNN(vocab_size, label_size, hidden_dim, number_layers, padding_idx)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define training function
def train_model(model, iterator, criterion, optimizer):
    model.train()
    epoch_loss = 0
    epoch_acc = 0

    for batch in iterator:
        text, labels = batch.text, batch.label
        optimizer.zero_grad()
        predictions = model(text)
        loss = criterion(predictions, labels)
        acc = (predictions.argmax(dim=1) == labels).float().mean()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

# Define evaluation function
def evaluate_model(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0

    with torch.no_grad():
        for batch in iterator:
            text, labels = batch.text, batch.label
            predictions = model(text)
            loss = criterion(predictions, labels)
            acc = (predictions.argmax(dim=1) == labels).float().mean()

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)





In [None]:
def train_and_evaluate(model, train_iter, val_iter, criterion, optimizer, n_epochs=5):
    for epoch in range(n_epochs):
        train_loss, train_acc = train_model(model, train_iter, criterion, optimizer)
        val_loss, val_acc = evaluate_model(model, val_iter, criterion)

        epoch_panel = Panel(
            f"Epoch {epoch + 1:02}/{n_epochs:02}",
            title="Training Progress",
            title_align="left",
            border_style="bold cyan",
            padding=(1, 2)
        )

        # Create a summary panel
        summary_panel = Panel(
            f"Train Loss: {train_loss:.4f}\nTrain Accuracy: {train_acc:.4f}\n"
            f"Validation Loss: {val_loss:.4f}\nValidation Accuracy: {val_acc:.4f}",
            title="Summary",
            border_style="bold red",
            padding=(1, 2)
        )

        # Display the results
        console.print(epoch_panel)
        console.print(summary_panel)
        console.print("\n" + "--" * 57)

# Call the training and evaluation function
train_and_evaluate(model, train_iter, val_iter, criterion, optimizer, n_epochs=10)