In [None]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F

import wandb



In [None]:
import os

class AsciiTokenizer():
    def __init__(self, vocab_size=128):
        self.vocab_size = vocab_size  # ASCII range
        self.pad_token = 0  # Padding token index
        self.unk_token = 1  # Unknown token index

    def encode(self, text, seq_length):
        tokens = [ord(c) if ord(c) < self.vocab_size else self.unk_token for c in text]
        if len(tokens) < seq_length:
            tokens += [self.pad_token] * (seq_length - len(tokens))  # Pad
        else:
            tokens = tokens[:seq_length]  # Truncate
        return tokens

    def decode(self, tokens):
        chars = [chr(t) if t < self.vocab_size else '?' for t in tokens if t != self.pad_token]
        return ''.join(chars)

class TextDataset(Dataset):
    def __init__(self, dataset_name = "wikipedia", tokeknizer=None, seq_length=64):
        self.tokenizer = tokeknizer if tokeknizer else AsciiTokenizer()
        self.seq_length = seq_length
        if dataset_name == "wikipedia":
            from datasets import load_dataset
            self.dataset = load_dataset("wikimedia/wikipedia", "20231101.en")['train']
        else:
            raise ValueError(f"Dataset {dataset_name} not supported.")


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Check if using ascii tokenizer
        if isinstance(self.tokenizer, AsciiTokenizer):
            # select only seq_length + 1 tokens for input-target pair

            text = self.dataset[idx]
            # random int between 0 and len(text) - seq_length - 1
            if len(text['text']) < self.seq_length + 1:
                start_idx = 0
            else:
                upper_bound = len(text['text']) - self.seq_length - 1
                if upper_bound <= 0:
                    start_idx = 0
                else:
                    start_idx = torch.randint(0, upper_bound, (1,)).item()
            text['text'] = text['text'][start_idx:start_idx + self.seq_length + 1]

            code = self.tokenizer.encode(text['text'], self.seq_length)
            tokens = torch.tensor(code, dtype=torch.long)

            # pad if needed
            if len(tokens) < self.seq_length + 1:
                padding = torch.full((self.seq_length + 1 - len(tokens),), self.tokenizer.pad_token, dtype=torch.long)
                tokens = torch.cat([tokens, padding], dim=0)
        else:
            raise ValueError("Only AsciiTokenizer is supported now.")

       
        return {"input_ids": tokens[:-1], "target_ids": tokens[1:]}  # Shifted for language modeling

In [None]:
dataset = TextDataset(seq_length = 128)

In [None]:
dataset.tokenizer.decode(dataset[0]['target_ids'].tolist())

In [None]:
class ResidualRNN(nn.Module):
    """
    RNN with residual connection that adds input projection to RNN output.

    Args:
        input_size: Size of input features
        hidden_size: Size of RNN hidden state
        output_size: Size of output features
        num_layers: Number of RNN layers (default: 1)
    """
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(ResidualRNN, self).__init__()

        # Validate inputs
        if input_size <= 0 or hidden_size <= 0 or output_size <= 0:
            raise ValueError("All sizes must be positive integers")
        if num_layers <= 0:
            raise ValueError("num_layers must be positive")

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers

        # RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

        # Project input to hidden_size for residual connection
        self.input_projection = nn.Linear(input_size, hidden_size)

        # Final output layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, last_hidden=None):
        """
        Forward pass with residual connection.

        Args:
            x: Input tensor of shape (batch_size, seq_len, input_size)
            last_hidden: Previous hidden state of shape (num_layers, batch_size, hidden_size)
                        If None, initializes to zeros

        Returns:
            out: Output tensor of shape (batch_size, seq_len, output_size)
            hidden: Final hidden state of shape (num_layers, batch_size, hidden_size)
        """
        # Validate input shape
        if x.dim() != 3:
            raise ValueError(f"Expected 3D input (batch, seq, features), got {x.dim()}D")

        batch_size, seq_len, feat_size = x.shape

        if feat_size != self.input_size:
            raise ValueError(
                f"Input feature size {feat_size} doesn't match expected {self.input_size}"
            )

        # Initialize hidden state if not provided
        if last_hidden is None:
            last_hidden = torch.zeros(
                self.num_layers, batch_size, self.hidden_size,
                device=x.device, dtype=x.dtype
            )
        else:
            # Validate hidden state shape
            if last_hidden.shape != (self.num_layers, batch_size, self.hidden_size):
                raise ValueError(
                    f"Hidden state shape {last_hidden.shape} doesn't match expected "
                    f"({self.num_layers}, {batch_size}, {self.hidden_size})"
                )

        # RNN forward pass
        rnn_out, hidden = self.rnn(x, last_hidden)

        # Residual connection: project input and add to RNN output
        residual = self.input_projection(x)
        rnn_out_with_residual = rnn_out + residual

        # Final output projection
        out = self.fc(rnn_out_with_residual)

        return out, hidden


