In [1]:
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

In [2]:
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int

In [3]:
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

In [4]:
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))

In [5]:
class Conv1d(nn.Conv1d):
    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
        return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))


In [6]:
# peek conv1d
x = torch.randn(1, 1, 16000)
c = Conv1d(1, 16, 25)
o = c(x)
o.shape

torch.Size([1, 16, 15976])

In [7]:

def sinusoids(length, channels, max_timescale=10000):
    # channels dim needs to be even because, we need split it half and process with sin and cos
    assert channels % 2 == 0
    
    # todo
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    
    # todo
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))

    # todo
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]

    # todo
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
    

In [8]:
sinusoids(5, 4).shape

torch.Size([5, 4])

In [10]:
# todo

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
        
    def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
        q = self.query(x)
        
        # todo
        if kv_cache is None or xa is None or self.key not in kv_cache:
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            k = kv_cache[self.key]
            v = kv_cache[self.value]
        
        wv, qk = self.qkv_attention(q, k, v, mask)

        return self.out(wv), qk
    
    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25

        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()    

In [13]:
MultiHeadAttention(5, 5)

MultiHeadAttention(
  (query): Linear(in_features=5, out_features=5, bias=True)
  (key): Linear(in_features=5, out_features=5, bias=False)
  (value): Linear(in_features=5, out_features=5, bias=True)
  (out): Linear(in_features=5, out_features=5, bias=True)
)

In [15]:
# todo
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()
        
        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
        self.mlp_ln = LayerNorm(n_state)
    
    def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

In [18]:
ResidualAttentionBlock(5, 5)

ResidualAttentionBlock(
  (attn): MultiHeadAttention(
    (query): Linear(in_features=5, out_features=5, bias=True)
    (key): Linear(in_features=5, out_features=5, bias=False)
    (value): Linear(in_features=5, out_features=5, bias=True)
    (out): Linear(in_features=5, out_features=5, bias=True)
  )
  (attn_ln): LayerNorm((5,), eps=1e-05, elementwise_affine=True)
  (mlp): Sequential(
    (0): Linear(in_features=5, out_features=20, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=20, out_features=5, bias=True)
  )
  (mlp_ln): LayerNorm((5,), eps=1e-05, elementwise_affine=True)
)