In [2]:
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

In [3]:
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int

In [4]:
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

In [6]:
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))

In [13]:
class Conv1d(nn.Conv1d):
    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
        return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))


In [17]:
# peek conv1d
x = torch.randn(1, 1, 16000)
c = Conv1d(1, 16, 25)
o = c(x)
o.shape

torch.Size([1, 16, 15976])

In [21]:
def sinusoids(length, channels, max_timescale=10000):
    # channels dim needs to be even because, we need split it half and process with sin and cos
    assert channels % 2 == 0

    # todo
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)

    # todo
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))


In [43]:
channels = 4
length = 5
max_timescale = 10000
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
-log_timescale_increment * torch.arange(channels // 2)
inv_t = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
print(inv_t)
scaled_time = torch.arange(length)[:, np.newaxis] * inv_t[np.newaxis, :]
print(scaled_time)
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).shape

tensor([1.0000e+00, 1.0000e-04])
tensor([[0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-04],
        [2.0000e+00, 2.0000e-04],
        [3.0000e+00, 3.0000e-04],
        [4.0000e+00, 4.0000e-04]])


torch.Size([5, 4])

In [None]:
def sinusoids(length, channels, max_timescale=10000):
    
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
    # 6. Compute and concatenate sine and cosine values
    # Apply sine and cosine functions to the scaled time
    # Concatenate the results along the second dimension (dim=1) to get the final positional encodings

# Example usage:
length = 10  # sequence length
channels = 6  # number of dimensions for positional encoding
max_timescale = 1000  # maximum timescale

pos_encodings = sinusoids(length, channels, max_timescale)
print(pos_encodings.shape)  # Should output: torch.Size([10, 6])
print(pos_encodings)  # Will show the actual positional encodings

In [None]:
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

