In [1]:
import torch
import gpt1

torch.manual_seed(10)

<torch._C.Generator at 0x11ed91810>

# GPT



In [2]:
device = 'cpu'
batch_size = 2
seq_len = 40

max_seq_len = 300
vocab_size = 100
emb_size = 128

head_size = 4
num_heads = 3
num_layers = 5
dropout = 0.3

In [3]:
# на входе тензор токенов (batch_size,seq_len)
x = torch.randint(low=0,high=vocab_size-1,size=(batch_size,seq_len), dtype=torch.long, device=device)
x.shape, x.dtype

(torch.Size([2, 40]), torch.int64)

In [4]:
gpt = gpt1.GPT(
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    emb_size=emb_size,
    num_heads=num_heads,
    head_size=head_size,
    num_layers=num_layers,
    dropout=dropout,
    device=device)

In [5]:
#  на выходе loggits batch_size x seq_len x vocab_size
y = gpt.forward(x)
y.shape, y.dtype, torch.isnan(y).any().item()

(torch.Size([2, 40, 100]), torch.float32, False)

# GPT.generate

In [6]:
max_new_tokens = 20
#  на выходе []int batch_size x (seq_len + max_new_tokens)
new_seq = gpt.generate(x=x,max_new_tokens=max_new_tokens)
new_seq.shape, new_seq.dtype, torch.isnan(new_seq).any().item(), torch.unique(new_seq), len(torch.unique(new_seq))

(torch.Size([2, 60]),
 torch.int64,
 False,
 tensor([ 1,  4,  6,  7,  8,  9, 10, 11, 12, 13, 15, 17, 18, 20, 21, 22, 23, 24,
         25, 26, 27, 31, 32, 33, 34, 35, 40, 41, 42, 43, 45, 47, 48, 49, 50, 52,
         53, 54, 55, 57, 58, 60, 62, 63, 65, 67, 68, 69, 70, 72, 74, 75, 76, 79,
         80, 81, 82, 83, 84, 85, 89, 91, 94, 95, 96, 98]),
 66)

In [7]:
# заготовка, чтоб посмотреть как работает softmax
# batch_size x seq_len x vocab_size
logits = torch.rand([2,4,3])
logits.shape,logits.dtype

(torch.Size([2, 4, 3]), torch.float32)

In [8]:
# берём последний токен
last_log = logits[:,-1,:]
last_log.shape, last_log

(torch.Size([2, 3]),
 tensor([[0.7192, 0.8887, 0.5660],
         [0.4660, 0.9185, 0.3428]]))

In [9]:
# сумма по каждой строке 1
# для каждого batch один следующий токен
prob = torch.softmax(last_log,dim=-1)
prob

tensor([[0.3286, 0.3894, 0.2820],
        [0.2893, 0.4549, 0.2558]])

In [10]:
0.5210+ 0.2427+ 0.2363

1.0

In [11]:
# keep_dim чтоб потом можно было сделать cat c batch_size x seq_len -> batch_size x (seq_len+1)
next_token = torch.argmax(prob,dim=-1,keepdim=True)
next_token, next_token.shape

(tensor([[1],
         [1]]),
 torch.Size([2, 1]))

# GPT.generate (multinominal sample)

In [12]:
# включили сэмплирование - набор токенов в ответе стал разнообразней (не только самые вероятные)
new_seq = gpt.generate(x=x,max_new_tokens=max_new_tokens,do_sample=True)
new_seq.shape, new_seq.dtype, torch.isnan(new_seq).any().item(), torch.unique(new_seq),len(torch.unique(new_seq))

(torch.Size([2, 60]),
 torch.int64,
 False,
 tensor([ 1,  3,  4,  5,  7,  8,  9, 10, 11, 12, 13, 15, 18, 19, 20, 21, 22, 23,
         24, 25, 26, 27, 29, 32, 33, 35, 36, 39, 40, 41, 42, 44, 45, 46, 47, 48,
         49, 50, 52, 53, 54, 55, 57, 58, 60, 62, 64, 65, 67, 68, 69, 74, 76, 79,
         80, 81, 82, 83, 84, 89, 94, 95]),
 62)

In [13]:
# заготовка для Multinomial sampling (вероятностного сэмплирования)
prob

tensor([[0.3286, 0.3894, 0.2820],
        [0.2893, 0.4549, 0.2558]])

