In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!cp -r "/content/drive/My Drive/recurrent-transformer-models/model_artifacts/StatefulTransformer_best_model.pth" "StatefulTransformer_best_model.pth"

In [3]:
!rm -rf Recurrent-Neuron-Transformer

In [4]:
!git clone -b add-dev-container-debug-and-run-name https://github.com/ChrisHayduk/Recurrent-Neuron-Transformer.git

Cloning into 'Recurrent-Neuron-Transformer'...
remote: Enumerating objects: 698, done.[K
remote: Counting objects: 100% (250/250), done.[K
remote: Compressing objects: 100% (140/140), done.[K
remote: Total 698 (delta 186), reused 166 (delta 109), pack-reused 448[K
Receiving objects: 100% (698/698), 27.82 MiB | 10.89 MiB/s, done.
Resolving deltas: 100% (443/443), done.


In [5]:
!cd Recurrent-Neuron-Transformer/ && pip install -r requirements.txt



In [6]:
import torch
from torch import nn

class Neurons(nn.Module):
    def __init__(self, n_neurons, device):
        super(Neurons, self).__init__()
        self.device = device

        # Initialize matrix neuron parameters and number of neurons to create
        self.n_neurons = n_neurons
        self.params = nn.Parameter(torch.rand(n_neurons, 3, 3) * 2 - 1)
        self.gelu = nn.GELU()

    def forward(self, inputs, hidden_state=None):
        if hidden_state is not None:
            hidden_state = hidden_state.detach()
        else:
            hidden_state = torch.zeros(1, self.n_neurons, 1, device=self.device)

        batch_size = inputs.shape[0]
        seq_len = inputs.shape[1]

        hidden_batch = hidden_state.expand(batch_size, seq_len, self.n_neurons, 1)
        inputs = inputs.view(batch_size, seq_len, -1, 1)
        ones = torch.ones_like(inputs)


        # Concatenate along the last dimension
        stacked = torch.cat((inputs, hidden_batch, ones), dim=3)

        # Reshape stacked for matrix multiplication: [batch_size, seq_len, n_neurons, 3]
        stacked = stacked.view(batch_size, seq_len, self.n_neurons, 3)

        # Perform matrix multiplication
        dot = self.gelu(torch.matmul(self.params, stacked.unsqueeze(4)).squeeze(4))

        # Update hidden state without in-place operation
        new_hidden = dot[:, :, :, 1].unsqueeze(3).detach()

        return dot[:, :, :, 0], new_hidden

class RecurrentNeuronLayer(nn.Module):
    def __init__(self, input_size, output_size, device):
        super(RecurrentNeuronLayer, self).__init__()
        self.neurons = Neurons(output_size, device)
        self.weights = nn.Linear(input_size, output_size)
        self.device = device

    def forward(self, x, hidden_state=None):
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        x = self.weights(x)
        x, updated_hidden_state = self.neurons(x, hidden_state)

        # Reshape the output to ensure it has the shape [batch_size, n_classes]
        final_output = x.view(batch_size, seq_len, -1)

        return final_output, updated_hidden_state

In [7]:
import numpy as np
import math
import torch
from torch import nn
import random
import torch.functional as F
from dataclasses import dataclass

@dataclass
class RecurrentModelConfig:
    max_length: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    num_heads: int = 12
    hidden_dim: int = 768
    dropout: float = 0.0
    device: str = "cuda"
    recurrent_layers: str = "all"

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        if "proj" == config.recurrent_layers or "all" == config.recurrent_layers:
            self.c_fc = RecurrentNeuronLayer(config.hidden_dim, 4 * config.hidden_dim, config.device)
        else:
            self.c_fc = nn.Linear(config.hidden_dim, 4 * config.hidden_dim)

        self.gelu = nn.GELU()

        if "proj" == config.recurrent_layers or "all" == config.recurrent_layers:
            self.c_proj = RecurrentNeuronLayer(4 * config.hidden_dim, config.hidden_dim, config.device)
        else:
            self.c_proj = nn.Linear(4 * config.hidden_dim, config.hidden_dim)

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, hidden_layers = None, layer_num=0):
        if isinstance(self.c_fc, RecurrentNeuronLayer):
            x, hidden_layers[f"c_fc_{layer_num}"] = self.c_fc(x, hidden_layers.get(f"c_fc_{layer_num}"))
        else:
            x =  self.c_fc(x)

        x = self.gelu(x)

        if isinstance(self.c_proj, RecurrentNeuronLayer):
            x, hidden_layers[f"c_proj_{layer_num}"] = self.c_proj(x, hidden_layers.get(f"c_proj_{layer_num}"))
        else:
            x =  self.c_proj(x)

        x = self.dropout(x)
        return x, hidden_layers

class RecurrentCausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.hidden_dim % config.num_heads == 0
        # key, query, value projections for all heads, but in a batch
        if "qkv" == config.recurrent_layers or "all" == config.recurrent_layers:
            self.c_attn = RecurrentNeuronLayer(config.hidden_dim, 3 * config.hidden_dim, config.device)
        else:
            self.c_attn = nn.Linear(config.hidden_dim, 3 * config.hidden_dim)
        # output projection

        if "qkv" == config.recurrent_layers or "all" == config.recurrent_layers:
            self.c_proj = RecurrentNeuronLayer(config.hidden_dim, config.hidden_dim, config.device)
        else:
            self.c_proj = nn.Linear(config.hidden_dim, config.hidden_dim)

        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_head = config.num_heads
        self.n_embd = config.hidden_dim
        self.dropout = config.dropout
        self.max_length = config.max_length

        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(self.max_length, self.max_length))
                                        .view(1, 1, self.max_length, self.max_length))

    def forward(self, x, hidden_layers=None, layer_num=0):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        if isinstance(self.c_attn, RecurrentNeuronLayer):
            proj_output, hidden_layers[f"c_attn_{layer_num}"]  = self.c_attn(x, hidden_layers.get(f"c_attn_{layer_num}"))
        else:
            proj_output = self.c_attn(x)

        q, k, v = proj_output.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        if isinstance(self.c_proj, RecurrentNeuronLayer):
            y, hidden_layers[f"c_proj_{layer_num}"] = self.c_proj(y, hidden_layers.get(f"c_proj_{layer_num}"))
        else:
            y = self.c_proj(y)

        y = self.resid_dropout(y)
        return y, hidden_layers

class RecurrentTransformerBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.hidden_dim)
        self.attn = RecurrentCausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.hidden_dim)
        self.mlp = MLP(config)

    def forward(self, x, hidden_layers = None, layer_num = 0):
        new_x, hidden_layers = self.attn(self.ln_1(x), hidden_layers, layer_num)
        x = x + new_x
        new_x, hidden_layers = self.mlp(self.ln_2(x), hidden_layers, layer_num)
        x = x + new_x
        return x, hidden_layers

