In [1]:
print("Loading imports...\n")
import torch
import torch.nn.functional as F

import tiktoken
import random

from train_gpt2 import GPT, GPTConfig, load_model_from_save

device = "cuda" if torch.cuda.is_available() else "cpu"

model = GPT(GPTConfig(vocab_size=50304)).to(device)

optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)
_, _ = load_model_from_save(model, "log/model_19072.pt", optimizer, device)


Loading imports...

Configured optimizer with weight decay
- num decayed param tensors: 50, totaling 124,354,560 parameters
- num non-decayed parameter tensors: 98, totaling 121,344 parameters
- fused AdamW: False

Loaded model from train step 19073 with val loss 3.09375


In [2]:
def generate_text(
        model,
        prompt,
        device,
        num_return_sequences=1,
        max_new_tokens=50,
        temp=0.8,
        top_k = 50
    ):
        
        model.eval()

        sample_rng = torch.Generator(device=device)
        sample_rng.manual_seed(random.randint(0, 2**32 - 1))
        # sample_rng.manual_seed(42) # get different results for each process w/o affecting global rng 

        enc = tiktoken.get_encoding("gpt2")
        tokens = enc.encode(prompt)

        tokens = torch.tensor(tokens, dtype=torch.long) # (sequence_length, )
        tokens = tokens.unsqueeze(dim=0).repeat(num_return_sequences, 1) # (num_return_sequences, sequence_length)
        x = tokens.to(device)

        for t in range(max_new_tokens):
            with torch.inference_mode():
                logits, _ = model(x) # (B, T, vocab_size)
                logits = logits[:, -1, :] # (B, vocab_size)
                logits = logits / temp
                probs = F.softmax(logits, dim=-1)


                topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1) # (B, 50)

                # select a token
                ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
                xcol = torch.gather(topk_indices, -1, ix) # get the element at the index for each batch
                # xcol = torch.argmax(probs, dim=-1).unsqueeze(-1)
                x = torch.cat((x, xcol), dim=1)

            if num_return_sequences == 1:
                token_id = xcol[0, 0].item()
                decoded_token = enc.decode([token_id])

                if t == 0:
                    print(f"{prompt}{decoded_token}", end="")
                else:
                    print(decoded_token, end="")

        if num_return_sequences > 1:
            for i in range(num_return_sequences):
                tokens = x[i].tolist()
                decoded = enc.decode(tokens)
                print(f"\nSequence #{i+1}: {decoded}")

In [3]:
prompt = "I think that Artifical Intelligence (AI) is"

print("\nGenerating text...")
# model.generate_text(prompt, device=device, num_return_sequences=3, max_new_tokens=30)
generate_text(model, prompt, device=device, max_new_tokens=100)



Generating text...
I think that Artifical Intelligence (AI) is the future (if I can make that argument) because of its use in a business context and its use in everyday life.
AI is already used for other purposes, such as helping businesses solve complex business problems.
But how can it assist in a business context?
AI can assist in a business context by facilitating:
- Decision making by making decisions by comparing and contrasting information
- Analyzing information with data that is meaningful for you, or better yet, which is still only useful for