In [1]:
import torch
import torch.nn as nn
from config import device
import time

In [2]:
class stable_log_softmax(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x: torch.Tensor, dim: int=-1):
        max_values = torch.max(x, dim=dim, keepdim=True).values
        shifted_x = x - max_values
        log_probs = shifted_x - torch.logsumexp(shifted_x, dim=dim, keepdim=True)
        return log_probs

In [20]:
gen = torch.manual_seed(42)
x = torch.randn((1, 1, 50), requires_grad=True, generator=gen)

In [21]:
x = torch.nn.init.xavier_uniform_(x)

In [22]:
x

tensor([[[ 0.2032, -0.1452, -0.1461, -0.1461,  0.2203,  0.0816,  0.2357,
          -0.2022, -0.2430, -0.1916, -0.1648,  0.0992,  0.0877,  0.2035,
          -0.1265, -0.1670,  0.1300, -0.0990,  0.1487, -0.0581,  0.1401,
          -0.1903, -0.1236,  0.0747,  0.0518, -0.0625,  0.1460,  0.1665,
          -0.1776, -0.1308,  0.2243, -0.0827, -0.0868, -0.2370, -0.1403,
           0.0612, -0.0323, -0.1778,  0.0057, -0.1673, -0.2078, -0.1349,
          -0.2144, -0.1560,  0.2449,  0.0463,  0.0755, -0.2285, -0.1609,
          -0.0815]]], requires_grad=True)

In [23]:
m = stable_log_softmax()

In [24]:
seq_len = 10
total_log_prob = 1.0

In [30]:
torch.max(x, dim=-1).values.item()

0.2448531687259674

In [31]:
batch_size = 1
seq_len = 10
vocab_size = 50
gen = torch.manual_seed(42)
logits = torch.randn((batch_size, seq_len, vocab_size), generator=gen)
target_ids = torch.randint(0, vocab_size, (batch_size, seq_len), generator=gen)

In [33]:
logits[0, 0, :], target_ids[0, :]

(tensor([ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
         -0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624,
          1.6423, -0.1596, -0.4974,  0.4396, -0.7581,  1.0783,  0.8008,  1.6806,
          1.2791,  1.2964,  0.6105,  1.3347, -0.2316,  0.0418, -0.2516,  0.8599,
         -1.3847, -0.8712, -0.2234,  1.7174,  0.3189, -0.4245,  0.3057, -0.7746,
         -1.5576,  0.9956, -0.8798, -0.6011, -1.2742,  2.1228, -1.2347, -0.4879,
         -0.9138, -0.6581]),
 tensor([37, 39, 29,  7, 35, 38, 17, 14,  7, 24]))

In [34]:
out = torch.gather(logits, dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

In [35]:
out.shape

torch.Size([1, 10])

In [36]:
out

tensor([[-0.4245, -0.4816,  1.0119, -0.4175, -1.0811,  0.9733, -0.0553, -1.7735,
         -0.0388,  0.4151]])

In [38]:
logits[0, 1, target_ids[0, 1]]

tensor(-0.4816)

In [39]:
torch.sum(out, dim=-1)

tensor([-1.8721])

In [13]:
from maths.sequence_score import sequence_logprob, compute_nll
from maths.softmax import stable_log_softmax

In [4]:
m_sftmax = stable_log_softmax()

In [9]:
gen1 = torch.manual_seed(42)
gen2 = torch.manual_seed(43)
logits_a = torch.randn((10, 2, 50), generator=gen1)
target_ids_a = torch.randint(0, 50, (10, 2), generator=gen1)
logits_b = torch.randn((10, 10, 50), generator=gen2)
target_ids_b = torch.randint(0, 50, (10, 10), generator=gen2)

In [10]:
seq_probs_a = sequence_logprob(logits_a, target_ids_a)
seq_probs_b = sequence_logprob(logits_b, target_ids_b)

In [11]:
seq_probs_a

tensor([ -9.5571,  -6.9069, -12.4230,  -8.7082, -11.6055, -10.6919,  -9.3203,
        -10.5201,  -8.8899,  -6.3690])

In [12]:
seq_probs_b

tensor([-46.3163, -44.3665, -44.4756, -40.5072, -49.0028, -42.1155, -37.8311,
        -42.5875, -46.1512, -46.3185])

In [14]:
seq_nll_a = compute_nll(logits_a, target_ids_a)
seq_nll_b = compute_nll(logits_b, target_ids_b)

In [15]:
seq_nll_a, seq_nll_b

(tensor(4.7496), tensor(4.3967))