class RecurrentNeuronTransformer(nn.Module):
    """
    A single-layer Transformer which encodes a sequence of text and
    performs binary classification.

    The model has a vocab size of V, works on
    sequences of length T, has an hidden dimension of H, uses word vectors
    also of dimension H, and operates on minibatches of size N.
    """
    def __init__(self, config):
        """
        :config
        """
        super(RecurrentNeuronTransformer, self).__init__()
        assert config.hidden_dim % config.num_heads == 0
        assert config.recurrent_layers in set(["qkv", "proj", "all", "none"])

        print(config)

        self.num_heads = config.num_heads
        self.word_embedding_dim = config.hidden_dim
        self.hidden_dim = config.hidden_dim
        self.max_length = config.max_length
        self.vocab_size = config.vocab_size
        self.device = config.device
        self.dropout = config.dropout

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(self.vocab_size, self.word_embedding_dim),
            wpe = nn.Embedding(self.max_length, self.word_embedding_dim),
            drop = nn.Dropout(self.dropout),
            h = nn.ModuleList([RecurrentTransformerBlock(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(self.hidden_dim),
        ))


        self.lm_head = RecurrentNeuronLayer(self.hidden_dim, self.vocab_size, self.device)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("Number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, RecurrentNeuronLayer):
            neuron_module = module.neurons
            torch.nn.init.normal_(neuron_module.params, mean=0.0, std=0.02)
            linear_module = module.weights
            torch.nn.init.normal_(linear_module.weight, mean=0.0, std=0.02)
            if linear_module.bias is not None:
                torch.nn.init.zeros_(linear_module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, inputs, hidden_layers):
        """
        This function computes the full Transformer forward pass.
        Put together all of the layers you've developed in the correct order.

        :param inputs: a PyTorch tensor of shape (N,T). These are integer lookups.

        :returns: the model outputs. Should be scores of shape (N,T,output_size).
        """

        embeddings = self.embed(inputs)
        x = self.transformer.drop(embeddings)
        for idx, block in enumerate(self.transformer.h):
            x, hidden_layers = block(x, hidden_layers, idx)
        x = self.transformer.ln_f(x)
        outputs, hidden_layers["lm_output"] = self.lm_head(x, hidden_layers.get("lm_output"))


        return outputs, hidden_layers


    def embed(self, inputs):
        """
        :param inputs: intTensor of shape (N,T)
        :returns embeddings: floatTensor of shape (N,T,H)
        """

        pos = torch.arange(0, self.max_length, dtype=torch.long, device=self.device) # shape (t)
        tok_emb = self.transformer.wte(inputs) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        embeddings  = tok_emb + pos_emb

        return embeddings

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, hidden_layers = None, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, hidden_layers = self(idx_cond, hidden_layers)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import tiktoken
import os

class TextDataset(Dataset):
    def __init__(self, tokens, seq_length, bpe_tokenizer, vocab_size, device):
        self.tokens = tokens
        self.tokenizer = tiktoken.get_encoding(bpe_tokenizer)
        self.seq_length = seq_length
        self.vocab_size = vocab_size
        self.device = device

    def __len__(self):
        return len(self.tokens) - self.seq_length - 1

    def __getitem__(self, idx):
        input_seq = torch.tensor(self.tokens[idx : idx+self.seq_length], device=self.device)
        target_seq = torch.tensor(self.tokens[idx+1 : idx+1+self.seq_length], device=self.device)
        return input_seq, target_seq


class TextDataLoader:
    def __init__(self, file_path, seq_length, bpe_tokenizer, batch_size, vocab_size, device, split_ratio=0.8):
        self.file_path = file_path
        self.seq_length = seq_length
        self.bpe_tokenizer = bpe_tokenizer
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.device = device
        self.split_ratio = split_ratio

    def load_and_tokenize(self):
        try:
            with open(self.file_path, 'r', encoding='utf-8') as f:
                text = f.read()
            return text
        except IOError:
            print(f"Error opening/reading {self.file_path}")
            return None

    def _create_datasets(self):
        text = self.load_and_tokenize()
        tokenizer = tiktoken.get_encoding(self.bpe_tokenizer)
        tokens = tokenizer.encode_ordinary(text)
        split_index = int(len(tokens) * self.split_ratio)
        train_tokens = tokens[:split_index]
        test_tokens = tokens[split_index:]
        train_dataset = TextDataset(train_tokens, self.seq_length, self.bpe_tokenizer, self.vocab_size, self.device)
        test_dataset = TextDataset(test_tokens, self.seq_length, self.bpe_tokenizer, self.vocab_size, self.device)
        return train_dataset, test_dataset

    def create_loaders(self):
        train_dataset, test_dataset = self._create_datasets()
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True)
        return train_loader, test_loader


In [9]:
# Device configuration
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


In [10]:
model_config = RecurrentModelConfig(max_length=512, vocab_size=50257,
                                    n_layer=8, num_heads=8, hidden_dim=768,
                                    dropout=0.1, device=DEVICE, recurrent_layers="all")

model = RecurrentNeuronTransformer(config=model_config).to(DEVICE)

RecurrentModelConfig(max_length=512, vocab_size=50257, n_layer=8, num_heads=8, hidden_dim=768, dropout=0.1, device=device(type='cuda'), recurrent_layers='all')
Number of parameters: 134.90M


In [11]:
model.load_state_dict(torch.load('StatefulTransformer_best_model.pth'))

<All keys matched successfully>

In [12]:
data_loader = TextDataLoader(file_path="Recurrent-Neuron-Transformer/data/shakespeare/tinyshakespeare.txt",
                             seq_length=4096,
                             bpe_tokenizer='gpt2',
                             batch_size=12,
                             vocab_size=50257,
                             split_ratio=0.8,
                             device=DEVICE)
train_loader, test_loader = data_loader.create_loaders()

In [13]:
def recurrent_transformer_forward(model, input_seq, hidden_layers, target_seq):
    # Forward pass
    outputs, hidden_layers = model(inputs=input_seq, hidden_layers=hidden_layers)
    outputs = outputs.reshape(-1, outputs.size(-1))
    target_seq = target_seq.reshape(-1)

    # Calculate loss
    loss = nn.CrossEntropyLoss()(outputs, target_seq)

    return (outputs, hidden_layers), loss

In [14]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="transformer-testing",
)

