In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import urllib.request

# Load small text corpus (tiny Shakespeare)
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
response = urllib.request.urlopen(url)
text = response.read().decode('utf-8')

# Preprocess the text
chars = sorted(list(set(text)))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)

data = [char_to_idx[ch] for ch in text]

# Define dataset
class CharDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        return (
            torch.tensor(self.data[idx:idx+self.seq_len]),
            torch.tensor(self.data[idx+1:idx+self.seq_len+1])
        )

BATCH_SIZE = 64
seq_len = 100
dataset = CharDataset(data, seq_len)
# THE FIX: Set drop_last=True to avoid smaller, problematic batches
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Define simple character-level RNN
class CharRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(CharRNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.n_layers, batch_size, self.hidden_size)

hidden_size = 128
n_layers = 1
model = CharRNN(vocab_size, hidden_size, vocab_size, n_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
print('Starting RNN training...')
n_epochs = 5 # Reduced for quicker validation
for epoch in range(n_epochs):
    # THE FIX: Hidden state is now initialized correctly inside the loop
    for x, y in dataloader:
        hidden = model.init_hidden(BATCH_SIZE) # Initialize with the constant batch size
        
        outputs, hidden = model(x, hidden)
        loss = criterion(outputs.view(-1, vocab_size), y.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.4f}')
print('Training complete.')

Starting RNN training...
Epoch 1/5, Loss: 1.4594
Epoch 2/5, Loss: 1.4896
Epoch 3/5, Loss: 1.4587
Epoch 4/5, Loss: 1.4407
Epoch 5/5, Loss: 1.4409
Training complete.


In [3]:
def prune_and_sample(logits, threshold=0.1):
    # Assuming logits is the output of the linear layer (batch_size, seq_len, vocab_size)
    if logits.dim() > 1:
        logits = logits[-1]
    # Treat logits as a matrix for SVD (though it's 1D, reshape if needed; here assume transition tensor is fc.weight)
    transition = model.fc.weight.data  # (vocab_size, hidden_size)
    U, S, Vh = torch.svd(transition)
    # Prune low singular values
    mask = S > (threshold * S.max())
    S_pruned = S * mask.float()
    transition_pruned = U @ torch.diag(S_pruned) @ Vh.T
    # Recompute logits with pruned transition (simplified; in practice, apply to hidden state)
    hidden_last = hidden[-1, -1]  # Last hidden state
    logits_pruned = hidden_last @ transition_pruned.T + model.fc.bias
    probs = torch.softmax(logits_pruned, dim=0)
    return torch.multinomial(probs, num_samples=1).item()

def generate_text(model, seed, length=200, use_qiep=False):
    model.eval()
    hidden = model.init_hidden(1)
    generated = [char_to_idx[ch] for ch in seed]
    for ch in generated[:-1]:
        output, hidden = model(torch.tensor([[ch]]), hidden)
    for _ in range(length):
        output, hidden = model(torch.tensor([[generated[-1]]]), hidden)
        if use_qiep:
            next_char = prune_and_sample(output[0])
        else:
            probs = torch.softmax(output[0, -1], dim=0)
            next_char = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_char)
    return ''.join([idx_to_char[i] for i in generated])

# Generate samples
seed = 'To be or not to be'
standard_text = generate_text(model, seed, use_qiep=False)
qiep_text = generate_text(model, seed, use_qiep=True)
print('Standard Generation:', standard_text)
print('QIEP Generation:', qiep_text)

Standard Generation: To be or not to be,-hating must as stilded king, break of the callide I honourse
have winds: and my our point.

BENVOLIO:
Up that I say, alonio thee,
And on she speve not him.

Third Anmator not, cure,
Send,
incelved, 
QIEP Generation: To be or not to beaieooieaoiiioaeiaiiiooiooeaiiiiiiaueeoeeaieoeeaooieoiaiaeaeiaeeioioooioiieioeoaiioaieieiaoiiiaieoiiiaeaoeaiieiieieaaeiiiieaeaiioiieeioeaiieioueoieeeiaeieeoiiiiaeoiiooiaeiaeaeeoeaeooooiiiiiuieioaiiauei


In [4]:
def shannon_entropy(probs):
    return -torch.sum(probs * torch.log2(probs + 1e-10))

def compute_average_entropy(model, seed, num_steps=100, use_qiep=False):
    model.eval()
    hidden = model.init_hidden(1)
    generated = [char_to_idx[ch] for ch in seed]
    for ch in generated[:-1]:
        output, hidden = model(torch.tensor([[ch]]), hidden)
    entropies = []
    for _ in range(num_steps):
        output, hidden = model(torch.tensor([[generated[-1]]]), hidden)
        if use_qiep:
            # Recompute pruned logits
            transition = model.fc.weight.data
            U, S, Vh = torch.svd(transition)
            mask = S > (0.1 * S.max())
            S_pruned = S * mask.float()
            transition_pruned = U @ torch.diag(S_pruned) @ Vh.T
            hidden_last = hidden[-1, -1]
            logits_pruned = hidden_last @ transition_pruned.T + model.fc.bias
            probs = torch.softmax(logits_pruned, dim=0)
        else:
            probs = torch.softmax(output[0, -1], dim=0)
        entropy = shannon_entropy(probs)
        entropies.append(entropy.item())
        # Sample next char to continue (but we don't use it for entropy calc)
        next_char = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_char)
    return np.mean(entropies)

seed = 'To be or not to be'
standard_entropy = compute_average_entropy(model, seed, use_qiep=False)
qiep_entropy = compute_average_entropy(model, seed, use_qiep=True)
print(f'Standard Average Entropy: {standard_entropy:.4f}')
print(f'QIEP Average Entropy: {qiep_entropy:.4f}')
assert qiep_entropy < standard_entropy, 'QIEP should have lower entropy'

Standard Average Entropy: 2.1177
QIEP Average Entropy: 2.0414
