In [1]:
import torch
import torch.nn as nn
from config import device
import time

In [2]:
class stable_log_softmax(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x: torch.Tensor, dim: int=-1):
        max_values = torch.max(x, dim=dim, keepdim=True).values
        shifted_x = x - max_values
        log_probs = shifted_x - torch.logsumexp(shifted_x, dim=dim, keepdim=True)
        return log_probs

In [20]:
gen = torch.manual_seed(42)
x = torch.randn((1, 1, 50), requires_grad=True, generator=gen)

In [21]:
x = torch.nn.init.xavier_uniform_(x)

In [22]:
x

tensor([[[ 0.2032, -0.1452, -0.1461, -0.1461,  0.2203,  0.0816,  0.2357,
          -0.2022, -0.2430, -0.1916, -0.1648,  0.0992,  0.0877,  0.2035,
          -0.1265, -0.1670,  0.1300, -0.0990,  0.1487, -0.0581,  0.1401,
          -0.1903, -0.1236,  0.0747,  0.0518, -0.0625,  0.1460,  0.1665,
          -0.1776, -0.1308,  0.2243, -0.0827, -0.0868, -0.2370, -0.1403,
           0.0612, -0.0323, -0.1778,  0.0057, -0.1673, -0.2078, -0.1349,
          -0.2144, -0.1560,  0.2449,  0.0463,  0.0755, -0.2285, -0.1609,
          -0.0815]]], requires_grad=True)

In [23]:
m = stable_log_softmax()

In [24]:
seq_len = 10
total_log_prob = 1.0

In [30]:
torch.max(x, dim=-1).values.item()

0.2448531687259674

In [31]:
batch_size = 1
seq_len = 10
vocab_size = 50
gen = torch.manual_seed(42)
logits = torch.randn((batch_size, seq_len, vocab_size), generator=gen)
target_ids = torch.randint(0, vocab_size, (batch_size, seq_len), generator=gen)

In [33]:
logits[0, 0, :], target_ids[0, :]

(tensor([ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
         -0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624,
          1.6423, -0.1596, -0.4974,  0.4396, -0.7581,  1.0783,  0.8008,  1.6806,
          1.2791,  1.2964,  0.6105,  1.3347, -0.2316,  0.0418, -0.2516,  0.8599,
         -1.3847, -0.8712, -0.2234,  1.7174,  0.3189, -0.4245,  0.3057, -0.7746,
         -1.5576,  0.9956, -0.8798, -0.6011, -1.2742,  2.1228, -1.2347, -0.4879,
         -0.9138, -0.6581]),
 tensor([37, 39, 29,  7, 35, 38, 17, 14,  7, 24]))

In [34]:
out = torch.gather(logits, dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

In [35]:
out.shape

torch.Size([1, 10])

In [36]:
out

tensor([[-0.4245, -0.4816,  1.0119, -0.4175, -1.0811,  0.9733, -0.0553, -1.7735,
         -0.0388,  0.4151]])

In [38]:
logits[0, 1, target_ids[0, 1]]

tensor(-0.4816)

In [39]:
torch.sum(out, dim=-1)

tensor([-1.8721])

In [13]:
from maths.sequence_score import sequence_logprob, compute_nll
from maths.softmax import stable_log_softmax

In [4]:
m_sftmax = stable_log_softmax()

In [9]:
gen1 = torch.manual_seed(42)
gen2 = torch.manual_seed(43)
logits_a = torch.randn((10, 2, 50), generator=gen1)
target_ids_a = torch.randint(0, 50, (10, 2), generator=gen1)
logits_b = torch.randn((10, 10, 50), generator=gen2)
target_ids_b = torch.randint(0, 50, (10, 10), generator=gen2)

In [10]:
seq_probs_a = sequence_logprob(logits_a, target_ids_a)
seq_probs_b = sequence_logprob(logits_b, target_ids_b)

In [11]:
seq_probs_a

tensor([ -9.5571,  -6.9069, -12.4230,  -8.7082, -11.6055, -10.6919,  -9.3203,
        -10.5201,  -8.8899,  -6.3690])

In [12]:
seq_probs_b

tensor([-46.3163, -44.3665, -44.4756, -40.5072, -49.0028, -42.1155, -37.8311,
        -42.5875, -46.1512, -46.3185])

In [14]:
seq_nll_a = compute_nll(logits_a, target_ids_a)
seq_nll_b = compute_nll(logits_b, target_ids_b)

In [15]:
seq_nll_a, seq_nll_b

(tensor(4.7496), tensor(4.3967))

In [2]:
from maths.softmax import stable_log_softmax, stable_softmax

In [3]:
stb_sft = stable_softmax()
stb_log_sft = stable_log_softmax()

In [18]:
def compute_entropy(logits: torch.Tensor):
    probs = stb_sft(logits, dim=-1)
    log_probs = stb_log_sft(logits, dim=-1)
    out = - torch.sum(probs * log_probs, dim=-1)
    return out

In [19]:
gen = torch.manual_seed(42)
logits_a = torch.ones((1, 50))
logits_b = torch.ones((1, 50))
logits_b[0, 4], logits_b[0, 23] = 100, 100

In [20]:
logits_a, logits_b

(tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),
 tensor([[  1.,   1.,   1.,   1., 100.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
            1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1., 100.,
            1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
            1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
            1.,   1.]]))

