In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, repeat
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
import numpy as np
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [3]:
class CONFIG:
    vocab_size = tokenizer.vocab_size + 1
    ignored_index = tokenizer.vocab_size
    block_size = 20
    emb_dim = 16
    d_model = 16
    m_model = 16
    n_heads = 4
    n_ssm_heads = 2
    n_layer = 2
    lr = 0.01
    n_epochs = 5
    batch_size = 16

In [4]:
class TextDataset(Dataset):
    def __init__(self, split, block_size, vocab_size, transform=None, target_transform=None) -> None:
        super().__init__()
        self.split = split
        self.block_size = block_size
        self.vocab_size = vocab_size
        
        dataset = load_dataset("glue", "mrpc", split=split)
        dataset = [tokenizer(i['sentence1'])['input_ids'] for i in dataset]
        self.dataset = [self.pad(x) for x in dataset]
        del dataset
    
    def pad(self, x):
        x = torch.tensor(x, dtype=torch.long)
        if len(x) == self.block_size + 1:
            return x
        elif len(x) > self.block_size + 1:
            idx = torch.randint(len(x) - self.block_size - 1, (1,))
            return x[idx: idx + self.block_size + 1]
        else:
            n_to_pad = self.block_size + 1 - len(x)
            x = F.pad(x, (0, n_to_pad), 'constant', self.vocab_size - 1)
            return x

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx][:-1], self.dataset[idx][1:]

In [5]:
train_ds = TextDataset('train', CONFIG.block_size, CONFIG.vocab_size)
valid_ds = TextDataset('validation', CONFIG.block_size, CONFIG.vocab_size)