In [14]:
# здесь с размерностью всё хорошо batch_size x 1
next_token=torch.multinomial(prob,num_samples=1)
next_token,next_token.shape

(tensor([[2],
         [1]]),
 torch.Size([2, 1]))

# GPT.generate (temperature)

In [15]:
# с температурой можем получить промежуточный результат
new_seq = gpt.generate(x=x,max_new_tokens=max_new_tokens,do_sample=True, temperature=0.3)
new_seq.shape, new_seq.dtype, torch.isnan(new_seq).any().item(), torch.unique(new_seq), len(torch.unique(new_seq))

(torch.Size([2, 60]),
 torch.int64,
 False,
 tensor([ 1,  4,  6,  7,  9, 10, 11, 12, 13, 15, 17, 18, 20, 21, 22, 23, 24, 25,
         26, 27, 29, 32, 33, 34, 35, 37, 40, 41, 42, 47, 48, 49, 50, 52, 53, 54,
         55, 57, 58, 60, 62, 63, 65, 66, 67, 68, 69, 70, 74, 79, 80, 81, 82, 83,
         84, 85, 87, 89, 90, 93, 94, 95, 96, 98]),
 64)

# GPT.generate (top_k)

In [16]:
# не всё так однозначно - кол-во различных токенов может и нерасти растёт с ростом top_k
new_seq = gpt.generate(x=x,max_new_tokens=max_new_tokens,do_sample=True, top_k=2)
new_seq.shape, new_seq.dtype, torch.isnan(new_seq).any().item(), torch.unique(new_seq), len(torch.unique(new_seq))

(torch.Size([2, 60]),
 torch.int64,
 False,
 tensor([ 1,  2,  4,  6,  7,  8,  9, 10, 11, 12, 13, 15, 17, 18, 20, 21, 22, 23,
         24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 40, 41, 42, 43, 45, 46, 47,
         48, 49, 50, 52, 53, 54, 55, 57, 58, 60, 62, 65, 67, 68, 69, 70, 71, 74,
         79, 80, 81, 82, 83, 84, 85, 89, 91, 93, 94, 95, 96]),
 67)

In [17]:
new_seq = gpt.generate(x=x,max_new_tokens=max_new_tokens,do_sample=True, top_k=25)
new_seq.shape, new_seq.dtype, torch.isnan(new_seq).any().item(), torch.unique(new_seq), len(torch.unique(new_seq))

(torch.Size([2, 60]),
 torch.int64,
 False,
 tensor([ 0,  1,  2,  3,  4,  5,  7,  8,  9, 10, 11, 12, 13, 15, 18, 20, 21, 22,
         23, 24, 25, 26, 27, 29, 32, 33, 35, 40, 41, 42, 46, 47, 48, 49, 50, 51,
         52, 53, 54, 55, 57, 58, 59, 60, 62, 65, 66, 67, 68, 69, 70, 71, 72, 74,
         75, 79, 80, 81, 82, 83, 84, 86, 89, 90, 94, 95, 98]),
 67)

In [18]:
# заготовка
# batch_size x seq_len x vocab_size
logits = torch.rand([2,4,5])
# batch_size x vocab_size
logits = logits[:,-1,:]
logits.shape,logits.dtype,logits

(torch.Size([2, 5]),
 torch.float32,
 tensor([[0.5369, 0.8912, 0.2275, 0.7969, 0.5160],
         [0.5575, 0.4103, 0.9600, 0.8359, 0.5377]]))

In [19]:
# индексы трёх максимальных токенов
_, top_k_ind = torch.topk(logits,k=3, dim=-1)
top_k_ind, top_k_ind.shape

(tensor([[1, 3, 0],
         [2, 3, 0]]),
 torch.Size([2, 3]))

In [20]:
filtered = torch.full_like(logits,float('-inf'))
filtered.scatter_(dim=-1,index=top_k_ind, src=logits)
filtered

tensor([[0.2275, 0.5369,   -inf, 0.8912,   -inf],
        [0.9600,   -inf, 0.5575, 0.4103,   -inf]])

# GPT.generate (top_p)

In [30]:
new_seq = gpt.generate(x=x,max_new_tokens=max_new_tokens,do_sample=True, top_p=0.8)
new_seq.shape, new_seq.dtype, torch.isnan(new_seq).any().item(), torch.unique(new_seq), len(torch.unique(new_seq))

