## Pretraining on unlabeled data

### 1) Evaluating generative text models

In [2]:
from utils import GPTModel
import torch

GPT_CONFIG_124M = {
    "vocab_size": 200019,
    "context_length": 1024,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

In [3]:
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.eval();

In [4]:
import tiktoken
from utils import generate_text_simple

def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)
    return encoded_tensor

In [5]:
start_context = "Pierwszy dzień wiosny jest"
tokenizer = tiktoken.get_encoding("o200k_base")
input_ids = text_to_token_ids(start_context, tokenizer)
print("Input IDs:", input_ids)

Input IDs: tensor([[152687,   8811,   3705, 155653,    286,   2453,   3008,  12637]])


In [6]:
def token_ids_to_text(token_ids, tokenizer):
    decoded = tokenizer.decode(token_ids.squeeze(0).tolist())
    return decoded

In [7]:
token_ids_to_text(input_ids, tokenizer)
print("Token IDs to text:", token_ids_to_text)

Token IDs to text: <function token_ids_to_text at 0x7e2151c21e40>


In [8]:
token_ids = generate_text_simple(
    model=model,
    idx=text_to_token_ids(start_context, tokenizer),
    max_new_tokens=10,
    context_size=GPT_CONFIG_124M["context_length"],
)

In [9]:
token_ids.squeeze(0)

tensor([152687,   8811,   3705, 155653,    286,   2453,   3008,  12637,   2944,
         32600,  10819,  29864,  14338, 160118,  91249, 189492, 135305,  65540])

In [10]:
token_ids_to_text(token_ids, tokenizer)

'Pierwszy dzień wiosny jest monthibileosingGP branch Wolfs hướng serien தொகету'

### 2) Calculating the text generation loss: cross-entropy and perplexity

In [30]:
inputs = text_to_token_ids("Wszystkie drogi prowadzą do", tokenizer)
print(inputs)
targets = text_to_token_ids("drogi prowadzą do Rzymu", tokenizer)
print(targets)

tensor([[    54, 148556,  51201,   6517,   6248, 104788,  21589,    621]])
tensor([[100256,   6248, 104788,  21589,    621,    460,  28178,     84]])


In [31]:
with torch.no_grad():
    logits = model(inputs)

In [32]:
logits.shape

torch.Size([1, 8, 200019])

In [33]:
probas = torch.softmax(logits, dim=-1)
print(probas.shape)

torch.Size([1, 8, 200019])


In [34]:
probas

tensor([[[5.4583e-06, 5.1334e-06, 4.4685e-06,  ..., 7.7346e-06,
          4.0741e-06, 2.8078e-06],
         [8.1417e-06, 4.7167e-06, 1.7018e-05,  ..., 1.8737e-05,
          1.5956e-05, 3.2252e-06],
         [1.1007e-05, 5.4918e-06, 4.6825e-06,  ..., 3.5426e-06,
          3.4522e-06, 4.5152e-06],
         ...,
         [6.6834e-06, 1.9917e-06, 5.3639e-06,  ..., 6.2616e-06,
          4.3297e-06, 2.4443e-06],
         [4.5537e-06, 3.7555e-06, 7.1191e-06,  ..., 3.2698e-06,
          5.6970e-06, 4.2523e-06],
         [6.2523e-06, 2.0468e-06, 2.9966e-06,  ..., 2.4268e-06,
          3.7407e-06, 9.4331e-06]]])

In [35]:
token_ids = torch.argmax(probas, dim=-1, keepdim=True)
print("Token IDs:", token_ids)

Token IDs: tensor([[[ 99330],
         [ 26909],
         [  4071],
         [ 21347],
         [183194],
         [191723],
         [113220],
         [122132]]])


In [36]:
print(f"Targets batch: {token_ids_to_text(targets[0], tokenizer)}")
print(f"Outputs batch: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}")

Targets batch: drogi prowadzą do Rzymu
Outputs batch: ZIPৰু dire debut משום goofy vroegerartig


In [37]:
text_idx = 0
target_probas = probas[text_idx, -1, targets[text_idx]]
print("Target probabilities:", target_probas)

Target probabilities: tensor([2.0819e-06, 8.7382e-06, 3.2365e-06, 3.3469e-06, 1.8273e-06, 2.4464e-06,
        5.1523e-06, 4.0019e-06])


In [38]:
log_probs = torch.log(target_probas)
print("Log probabilities:", log_probs)

Log probabilities: tensor([-13.0822, -11.6478, -12.6410, -12.6075, -13.2127, -12.9209, -12.1761,
        -12.4287])


In [39]:
-1 * torch.mean(log_probs)

tensor(12.5896)

In [40]:
torch.nn.functional.cross_entropy(
    logits.flatten(0, 1),
    targets.flatten())

tensor(12.6042)