Found cached dataset glue (/Users/piotrgabrys/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Found cached dataset glue (/Users/piotrgabrys/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [6]:
train_loader = DataLoader(train_ds, CONFIG.batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_ds, CONFIG.batch_size, shuffle=True, drop_last=True)

In [7]:
class SSM(nn.Module):
    def __init__(self, d_model, m_model, mode) -> None:
        super().__init__()
        self.d_model = d_model
        self.m_model = m_model
        self.mode = mode
        
        if self.mode == 'diag':
            A = torch.randn((self.d_model, self.m_model)) / 100
        elif self.mode == 'shift':
            A = torch.cat([torch.randn((self.d_model, self.m_model - 1)), torch.zeros((self.d_model, 1))], dim=1)
        else:
            raise ValueError('not such mode')
        self.A = nn.Parameter(A)
        B = torch.randn((self.d_model, self.m_model)) / 100
        self.B = nn.Parameter(B)
        C = torch.randn((self.d_model, self.m_model)) / 100
        self.C = nn.Parameter(C)
        D = torch.randn((self.d_model)) / 100
        self.D = nn.Parameter(D)
        E = torch.eye(self.m_model)
        if self.mode == 'shift':
            E = torch.cat([E[:,1:], torch.zeros(m_model, 1)], dim=1)
        E = repeat(E, 'h w -> n h w', n=self.d_model)
        self.register_buffer('E', E)

    def forward(self, x):
        # B T C
        block_size = x.shape[1]
        filter = [(self.C * (self.A ** i) * self.B).sum(1) for i in range(block_size)]
        filter = torch.stack(filter, dim=1).squeeze()
        filter = rearrange(filter, 'b c -> c b')
        x = x.flip(1)
        conv_part = filter * x
        conv = conv_part.cumsum(dim=1)
        conv = conv.flip(1)

        y = conv + self.D * x

        # batch, block, emb_dim
        return y

In [8]:
class H3Head(nn.Module):
    """ one head of H3 """

    def __init__(self, emb_size, head_size, n_ssm_heads, m_model):
        super().__init__()
        self.emb_size = emb_size
        self.head_size = head_size
        self.n_ssm_heads = n_ssm_heads
        self.m_model = m_model

        self.key = nn.Linear(self.emb_size, head_size, bias=False)
        self.ln_key = nn.LayerNorm(head_size)
        self.query = nn.Linear(self.emb_size, head_size, bias=False)
        self.ln_query = nn.LayerNorm(head_size)
        self.value = nn.Linear(self.emb_size, head_size, bias=False)
        self.ln_value = nn.LayerNorm(head_size)
        self.ssm_shift = SSM(self.head_size, self.m_model, 'shift')

        self.values = nn.ModuleList([nn.Linear(self.head_size, int(self.head_size / self.n_ssm_heads), bias=False) for _ in range(self.n_ssm_heads)])
        self.ssm_diags = nn.ModuleList([SSM(int(self.head_size / self.n_ssm_heads), self.m_model, 'diag') for _ in range(self.n_ssm_heads)])
        

    def forward(self, x):
        k = self.key(x)
        k = self.ln_key(k)
        k = self.ssm_shift(k)
        v = self.value(x)
        v = self.ln_key(v)
        v = v * k
        v = torch.cat([ssm_diag(value(v)) for value, ssm_diag in zip(self.values, self.ssm_diags)], dim=2)
        q = self.query(x)
        q = self.ln_key(q)
        q = v * q
        return q

In [9]:
class MultiHeadH3(nn.Module):
    """ multiple heads of H3 in parallel """

    def __init__(self, emb_size, num_heads, head_size, n_ssm_heads, m_model, dropout=0.2):
        super().__init__()
        self.heads = nn.ModuleList([H3Head(emb_size, head_size, n_ssm_heads, m_model) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [10]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [11]:
class H3Block(nn.Module):
    """ H3 block: communication followed by computation """

    def __init__(self, emb_size, num_heads, n_ssm_heads, m_model):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = emb_size // num_heads
        self.sa = MultiHeadH3(emb_size, num_heads, head_size, n_ssm_heads, m_model)
        self.ffwd = FeedFoward(emb_size)
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [12]:
class H3LanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, d_model, m_model, n_heads, n_ssm_heads, n_layer) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.d_model = d_model
        self.m_model = m_model
        self.n_heads = n_heads

        self.emb = nn.Embedding(self.vocab_size, self.emb_dim)

        self.blocks = nn.Sequential(*[H3Block(emb_dim, n_heads, n_ssm_heads, m_model) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(emb_dim) # final layer norm
        self.lm_head = nn.Linear(emb_dim, vocab_size)

        self.apply(self._init_weights)


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, x):
        # B, T
        x = self.emb(x)
        
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        x = self.lm_head(x) # (B,T,vocab_size)

        x = rearrange(x, 'B T vocab -> (B T) vocab')

        return x

In [13]:
model = H3LanguageModel(vocab_size=CONFIG.vocab_size, emb_dim=CONFIG.emb_dim, d_model=CONFIG.d_model, m_model=CONFIG.m_model, n_heads=CONFIG.n_heads, n_ssm_heads=CONFIG.n_ssm_heads, n_layer=CONFIG.n_layer)

In [14]:
loss_func = nn.CrossEntropyLoss(ignore_index=CONFIG.ignored_index)
optimizer = torch.optim.Adam(model.parameters(), CONFIG.lr)

In [15]:
@torch.no_grad()
def estimate_loss():
    model.eval()

    train_loss = 0
    for x, y in train_loader:
        preds = model(x)
        train_loss += loss_func(preds, y.ravel())
    train_loss /= len(train_loader)

    valid_loss = 0
    for x, y in valid_loader:
        preds = model(x)
        valid_loss += loss_func(preds, y.ravel())
    valid_loss /= len(valid_loader)

    model.train()

    return train_loss, valid_loss

In [16]:
for i in range(CONFIG.n_epochs):
    for x, y in train_loader:
        optimizer.zero_grad(set_to_none=True)
        preds = model(x)
        loss = loss_func(preds, y.ravel())
        loss.backward()
        optimizer.step()
    
    train_loss, valid_loss = estimate_loss()
    print(f'EPOCH {i}, train loss: {train_loss:.6f}, valid loss: {valid_loss:.6f}')

EPOCH 0, train loss: 6.720024, valid loss: 7.207736
EPOCH 1, train loss: 6.026439, valid loss: 6.909051
EPOCH 2, train loss: 5.356634, valid loss: 6.824624
EPOCH 3, train loss: 4.846189, valid loss: 6.923506
EPOCH 4, train loss: 4.437585, valid loss: 7.060393