(torch.Size([2, 60]),
 torch.int64,
 False,
 tensor([ 1,  2,  3,  4,  7,  8,  9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22,
         23, 24, 25, 26, 27, 30, 32, 33, 34, 35, 40, 41, 42, 47, 48, 49, 50, 52,
         53, 54, 55, 56, 57, 58, 60, 62, 65, 67, 68, 69, 71, 73, 74, 77, 79, 80,
         81, 82, 83, 84, 85, 88, 89, 90, 91, 94, 95, 98]),
 66)

In [21]:
# заготовка
# batch_size x seq_len x vocab_size
logits = torch.rand([2,4,5])
# batch_size x vocab_size
logits = logits[:,-1,:]
logits.shape,logits.dtype,logits

(torch.Size([2, 5]),
 torch.float32,
 tensor([[0.9162, 0.5494, 0.5673, 0.0287, 0.5526],
         [0.6475, 0.9534, 0.4948, 0.2036, 0.8602]]))

In [22]:
prob = torch.softmax(logits, dim=-1)
sorted_prob, sorted_prob_ind = torch.sort(prob, descending=True, dim=-1)
prob,sorted_prob_ind

(tensor([[0.2853, 0.1977, 0.2013, 0.1174, 0.1983],
         [0.1963, 0.2665, 0.1685, 0.1259, 0.2428]]),
 tensor([[0, 2, 4, 1, 3],
         [1, 4, 0, 2, 3]]))

In [23]:
cumulative_prob = torch.cumsum(sorted_prob, dim=-1)
sorted_prob,cumulative_prob

(tensor([[0.2853, 0.2013, 0.1983, 0.1977, 0.1174],
         [0.2665, 0.2428, 0.1963, 0.1685, 0.1259]]),
 tensor([[0.2853, 0.4865, 0.6849, 0.8826, 1.0000],
         [0.2665, 0.5093, 0.7056, 0.8741, 1.0000]]))

In [24]:
# если не осталось ни одного токена, то хотя бы самый вероятный оставим
top_p = 0.2
to_remove = cumulative_prob > top_p
to_remove

tensor([[True, True, True, True, True],
        [True, True, True, True, True]])

In [25]:
to_remove[:,0]=False
to_remove

tensor([[False,  True,  True,  True,  True],
        [False,  True,  True,  True,  True]])

In [26]:
top_p = 0.5
to_remove = cumulative_prob > top_p
to_remove

tensor([[False, False,  True,  True,  True],
        [False,  True,  True,  True,  True]])

In [27]:
# to_remove и sorted_prob, sorted_prob_ind индексированы одинаково
# sorted_prob_ind указывает на элемент в исходном prob
# построим обраный индекс, чтоб бежать по исходному prob и получать индекс в sorted_prob
inverse_ind = torch.argsort(sorted_prob_ind, dim=-1)
sorted_prob_ind, inverse_ind

(tensor([[0, 2, 4, 1, 3],
         [1, 4, 0, 2, 3]]),
 tensor([[0, 3, 1, 4, 2],
         [2, 0, 3, 4, 1]]))

In [28]:
# Преобразуем to_remove из отсортированного пространства в исходное
mask = to_remove.gather(dim=-1, index=inverse_ind)
# если теперь взять prob проб по маске, то в сумме по строке меньше top_p=0.5
mask, prob, logits

(tensor([[False,  True, False,  True,  True],
         [ True, False,  True,  True,  True]]),
 tensor([[0.2853, 0.1977, 0.2013, 0.1174, 0.1983],
         [0.1963, 0.2665, 0.1685, 0.1259, 0.2428]]),
 tensor([[0.9162, 0.5494, 0.5673, 0.0287, 0.5526],
         [0.6475, 0.9534, 0.4948, 0.2036, 0.8602]]))

In [29]:
prob[mask]=0
logits[mask]=float('-inf')
prob,logits

(tensor([[0.2853, 0.0000, 0.2013, 0.0000, 0.0000],
         [0.0000, 0.2665, 0.0000, 0.0000, 0.0000]]),
 tensor([[0.9162,   -inf, 0.5673,   -inf,   -inf],
         [  -inf, 0.9534,   -inf,   -inf,   -inf]]))