In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, repeat
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
import numpy as np
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [3]:
class CONFIG:
    vocab_size = tokenizer.vocab_size + 1
    ignored_index = tokenizer.vocab_size
    block_size = 20
    emb_dim = 16
    d_model = 16
    m_model = 8
    n_heads = 4
    lr = 0.01
    n_epochs = 5
    batch_size = 16

In [4]:
class TextDataset(Dataset):
    def __init__(self, split, block_size, vocab_size, transform=None, target_transform=None) -> None:
        super().__init__()
        self.split = split
        self.block_size = block_size
        self.vocab_size = vocab_size
        
        dataset = load_dataset("glue", "mrpc", split=split)
        dataset = [tokenizer(i['sentence1'])['input_ids'] for i in dataset]
        self.dataset = [self.pad(x) for x in dataset]
        del dataset
    
    def pad(self, x):
        x = torch.tensor(x, dtype=torch.long)
        if len(x) == self.block_size + 1:
            return x
        elif len(x) > self.block_size + 1:
            idx = torch.randint(len(x) - self.block_size - 1, (1,))
            return x[idx: idx + self.block_size + 1]
        else:
            n_to_pad = self.block_size + 1 - len(x)
            x = F.pad(x, (0, n_to_pad), 'constant', self.vocab_size - 1)
            return x

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx][:-1], self.dataset[idx][1:]

In [5]:
train_ds = TextDataset('train', CONFIG.block_size, CONFIG.vocab_size)
valid_ds = TextDataset('validation', CONFIG.block_size, CONFIG.vocab_size)

Found cached dataset glue (/Users/piotrgabrys/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Found cached dataset glue (/Users/piotrgabrys/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [6]:
train_loader = DataLoader(train_ds, CONFIG.batch_size, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_ds, CONFIG.batch_size, shuffle=True, drop_last=True)

In [7]:
class SSM(nn.Module):
    def __init__(self, d_model, m_model, mode) -> None:
        super().__init__()
        self.d_model = d_model
        self.m_model = m_model
        self.mode = mode
        
        if self.mode == 'diag':
            A = torch.randn((self.d_model, self.m_model)) / 100
        elif self.mode == 'shift':
            A = torch.cat([torch.randn((self.d_model, self.m_model - 1)), torch.zeros((self.d_model, 1))], dim=1)
        else:
            raise ValueError('not such mode')
        self.A = nn.Parameter(A)
        B = torch.randn((self.d_model, self.m_model)) / 100
        self.B = nn.Parameter(B)
        C = torch.randn((self.d_model, self.m_model)) / 100
        self.C = nn.Parameter(C)
        D = torch.randn((self.d_model)) / 100
        self.D = nn.Parameter(D)


    def forward(self, x):
        # B T C
        block_size = x.shape[1]
        filter = [(self.C * (self.A ** i) * self.B).sum(1) for i in range(block_size)]
        filter = torch.stack(filter, dim=1).squeeze()
        filter = rearrange(filter, 'b c -> c b')
        x = x.flip(1)
        conv_part = filter * x
        conv = conv_part.cumsum(dim=1)
        conv = conv.flip(1)

        y = conv + self.D * x

        # batch, block, emb_dim
        return y

In [8]:
class H3Model(nn.Module):
    def __init__(self, vocab_size, emb_dim, d_model, m_model, n_heads) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.d_model = d_model
        self.m_model = m_model
        self.n_heads = n_heads

        self.emb = nn.Embedding(self.vocab_size, self.emb_dim)

        self.k_proj = nn.Linear(self.emb_dim, self.d_model, bias=False)
        nn.init.xavier_normal_(self.k_proj.weight)
        self.ln1 = nn.LayerNorm(self.d_model)
        self.ssm_shift = SSM(self.d_model, self.m_model, 'shift')

        self.v_proj = nn.Linear(self.emb_dim, self.d_model, bias=False)
        nn.init.xavier_normal_(self.v_proj.weight)
        self.ln2 = nn.LayerNorm(self.d_model)

        self.vhs_proj = nn.ModuleList([nn.Linear(self.d_model, int(self.d_model / n_heads), bias=False) for _ in range(n_heads)])
        for vh_proj in self.vhs_proj:
            nn.init.xavier_normal_(vh_proj.weight)

        self.ssm_diags = nn.ModuleList([SSM(int(self.d_model / n_heads), self.m_model, 'diag') for _ in range(self.n_heads)])

        self.q_proj = nn.Linear(self.emb_dim, self.d_model, bias=False)
        nn.init.xavier_normal_(self.q_proj.weight)
        self.ln3 = nn.LayerNorm(self.d_model)

        self.output_layer = nn.Linear(self.d_model, self.vocab_size)
        nn.init.xavier_normal_(self.output_layer.weight)


    def forward(self, x):
        # B, T
        x = self.emb(x)
        
        k = self.ln1(self.k_proj(x))
        k = self.ssm_shift(k)

        v = self.ln2(self.v_proj(x))

        v = v * k
        v = [ssm_diag(self.vhs_proj[i](v)) for i, ssm_diag in enumerate(self.ssm_diags)]
        v = torch.cat(v, dim=2)

        q = self.ln3(self.q_proj(x))
        out = q * v

        out = self.output_layer(out)

        out = rearrange(out, 'B T C -> (B T) C')

        return out

In [9]:
model = H3Model(vocab_size=CONFIG.vocab_size, emb_dim=CONFIG.emb_dim, d_model=CONFIG.d_model, m_model=CONFIG.m_model, n_heads=CONFIG.n_heads)

In [10]:
loss_func = nn.CrossEntropyLoss(ignore_index=CONFIG.ignored_index)
optimizer = torch.optim.Adam(model.parameters(), CONFIG.lr)

In [11]:
@torch.no_grad()
def estimate_loss():
    model.eval()

    train_loss = 0
    for x, y in train_loader:
        preds = model(x)
        train_loss += loss_func(preds, y.ravel())
    train_loss /= len(train_loader)

    valid_loss = 0
    for x, y in valid_loader:
        preds = model(x)
        valid_loss += loss_func(preds, y.ravel())
    valid_loss /= len(valid_loader)

    model.train()

    return train_loss, valid_loss

In [12]:
for i in range(CONFIG.n_epochs):
    for x, y in train_loader:
        optimizer.zero_grad(set_to_none=True)
        preds = model(x)
        loss = loss_func(preds, y.ravel())
        loss.backward()
        optimizer.step()
    
    train_loss, valid_loss = estimate_loss()
    print(f'EPOCH {i}, train loss: {train_loss:.6f}, valid loss: {valid_loss:.6f}')

EPOCH 0, train loss: 6.869299, valid loss: 7.389570
EPOCH 1, train loss: 6.076054, valid loss: 7.140264
EPOCH 2, train loss: 5.549601, valid loss: 7.070755
EPOCH 3, train loss: 5.123650, valid loss: 7.155102
EPOCH 4, train loss: 4.780811, valid loss: 7.360164