import torch
import torch.nn as nn

class RNNetworkExplicit(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_rnn_layers):
        super(RNNetworkExplicit, self).__init__()
        self.input_proj = nn.Linear(input_size, hidden_size)
        self.rnns = nn.ModuleList([nn.RNN(hidden_size, hidden_size, batch_first=True) for _ in range(num_rnn_layers)])
        self.lns = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(num_rnn_layers)])
        self.mlps = nn.ModuleList([nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size)
        ) for _ in range(num_rnn_layers)])
        self.lns2 = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(num_rnn_layers)])
        self.output_proj = nn.Linear(hidden_size, output_size)
        self.num_rnn_layers = num_rnn_layers  # Added to access in forward

    def forward(self, x, hidden_state=None, return_hidden=False):
        out = self.input_proj(x)
        hidden_states = []
        for i in range(self.num_rnn_layers):
            residual = out
            if hidden_state is not None:
                h = hidden_state[i]
            else:
                h = None
            out, h = self.rnns[i](out, h)
            out = residual + out
            out = self.lns[i](out)
            residual = out
            out = self.mlps[i](out)
            out = residual + out
            out = self.lns2[i](out)
            hidden_states.append(h)
        out = self.output_proj(out)
        if return_hidden:
            return out, hidden_states
        return out


In [None]:

# Test both implementations
input_size = 10
hidden_size = 768
batch_size = 1024
seq_len = 128
learning_rate = 0.0005
num_layers = 16
epochs = 10
max_grad_norm = 1.0  # Gradient clipping value


# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    # Set the wandb project where this run will be logged.
    project="smaLLM",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": learning_rate,
        "architecture": "RNN with Residuals",
        "layer_size": hidden_size,
        "seq_length": seq_len,
        "batch_size": batch_size,
        "num_layers": num_layers,
        "dataset": "Wikipedia english",
        "epochs": epochs,
        "grad_clipping": max_grad_norm,
        "mixed_precision": True,
        "optimizer": "AdamW",
        "scheduler": "LambdaLR with warmup"
    },
)



print("Testing Sequential version:")
model1 = RNNetworkExplicit(128, 768, 128, num_rnn_layers=num_layers)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
criterion = nn.CrossEntropyLoss()

# add warm up to the optimizer

optimizer = torch.optim.AdamW(model1.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: min((step+1)/1000, 1))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1.to(device)

scaler = GradScaler()


for epoch in range(epochs):
    for k, batch in enumerate(dataloader):
        inputs = batch['input_ids'].to(device)
        targets = batch['target_ids'].to(device)
        inputs_onehot = F.one_hot(inputs, num_classes=128).float()

        optimizer.zero_grad()

        # Mixed precision context
        with autocast():
            outputs = model1(inputs_onehot)
            loss = criterion(outputs.view(-1, 128), targets.view(-1))

        # Scale gradients, backward pass
        scaler.scale(loss).backward()

        # Gradient clipping
        scaler.unscale_(optimizer)  # Unscale before clipping
        torch.nn.utils.clip_grad_norm_(model1.parameters(), max_grad_norm)

        # Optimizer step and scaler update
        scaler.step(optimizer)
        scaler.update()

        # Scheduler step
        scheduler.step()

        run.log({"train/loss": loss.item(),
                  "epoch": epoch,
                  "batch": k,
                  "lr": scheduler.get_last_lr()[0]
                 })

        if k%1000 == 0 and k > 0:
            torch.save(model1.state_dict(), f"rnn/rnn_language_model_epoch{epoch}_batch{k}.pth")


In [None]:
def generate_sample(model, tokenizer, start_text, max_length):
    model.eval()
    generated = start_text
    # Encode start_text without forcing full length
    tokens = [ord(c) if ord(c) < tokenizer.vocab_size else tokenizer.unk_token for c in start_text]
    input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)  # (1, len(tokens))
    input_onehot = nn.functional.one_hot(input_ids, num_classes=tokenizer.vocab_size).float()
    hidden = None
    for _ in range(max_length):
        with torch.no_grad():
            output, hidden = model(input_onehot, hidden, return_hidden=True)
            next_token_logits = output[0, -1, :]  # (vocab_size)
            # sample from the distribution
            next_token_id = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1).item()
            if next_token_id == tokenizer.pad_token:
                break
            generated += tokenizer.decode([next_token_id])
            next_input_id = torch.tensor([[next_token_id]], dtype=torch.long)  # (1, 1)
            input_onehot = nn.functional.one_hot(next_input_id, num_classes=tokenizer.vocab_size).float()
    return generated


In [None]:
model1.to('cpu')
out = generate_sample(model1, dataset.tokenizer, "The meaning of life is", max_length=100)
print(out)

In [None]:
# save the model
torch.save(model1.state_dict(), "rnn_final.pth")