In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import urllib.request

# Load small text corpus (tiny Shakespeare)
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
response = urllib.request.urlopen(url)
text = response.read().decode('utf-8')

# Preprocess the text
chars = sorted(list(set(text)))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)

data = [char_to_idx[ch] for ch in text]

# Define dataset
class CharDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        return (
            torch.tensor(self.data[idx:idx+self.seq_len]),
            torch.tensor(self.data[idx+1:idx+self.seq_len+1])
        )

seq_len = 100
dataset = CharDataset(data, seq_len)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Define simple character-level RNN
class CharRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(CharRNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, n_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.n_layers, batch_size, self.hidden_size)

hidden_size = 128
n_layers = 1
model = CharRNN(vocab_size, hidden_size, vocab_size, n_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
n_epochs = 10  # Small number for laptop-scale
for epoch in range(n_epochs):
    hidden = model.init_hidden(64)
    for x, y in dataloader:
        hidden = hidden.detach()
        outputs, hidden = model(x, hidden)
        loss = criterion(outputs.view(-1, vocab_size), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.4f}')

In [None]:
def prune_and_sample(logits, threshold=0.1):
    # Assuming logits is the output of the linear layer (batch_size, seq_len, vocab_size)
    # For simplicity, process last token's logits (1, vocab_size)
    if logits.dim() > 1:
        logits = logits[-1]
    # Treat logits as a matrix for SVD (though it's 1D, reshape if needed; here assume transition tensor is fc.weight)
    # For demo, use SVD on fc.weight as transition tensor
    transition = model.fc.weight.data  # (vocab_size, hidden_size)
    U, S, Vh = torch.svd(transition)
    # Prune low singular values
    mask = S > (threshold * S.max())
    S_pruned = S * mask.float()
    transition_pruned = U @ torch.diag(S_pruned) @ Vh.T
    # Recompute logits with pruned transition (simplified; in practice, apply to hidden state)
    hidden_last = hidden[-1, -1]  # Last hidden state
    logits_pruned = hidden_last @ transition_pruned.T + model.fc.bias
    probs = torch.softmax(logits_pruned, dim=0)
    return torch.multinomial(probs, num_samples=1).item()

def generate_text(model, seed, length=200, use_qiep=False):
    model.eval()
    hidden = model.init_hidden(1)
    generated = [char_to_idx[ch] for ch in seed]
    for ch in generated[:-1]:
        output, hidden = model(torch.tensor([[ch]]), hidden)
    for _ in range(length):
        output, hidden = model(torch.tensor([[generated[-1]]]), hidden)
        if use_qiep:
            next_char = prune_and_sample(output[0])
        else:
            probs = torch.softmax(output[0, -1], dim=0)
            next_char = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_char)
    return ''.join([idx_to_char[i] for i in generated])

# Generate samples
seed = 'To be or not to be'
standard_text = generate_text(model, seed, use_qiep=False)
qiep_text = generate_text(model, seed, use_qiep=True)
print('Standard Generation:', standard_text)
print('QIEP Generation:', qiep_text)