wandb.run.name = "StatefulTransformer_Stepwise-Eval"
wandb.define_metric("eval_batch")

wandb.define_metric("epoch/*", step_metric="epoch")
wandb.define_metric("eval_batch/*", step_metric="eval_batch")



[34m[1mwandb[0m: Currently logged in as: [33mchayduk[0m ([33mrecurrent-neuron-transformer[0m). Use [1m`wandb login --relogin`[0m to force relogin


<wandb.sdk.wandb_metric.Metric at 0x7daa409e8130>

In [None]:
from tqdm import tqdm

context_window = 512
step_size = 511
distributed = False
rank = 0


model.eval()
epoch_val_loss = 0
eval_progress_bar = tqdm(test_loader, desc=f'Evaluating: Epoch 1', leave=False)
num_steps = (4096/511)

if int(num_steps) < num_steps:
  num_steps = int(num_steps) + 1
else:
  num_steps = int(num_steps)

step_losses = [0] * num_steps
raw_losses = [0] * num_steps
with torch.no_grad():
    for batch_idx, (input_chunk, target_chunk) in enumerate(eval_progress_bar):
        batch_loss = 0
        hidden_layers = dict()

        for i in range(0, input_chunk.size(1) - context_window, step_size):
            index = i//step_size
            # Create the input and target sequences
            input_seq = input_chunk[:, i:i+context_window].to(DEVICE)
            target_seq = target_chunk[:, i+1:i+context_window+1].to(DEVICE)

            outputs = None
            loss = None
            (outputs, hidden_layers), loss = recurrent_transformer_forward(model, input_seq, hidden_layers, target_seq)

            batch_loss += loss.item()

            raw_losses[index] += loss.item()

            step_losses[index] = raw_losses[index] / (batch_idx + 1)

            if (rank == 0 or not distributed):
                wandb.log({'eval_batch': batch_idx, f'eval_batch/step_{index}_loss': loss.item()})
                wandb.log({'eval_batch': batch_idx, f'eval_batch/step_{index}_average_loss': step_losses[index]})

        epoch_val_loss += batch_loss
        eval_progress_bar.set_postfix(loss=batch_loss)

avg_val_loss = epoch_val_loss / len(test_loader)

wandb.log({'epoch': 1, 'epoch/val_loss': avg_val_loss})


Evaluating: Epoch 1:  77%|███████▋  | 4054/5292 [7:42:08<2:21:08,  6.84s/it, loss=143]

In [None]:
torch.cuda.empty_cache()