In [21]:
out_a = compute_entropy(logits_a)
out_b = compute_entropy(logits_b)

In [22]:
out_a, out_b

(tensor([3.9120]), tensor([0.6931]))

In [23]:
torch.exp(out_a), torch.exp(out_b)

(tensor([50.0000]), tensor([2.]))

In [27]:
temp_values = [0.01, 0.1, 0.2, 0.5, 0.6, 0.8, 1.0, 5.0, 10.0, 25.0, 60.0, 100.0]

In [28]:
for t in temp_values:
    print(compute_entropy(logits_b/t))

tensor([0.6931])
tensor([0.6931])
tensor([0.6931])
tensor([0.6931])
tensor([0.6931])
tensor([0.6931])
tensor([0.6931])
tensor([0.6931])
tensor([0.7063])
tensor([2.3129])
tensor([3.7734])
tensor([3.8777])


In [29]:
for t in temp_values:
    print(compute_entropy(logits_a/t))

tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])
tensor([3.9120])


In [30]:
import math
print(math.log(2))

0.6931471805599453


In [95]:
def sample(logits: torch.Tensor, temperature: float=1.0, top_k: int=5, top_p: float= 0.9):
    B, C = logits.shape
    probs = stb_sft(logits/temperature, dim=-1)
    sort_probs, sort_indices = torch.sort(probs, dim=-1, descending=True)
    sort_probs = sort_probs[:, :top_k]
    sort_indices = sort_indices[:, :top_k]
    valid_indices = torch.where(~(torch.cumsum(sort_probs, dim=-1) > top_p), sort_indices, -1)
    return valid_indices

In [96]:
gen = torch.manual_seed(42) 
logits = torch.randn((2, 50), generator=gen)

In [97]:
logits

