Event based GRU was publised as a conference paper at ICLR 2023: 

**Efficient recurrent architectures through activity sparsity and sparse back-propagation through time (notable-top-25%)**

![egru_qr](media/egru_paper_qr.png "egru_qr")

In [None]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import json
import matplotlib.pyplot as plt

In [None]:
%pip install git+https://github.com/Efficient-Scalable-Machine-Learning/EvNN.git@feature/egru_cell

In [None]:
from evnn_pytorch import EGRU

<!-- ![EGRUanim](https://github.com/Efficient-Scalable-Machine-Learning/EvNN/raw/main/media/videos/anim/1080p60/EvNNPlot_ManimCE_v0.17.2.gif "egru") -->

<img src="https://github.com/Efficient-Scalable-Machine-Learning/EvNN/raw/main/media/videos/anim/1080p60/EvNNPlot_ManimCE_v0.17.2.gif" alt="egru" width="1000"/>

In [None]:
# Download and unzip the trained model
!wget -q -O download.zip https://datashare.tu-dresden.de/s/jbzaoqFXwCLYHJF/download
!unzip -o download.zip

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# load ascii mapping
filename = "Enwik8/index2word.json"
with open(filename, 'r', encoding='utf-8') as fp:
    i2w = json.load(fp)

filename = "Enwik8/word2index.json"
with open(filename, 'r', encoding='utf-8') as fp:
    w2i = json.load(fp)

In [None]:
eos = w2i.pop("<eos>")
w2i = {chr(int(c)):i for c,i in w2i.items()}

In [None]:
n_vocab = len(i2w)
print("Total Vocab: ", n_vocab)

In [None]:
from typing import Union


class Decoder(nn.Module):
    def __init__(self,
                 ninp: int,
                 ntokens: int,
                 project: bool = False,
                 nemb: Union[None, int] = None,
                 dropout: float = 0.0):
        """
        Takes hidden states of RNNs, optionally applies a projection operation and decodes to output tokens
        :param ninp: Input dimension
        :param ntokens: Number of tokens of the language model
        :param project: If True, applies a linear projection onto the embedding dimension
        :param nemb: If projection is True, specifies the dimension of the projection
        :param dropout: Dropout rate applied to the projector
        """
        super(Decoder, self).__init__()

        if project:
            assert nemb, "If projection is True, must specify nemb!"

        self.ninp = ninp
        self.nemb = nemb if nemb else ninp
        self.nout = ntokens

        self.dropout = dropout

        # projector
        self.project = project
        if project:
            self.projection = nn.Linear(ninp, nemb)
        else:
            self.projection = nn.Identity()

        # word embedding decoder
        self.decoder = nn.Linear(self.nemb, self.nout)
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x):
        bs, seq_len, ninp = x.shape
        if self.project:
            x = x.view(-1, ninp)
            x = F.relu(self.projection(x))
            x = x.view(bs, seq_len, self.nemb)
        x = x.view(-1, self.nemb)
        x = self.decoder(x)
        return x

In [None]:
class CharModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = nn.Embedding(n_vocab, 400)
        self.rnns = nn.ModuleList([
        EGRU(400, 800, batch_first=False),
        EGRU(800, 800, batch_first=False),
        EGRU(800, 800, batch_first=False)]
        )
        self.decoder = Decoder(ninp=800, ntokens=n_vocab,
                               project=True, nemb=400)

    def forward(self, x, y_pre=[None]*3, h_pre=[None]*3):
        y_new=[]
        h_new=[]
        x = self.embeddings(x)
        x, h, _ = self.rnns[0].step(x.squeeze(0), y_pre[0], h_pre[0])
        y_new.append(x.detach().clone())
        h_new.append(h.detach().clone())
        x, h, _ = self.rnns[1].step(x, y_pre[1], h_pre[1])
        y_new.append(x.detach().clone())
        h_new.append(h.detach().clone())
        x, h, _ = self.rnns[2].step(x, y_pre[2], h_pre[2])
        y_new.append(x.detach().clone())
        h_new.append(h.detach().clone())

        # produce output
        x = self.decoder(x.unsqueeze(0))
        return x, y_new, h_new

In [None]:
model = CharModel().to(device)
model.eval()
model

In [None]:
# Generation using the trained model
best_model = torch.load(
    "Enwik8/2024-05-16-Enwik8-EGRU-trained/checkpoints/EGRU_best_model.cpt", map_location=device)
model.load_state_dict(best_model)

In [None]:
# Define a prompt to start the generation
prompt = "William Shakespeare was an English playwright, poet and actor. He is widely regarded as the greatest writer in the English language and the world's pre-eminent "

# convert the prompt into tokens
x = [w2i[c] for c in prompt]

In [None]:
# Process prompt to a torch tensor
x = np.reshape(x, (len(x), 1))
x = torch.tensor(x, dtype=torch.int, device=device)

# initialize EGRU hidden states
state = [None]*3
internal_state = [None]*3

# Task 1: Run the model on your prompt
To familiarize with the inner workings of the model try to measure it's activity sparsity on your prompt.
The code would involve looping over the tokens and updating the hidden states accordingly.

# Task 2: Write a text generator
The model produces logits of a distribution over the vocabulary at each time step.
To generate text, we can sample from this distribution.
First, apply softmax to the logits, and then sample from the distribution.
Usual strategies involve sampling the most likely token greedily. 
However, sampling from the distribution with a temperature parameter can produce more diverse text.