In [3]:
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

In [4]:
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int

In [5]:
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

In [6]:
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))

In [7]:
class Conv1d(nn.Conv1d):
    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
        return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))


In [8]:
# peek conv1d
x = torch.randn(1, 1, 16000)
c = Conv1d(1, 16, 25)
o = c(x)
o.shape

torch.Size([1, 16, 15976])

In [9]:

def sinusoids(length, channels, max_timescale=10000):
    # channels dim needs to be even because, we need split it half and process with sin and cos
    assert channels % 2 == 0
    
    # todo
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    
    # todo
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))

    # todo
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]

    # todo
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
    

In [11]:
sinusoids(5, 4).shape

torch.Size([5, 4])