tensor([[ 1.9269e+00,  1.4873e+00,  9.0072e-01, -2.1055e+00,  6.7842e-01,
         -1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01,  1.6487e+00,
         -3.9248e-01, -1.4036e+00, -7.2788e-01, -5.5943e-01, -7.6884e-01,
          7.6245e-01,  1.6423e+00, -1.5960e-01, -4.9740e-01,  4.3959e-01,
         -7.5813e-01,  1.0783e+00,  8.0080e-01,  1.6806e+00,  1.2791e+00,
          1.2964e+00,  6.1047e-01,  1.3347e+00, -2.3162e-01,  4.1759e-02,
         -2.5158e-01,  8.5986e-01, -1.3847e+00, -8.7124e-01, -2.2337e-01,
          1.7174e+00,  3.1888e-01, -4.2452e-01,  3.0572e-01, -7.7459e-01,
         -1.5576e+00,  9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00,
          2.1228e+00, -1.2347e+00, -4.8791e-01, -9.1382e-01, -6.5814e-01],
        [ 7.8024e-02,  5.2581e-01, -4.8799e-01,  1.1914e+00, -8.1401e-01,
         -7.3599e-01, -1.4032e+00,  3.6004e-02, -6.3477e-02,  6.7561e-01,
         -9.7807e-02,  1.8446e+00, -1.1845e+00,  1.3835e+00,  1.4451e+00,
          8.5641e-01,  2.2181e+00,  5

In [98]:
out = sample(logits, top_p=0.2)

In [99]:
out

tensor([[45,  0, -1, -1, -1],
        [16, 11, -1, -1, -1]])

In [6]:
from maths.softmax import stable_log_softmax, stable_softmax
import math

In [7]:
stb_sft = stable_softmax()
stb_log_sft = stable_log_softmax()

In [8]:
def compute_entropy(logits: torch.Tensor):
    probs = stb_sft(logits, dim=-1)
    log_probs = stb_log_sft(logits, dim=-1)
    out = -torch.sum(probs * log_probs, dim=-1)
    return out
    

In [9]:
gen = torch.manual_seed(42)
logits_a = torch.ones((1, 50))
logits_b = torch.ones((1, 50))
logits_b[0, 4] = 100
logits_b[0, 25] = 100
logits_b[0, 43] = 89
logits_c = torch.randn((1, 50), generator=gen)

In [10]:
logits_b

tensor([[  1.,   1.,   1.,   1., 100.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1., 100.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.,  89.,   1.,   1.,   1.,   1.,
           1.,   1.]])

In [11]:
out_a = compute_entropy(logits_a)
out_b = compute_entropy(logits_b)
out_c = compute_entropy(logits_c)

In [12]:
out_a, out_b, out_c

(tensor([3.9120]), tensor([0.6932]), tensor([3.4912]))

In [13]:
torch.exp(out_a), torch.exp(out_b), torch.exp(out_c)

(tensor([50.0000]), tensor([2.0002]), tensor([32.8249]))

In [14]:
temp = [0.01, 0.1, 0.2, 0.3, 0.6, 0.9, 1.0, 5.0, 10.0, 15.0, 20.0, 50.0]

In [15]:
for t in temp:
    ent = compute_entropy(logits_b/t)
    print(t, ent, torch.exp(ent))

0.01 tensor([0.6931]) tensor([2.])
0.1 tensor([0.6931]) tensor([2.])
0.2 tensor([0.6931]) tensor([2.])
0.3 tensor([0.6931]) tensor([2.])
0.6 tensor([0.6931]) tensor([2.0000])
0.9 tensor([0.6932]) tensor([2.0001])
1.0 tensor([0.6932]) tensor([2.0002])
5.0 tensor([0.8626]) tensor([2.3692])
10.0 tensor([1.0149]) tensor([2.7591])
15.0 tensor([1.2381]) tensor([3.4492])
20.0 tensor([1.7435]) tensor([5.7173])
50.0 tensor([3.6309]) tensor([37.7482])


In [16]:
probs = stb_sft(logits_c, dim=-1)
torch.topk(probs, dim=-1, k = 5)

torch.return_types.topk(
values=tensor([[0.0759, 0.0699, 0.0593, 0.0575, 0.0571]]),
indices=tensor([[ 0, 45, 23,  9, 16]]))

In [54]:
def sample(logits: torch.Tensor, temperature: float=1.0, top_k: int=10, top_p: float=0.9):
    # top-k
    logits, logits_indices = torch.topk(logits, k=top_k, dim=-1)
    
    probs = stb_sft(logits/temperature, dim=-1)
    
    # top-p
    cumsum = torch.cumsum(probs, dim=-1)
    mask = cumsum <= top_p
    mask[:, 0] = True
    mask[:, 1:] |= (cumsum[:, :-1] <= top_p)
    probs = torch.where(mask, probs, torch.zeros_like(probs))
    probs = probs / torch.sum(probs, dim=-1, keepdim=True)
    indices = torch.multinomial(probs, num_samples=1)
    return torch.gather(logits_indices, dim=-1, index=indices)

In [57]:
def sample(logits: torch.Tensor, temperature: float=1.0, top_k: int=10, top_p: float=0.9):
    probs = stb_sft(logits/temperature, dim=-1)
    # top-k
    probs, prob_indices = torch.topk(probs, k=top_k, dim=-1)
    
    # top-p
    cumsum = torch.cumsum(probs, dim=-1)
    mask = cumsum <= top_p
    mask[:, 0] = True
    mask[:, 1:] |= (cumsum[:, :-1] <= top_p)
    probs = torch.where(mask, probs, torch.zeros_like(probs))
    probs = probs / torch.sum(probs, dim=-1, keepdim=True)
    indices = torch.multinomial(probs, num_samples=1)
    print(prob_indices.shape, indices.shape)
    return torch.gather(prob_indices, dim=-1, index=indices)
     

In [55]:
gen = torch.manual_seed(42)
logits = torch.randn((2, 50), generator=gen)

In [58]:
sample(logits, top_p=0.2, top_k=5)

torch.Size([2, 5]) torch.Size([2, 1])


tensor([[ 0],
        [14]])

In [61]:
def compute_nll(logits: torch.Tensor, target_ids: torch.Tensor):
    B, T, C = logits.shape
    B_t, T_t = target_ids.shape
    assert B==B_t and T == T_t
    log_probs = stb_log_sft(logits, dim=-1)
    log_probs = torch.gather(log_probs, dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)
    out = - torch.sum(log_probs)/(T*B)
    return out

In [63]:
gen = torch.manual_seed(42)
logits = torch.randn((2, 10, 50), generator=gen)
target_ids = torch.randint(0, 50, (2, 10), generator=gen)

In [70]:
out = compute_nll(logits, target_ids)
out_e = compute_entropy(logits)
out_e = torch.sum(out_e)/out_e.numel()
out, out_e

(tensor(4.4533), tensor(3.4519))

In [71]:
torch.exp(out), torch.exp(out_e)

(tensor(85.9116), tensor(31.5609))

In [72]:
gen = torch.manual_seed(42)
cheating = torch.randn((2, 1, 50), generator=gen)
cheating[0, 0, 10] = 25
cheating[1, 0, 21] = 100
target_ids = torch.tensor([[10], [21]])
clueless = torch.ones((2, 1, 50))
print(target_ids.shape, cheating.shape, clueless.shape)

torch.Size([2, 1]) torch.Size([2, 1, 50]) torch.Size([2, 1, 50])


In [73]:
out_cheating = compute_nll(cheating, target_ids)
out_clueless = compute_nll(clueless, target_ids)
out_cheating, out_clueless

(tensor(-0.), tensor(3.9120))

In [74]:
torch.exp(out_cheating), torch.exp(out_clueless)

(tensor(1.), tensor(50.0000))

In [None]:
from maths.softmax import stable_softmax
from config import device
import torch.nn as nn

In [33]:
class SelfAttention(nn.Module):
    def __init__(self, dim: int, seq_len: int, num_head:int, head_dim: int, is_causal: bool=True):
        super().__init__()
        self.dim = dim
        self.seq_len = seq_len
        self.num_head = num_head
        self.head_dim = head_dim
        self.is_causal = is_causal
        self.q_w = nn.Linear(self.dim, self.num_head * self.head_dim)
        self.k_w = nn.Linear(self.dim, self.num_head * self.head_dim)
        self.v_w = nn.Linear(self.dim, self.num_head * self.head_dim)

        self.w_o = nn.Linear(self.num_head * self.head_dim, self.dim)
        self.stb_sft = stable_softmax()
        self.register_buffer('tril', torch.tril(torch.ones((seq_len, seq_len)).bool()).unsqueeze(0).unsqueeze(1))
    
    def forward(self, x, mask = None):
        B, T, C = x.shape
        q = self.q_w(x)
        k = self.k_w(x)
        v = self.v_w(x)

        q = q.contiguous().view(B, T, self.num_head, self.head_dim).transpose(1, 2)
        k = k.contiguous().view(B, T, self.num_head, self.head_dim).transpose(1, 2)
        v = v.contiguous().view(B, T, self.num_head, self.head_dim).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        if self.is_causal:
            attn = attn.masked_fill(~self.tril[:, :, :T, :T], value=float('-inf'))
        
        if mask is not None:
            if mask.dim()==2:
                mask = mask.unsqueeze(1).unsqueeze(1)
            attn = attn.masked_fill(~mask, value=float('-inf'))
        
        attn_scores = self.stb_sft(attn, dim=-1)

        out = attn_scores @ v
        out = out.transpose(1, 2).reshape(B, T, self.num_head*self.head_dim)
        out = self.w_o(out)

        return out





In [34]:
B = 10
T = 256
C = 512
num_head = 4

In [35]:
attns_module = SelfAttention(dim=C, seq_len=T, num_head=num_head, head_dim=C//num_head).to(device)

In [36]:
gen = torch.Generator(device=device)
gen.manual_seed(42)

x = torch.randn((B, T, C), generator=gen, device=device)

In [37]:
x.shape

torch.Size([10, 256, 512])

In [38]:
out = attns_module(x, None)

In [39]:
out.shape

torch.Size([10, 256, 512])

In [26]:
out[0, 0, :]

tensor([-4.0777e-01, -2.8475e-01,  7.4862e-01, -1.8293e-01, -7.9914e-02,
        -4.7595e-02, -2.8990e-01,  7.4373e-02, -4.9075e-01,  6.1789e-01,
        -2.2332e-01,  3.9436e-01, -6.9272e-02, -1.1194e-01,  6.1884e-01,
        -4.0441e-01, -2.2336e-01, -2.0348e-01,  4.7485e-01,  1.1459e-01,
        -6.9173e-03, -1.3378e-01,  2.6216e-02, -4.1331e-01, -1.7349e-01,
         1.1603e-01, -4.9984e-01, -5.8381e-01,  7.0129e-02,  1.5205e-01,
        -8.9764e-02, -5.0892e-01,  5.5400e-01,  1.3133e-01, -5.0637e-03,
        -1.3070e-01, -2.9773e-01, -5.7534e-01,  4.4412e-01, -1.9810e-01,
        -3.9218e-02,  3.3281e-01,  2.7967e-02,  2.5573e-03, -1.5771e-01,
        -3.0615e-01,  8.6153e-02,  2.7363e-01,  6.2711e-01,  2.7965e-01,
         7.8249e-01,  2.4900e-01,  1.0584e-01,  4.2118e-01,  2.2307e-01,
         9.3330e-02,  7.1194e-01,  9.8989e-02, -2.0547e-01,  3.9773e-01,
         3.4572e-01,  3.6400e-01, -1.1486e-01,  2.2085e-01, -1.9148e-01,
        -1.1554e-01,  3.0440e-01, -2.8248e-01, -1.3

In [47]:
class ProjLayer(nn.Module):
    def __init__(self, dim: int, vocab_size: int):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.proj = nn.Linear(self.dim, self.vocab_size)
        self.stb_sft = stable_softmax()
    
    def forward(self, x):
        B, T, C = x.shape
        out = self.proj(x)
        out = self.stb_sft(out, dim=-1)
        return out


In [48]:
out.shape, C

(torch.Size([10, 256, 512]), 512)

In [51]:
proj_layer = ProjLayer(dim=C, vocab_size=1000).to(device)

In [52]:
final_out = proj_layer(out)

In [53]:
final_out.shape

torch.Size([10, 256, 1000])

In [56]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size: int, dim: int):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.embddings = nn.Embedding(self.vocab_size, self.dim)
    
    def forward(self, x):
        return self.embddings(x)


In [55]:
B = 10
T = 256
C = 512
VOCAB_SIZE = 1000
num_head = 4
head_dim = C//num_head

In [58]:
embedding_layer = EmbeddingLayer(VOCAB_SIZE, C).to(device)

In [60]:
gen = torch.Generator(device)
gen.manual_seed(42)
x = torch.randint(0, VOCAB_SIZE, (B, T), device=device, generator=gen)

In [61]:
x.shape

torch.Size([10, 256])

In [62]:
x = embedding_layer(x)

In [63]:
x.shape

torch.Size([10, 256, 512])

In [66]:
attns_module = SelfAttention(dim=C, seq_len=T, num_head=num_head, head_dim=head_dim).to(device)

In [67]:
out = attns_module(x)

In [68]:
out.shape

torch.Size([10, 256, 512])

In [69]:
proj_layer = ProjLayer(dim=C, vocab_size=VOCAB_SIZE).to(device)

In [70]:
final_out = proj_layer(out)

In [71]:
final_out.shape

torch.Size([10, 256, 1000])