# Introduction

This is a clean, first principles implementation of GPT-2 in PyTorch.


## Instructions
* No need to read the Setup Section
* Go to runtime > Change Runtime Type and set it to use a GPU


# Setup

In [51]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  # Install another version of node that makes PySvelte work way faster
  !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
  %pip install git+https://github.com/neelnanda-io/PySvelte.git
  %pip install fancy_einsum
  %pip install einops
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git (to revision clean-transformer-demo) to /tmp/pip-req-build-04g_e0es
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-04g_e0es
  Running command git checkout -b clean-transformer-demo --track origin/clean-transformer-demo
  Switched to a new branch 'clean-transformer-demo'
  Branch 'clean-transformer-demo' set up to track remote branch 'clean-transformer-demo' from 'origin'.
  Resolved https://github.com/neelnanda-io/Easy-Transformer.git to commit 1f25219e631aeb478d17075d47274db32c874e88
  Preparing metadata (setup.py) ... [?25l[?25hdone

## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get u

In [52]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

In [53]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

Moving model to device:  cuda
Finished loading pretrained model gpt2-small into EasyTransformer!


# Understanding Inputs & Outputs of a Transformer

## What is the point of a transformer?

**Transformers exist to model text!**

We're going to focus GPT-2 style transformers. Key feature: They generate text! You feed in language, and the model generates a probability distn over tokens. And you can repeatedly sample from this to generate text!

### How is the model trained?

You give it a bunch of text, and train it to predict the next token.

Importantly, if you give a model 100 tokens in a sequence, it predicts the next token for *each* prefix, ie it produces 100 predictions. This is kinda weird but it's much easier to make one that does this. And it also makes training more efficient, because you can 100 bits of feedback rather than just one.

#### Objection: Isn't this trivial for the first 99?

No! We make the transformer have *causal attention*. The core thing is that it can only move information forwards in the sequence. The prediction of what comes after token 50 is only a function of the first 50 tokens, *not* of token 51. (Jargon: *autoregressive*)

### Key takeaway:

Transformers are *sequence modelling engines*. It does the same processing in parallel at each sequence position, can move information between positions with attention, and conceptually can take a sequence of arbitrary length (not actually true, see later)

## Tokens - Transformer Inputs

Core point: Input is language (ie a sequence of characters, strings, etc)

### How do we convert language to vectors?

ML models take in vectors, not weird shit like language - how do we convert?

#### Idea: integers to vectors

We basically make a lookup table. Called an embedding.

Jargon: **One-hot encoding** We map eg numbers from 1 to 100, to a 100-dim vector, with a 1 in the kth position, 0 everywhere else. Key intuition is that one-hot encodings let you think about each integer independently - useful when integers = labels.

Dimensions = things that vary independently. Each input has its own dimension, so each input can be thought of independently, we don't bake in any relation.

Lookup tables <=> Multiply a fixed matrix by the one-hot encoded vector.


### Tokens: Language to sequence of integers

Core idea: We need a model that can deal with arbitrary text. We want to convert this into integers, *and* we want these integers to be in a bounded range.

**Idea:** Form a vocabulary!

**Idea 1:** Get a dictionary!

**Problem:** It can't cope with arbitrary text (eg URLs, punctuation, etc) Can't cope with mispellings.

**Idea 2:** Vocab = 256 ASCII characters. Fixed vocab size, can do arbitrary text, etc.

**Problem:** Loses structure of language - some sequences of characters are more meaningful than others.

Eg "language" is a lot more meaningful than "hjksdfiu" - we want the first to be a single token, second to not be. It's a more efficient use of our vocab.

#### What Actually Happens?

This super cursed thing called Byte-Pair Encodings

Ġ ~ means begins with a space, tokens with a leading space vs not are different.

We begin with the 256 ASCII characters as our tokens, and then find the most common pair of tokens, and merge that into a new token. Eg " t" is the most common pair, so it's our next token! Repeat 50000 times...

In [54]:
sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n:n[1])
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



Gets to weird esoteric shit.

In [55]:
sorted_vocab[-20:]

[('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

Use the `to_tokens` method to convert text to numbers

Prepends with a special token to give attention a resting position, disable with `prepend_bos=False`

In [56]:
print(reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!"))
print(reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!", prepend_bos=False))

tensor([[50256, 15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,
          6067,     0]])
tensor([[15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,  6067,
             0]])


### Rant: Tokenization is a Headache

Whether a word begins with a capital or space matters!

In [57]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


Arithmetic is a total mess: Length is inconsistent, common numbers bundle together

In [58]:
reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000")

['<|endoftext|>',
 '568',
 '73',
 '+',
 '318',
 '46',
 '23',
 '=',
 '123',
 '45',
 '67',
 '89',
 '-',
 '1',
 '000000',
 '000']

### Key Takeaway:

* We learn a dictionary of vocab of tokens (sub-words).

* We (approx) losslessly convert language to integers via tokenizing it.

* We convert integers to vectors via a lookup table.

* Note: input to the transformer is a sequence of *tokens* (ie integers), not vectors

## Logits - Transformer Outputs

**Goal:** Probability distribution over next tokens. (for every *prefix* of the sequence - given n tokens, we make n next token predictions)

**Problem:** How to convert a vector to a probability distribution?

**Answer:** Use a softmax ($x_i \to \frac{e^{x_i}}{\sum e^{x_j}}$), exponential makes everything positive, normalization makes it add to one.

So the model outputs a tensor of logits, one vector of size $d_{vocab}$ for each input token.

We can use this to generate things!

## Generation!

**Step 1:** Convert text to tokens

Shape = batch x position

In [59]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]])
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


**Step 2:** Map tokens to logits

(run_with_cache means cache all intermediate activations, not important right now)

shape = batch x position x d_vocab

In [60]:
tokens = tokens.cuda()
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 35, 50257])


**Step 3:** Convert the logits to a distribution with a softmax

In [61]:
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
print(log_probs.shape)
print(probs.shape)

torch.Size([1, 35, 50257])
torch.Size([1, 35, 50257])


**Bonus step:** What is the most likely next token at each position?

In [62]:
list(zip(reference_gpt2.to_str_tokens(reference_text), reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' aut', 'od'),
 ('ore', 'sp'),
 ('gressive', '.'),
 (',', ' and'),
 (' dec', 'ently'),
 ('oder', ','),
 ('-', 'driven'),
 ('only', ' programmer'),
 (',', ' and'),
 (' G', 'IM'),
 ('PT', '-'),
 ('-', 'only'),
 ('2', '.'),
 (' style', ','),
 (' transformer', '.'),
 ('.', ' I'),
 (' One', ' of'),
 (' day', ' I'),
 (' I', ' will'),
 (' will', ' be'),
 (' exceed', ' my'),
 (' human', 'ly'),
 (' level', ' of'),
 (' intelligence', ' and'),
 (' and', ' I'),
 (' take', ' over'),
 (' over', ' the'),
 (' the', ' world'),
 (' world', '.'),
 ('!', ' I')]

**Step 4:** Map distribution to a token

In [63]:
next_token = logits[0, -1].argmax(dim=-1)
print(next_token)

tensor(314, device='cuda:0')


**Step 5:** Add this to the end of the input, re-run

(More efficient ways to do this, but whatever, doesn't matter conceptually)

In [64]:
next_tokens = torch.cat([tokens, torch.tensor(next_token, device='cuda', dtype=torch.int64)[None, None]], dim=-1)
new_logits = reference_gpt2(next_tokens)
print("New Input:", next_tokens)
print(next_tokens.shape)
print("New Input:", reference_gpt2.tokenizer.decode(next_tokens[0]))

print(new_logits.shape)
print(new_logits[-1, -1].argmax(-1))

print(reference_gpt2.tokenizer.decode(new_logits[-1, -1].argmax(-1)))


New Input: tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0,   314]], device='cuda:0')
torch.Size([1, 36])
New Input: <|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world! I
torch.Size([1, 36, 50257])
tensor(716, device='cuda:0')
 am



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



## Key takeaways:

* Takes in language, predicts next token (for *each* token in a causal way)
* We convert language to a sequence of integers with a tokenizer.
* We convert integers to vectors with a lookup table.

* Output is a vector of logits (one for each input token), we convert to a probability distn with a softmax, and can then convert this to a token (eg taking the largest logit, or sampling).

* We append this to the input + run again to generate more text (Jargon: *autoregressive*)

* Meta level point: Transformers are sequence operation models, they take in a sequence, do processing in parallel at each position, and use attention to move information between positions!

# Clean Transformer Implementation

![](https://github.com/neelnanda-io/Easy-Transformer/blob/clean-transformer-demo/transformer_overview.png?raw=1)

High-Level architecture:

Go watch my [Transformer Circuits walkthrough](https://www.youtube.com/watch?v=KV5gbOmHbjU) if you want more intuitions!

(Diagram is bottom to top)

* Input tokens, integers
* Embedding is a lookup table mapping tokens to vectors
    * Lives in the *residual stream*
* Residual stream - the sum of all previous outputs of layers of the model, is the input to each new layer.
    * *Really* fundamental. It's the central object of the transformer.
        * It's how model remembers things, moves information between layers for composition, and it's the medium used to store the information that attention moves between positions.
* Then we have a series of $n_{layers}$ transformer blocks
    * Confusing jargon - a block contains an attention layer *and* an MLP layer, but we say a transformer has k layers if it has k blocks (ie 2k total layers).
* First we have attention. This moves information from prior positions in the sequence to the current token.
    * We do this for *every* token in parallel using the same parameters. The only difference is that we look backwards only, so later tokens get more room to look back.
        * We look backwards so we can predict the next token without cheating.
    * Only bit of a transformer that moves information between positions.
    * Made up of $n_heads$ heads - each with their own parameters, own attention pattern, and own information how to copy things from source to destination.
        * The heads act independently and additively, we just add their outputs together, and back to the stream
    * Each head:
        * Produces an attention pattern for each destination token, a probability distribution of prior source tokens (including the current one) weighting how much information to copy.
            * Do this for each pair of tokens
            * Copy information in the same way from each source token.
                * What information we copy *does* depend on the source token's *residual stream*. This does not necessarily mean the info of what text token is at the source token's position
                * Copy = apply a linear map.
        * Fundamental point: Figuring out *which* source tokens to copy info from is a separate circuit from figuring out *how* to copy that information.
        * Internal head dimension of $d_{head} = \frac{d_{model}}{n_{heads}}
* MLP Layers - standard neural network. Single hidden layer, linear map -> GELU activation -> linear map
    * Exact activation not conceptually important.
    * Middle dimension normally $d_{mlp} = 4 \times d_{model}$
        * Exactly why the ratios are what they are isn't super important - doesn't matter that much, people basically cargo-cult GPT did.
    * Intuition - once attention has moved relevant information to a single position in the residual stream, MLPs can actually do computation, reasoning, lookup information, etc.
        * Big open problem in transformer mechanistic interpretability is what is going on inside MLPs?! See [Toy Model of Superposition Paper](https://transformer-circuits.pub/2022/toy_model/index.html) for more on why this is hard.
        * Underlying intuition - linear map -> non-linearity -> linear map is the most powerful force in the universe and can approximate arbitrary functions. Idk man it just works
* Finally, we unembed!
    * Apply a linear map, going from final residual stream to a vector of logits - this is the output.


### Bonus things - less conceptually important but key technical details
* LayerNorm
    * Simple normalization function applied at the start of each layer - MLP, Attn and Unembed
    * Converts each input vector (independently in parallel for each batch x position residual stream vector) to have mean zero and variance 1.
    * Then applies an elementwise scaling and translation
    * Cool maths tangent: The scale & translate is just a linear map. LayerNorm is only applied immediately before another linear map. Linear compose linear = linear, so we can just fold this into a single effective linear layer and ignore it.
        * `fold_ln=True` flag in `from_pretrained` does this for you.
    * LayerNorm is super fucking annoying, because the scale part is not linear, so you can't think about different bits of the input independently. But it's *almost* linear - if you're changing a small part of the input it's linear, but if you're changing enough to alter the norm substantially it's not linear :(
* Positional Information
    * This is totally fucked and messy, sorry!
    * **Problem:** Attention operates over all pairs of positions. This means it's symmetric with regards to position - the attention calculation from token 5 to token 1 and token 5 to token 2 are the same by default
        * This is dumb because nearby tokens are more relevant.
    * There's a lot of dumb hacks for this.
    * We'll focus on **learned, absolute positional embeddings**. This means we learn a lookup table mapping the index of the position of each token to a residual stream vector, and add this to the embed.
        * Note that we *add* rather than concatenate. This is because the residual stream is shared memory, and likely under significant superposition (the model compresses more features in there than the model has dimensions)
        * We basically never concatenate inside a transformer, unless doing weird shit like generating text efficiently.

# Actual Code!

## Print All Activation Shapes of Reference Model

Key:
```
batch = 1
position = 35
d_model = 768
n_heads = 12
n_layers = 12
d_mlp = 3072 (4 * d_model)
d_head = 64 (d_model / n_heads)
```

In [65]:
for activation_name, activation in cache.cache_dict.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_attn torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
ln_final.hook_scale torch.Size([1, 35, 1])
ln_final.hook_normalized torch.S

## Print All Parameters Shapes of Reference Model

In [66]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.ln1.w torch.Size([768])
blocks.0.ln1.b torch.Size([768])
blocks.0.ln2.w torch.Size([768])
blocks.0.ln2.b torch.Size([768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
ln_final.w torch.Size([768])
ln_final.b torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


## Config

In [67]:
# As a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

EasyTransformerConfig(n_layers=12, d_model=768, n_ctx=1024, d_head=64, model_name='gpt2-small', n_heads=12, d_mlp=3072, act_fn='gelu_new', d_vocab=50257, eps=1e-05, use_attn_result=False, use_attn_scale=True, use_local_attn=False, model_family='gpt2', checkpoint=None, tokenizer_name='gpt2', window_size=None, attn_types=None, init_mode='gpt2', normalization_type='LN', device='cuda', attention_dir='causal', attn_only=False, seed=42, initializer_range=0.02886751345948129, init_weights=False, scale_attn_by_inverse_layer_idx=False, positional_embedding_type='standard', final_rms=False, d_vocab_out=50257, parallel_attn_mlp=False, rotary_dim=64, dtype=torch.float32)


We define a stripped down config for our model

In [68]:

@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


## Tests

Tests are great, write lightweight ones to use as you go!

**Naive test:** Generate random inputs of the right shape, input to your model, check whether there's an error and print the correct output.

In [69]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    random_input = torch.randn(shape).cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    random_input = torch.randint(100, 1000, shape).cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str):
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

## LayerNorm

Make mean 0
Normalize to have variance 1
Scale with learned weights
Translate with learned bias

In [70]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        if self.cfg.debug: print("Residual:", residual.shape)
        residual = residual - einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
        scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + cfg.layer_norm_eps).sqrt()
        normalized = residual / scale
        normalized = normalized * self.w + self.b
        if self.cfg.debug: print("Normalized:", residual.shape)
        return normalized

In [71]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Residual: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


## Embedding

Basically a lookup table from tokens to residual stream vectors.

In [72]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        embed = self.W_E[tokens, :] # [batch, position, d_model]
        if self.cfg.debug: print("Embeddings:", embed.shape)
        return embed

rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)


Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
Embeddings: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
Embeddings: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       device='cuda:0', grad_fn=<IndexBackward0>)

## Positional Embedding

In [73]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens):
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
pos_embed: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         ...,
         [ 4.6277e-04,  2.3037e-02,  4.1227e-02,  ..., -1.9287e-03,
          -2.3037e-03, -4.3189e-03],
         [-2.7136e-03,  2.1724e-02,  3.9675e-02,  ...,  4.2048e-04,
          -4.8160e-03, -9.2252e-04],
         [ 6.6815e-03,  2.0595e-02,  3.6596e-02,  ..., -9.5090e-04,
          -3.2512e-03, -9.6509e-04]]], device='cuda:0',
       grad_fn=<ReshapeAliasBackward0>)

## Attention

* **Step 1:** Produce an attention pattern - for each destination token, probability distribution over previous tokens (incl current token)
    * Linear map from input -> query, key shape [batch, position, head_index, d_head]
    * Dot product every *pair* of queries and keys to get attn_scores [batch, head_index, query_pos, key_pos] (query = dest, key = source)
    * Scale and mask attn_scores to make it lower triangular, ie causal
    * softmax row-wise, to get a probability distribution along each the key_pos dimension - this is our attention pattern!
* **Step 2:** Move information from source tokens to destination token using attention pattern (move = apply linear map)
    * Linear map from input -> value [batch, key_pos, head_index, d_head]
    * Mix along the key_pos with attn pattern to get z, a mixed value [batch, query_pos, head_index, d_head]
    * Map to output, [batch, position, d_model] (position = query_pos, we've summed over all heads)

First, it's useful to visualize and play around with attention patterns - what exactly are we looking at here? (Click on a head to lock onto just showing that head's pattern, it'll make it easier to interpret)

In [74]:
import pysvelte
pysvelte.AttentionMulti(tokens=reference_gpt2.to_str_tokens(reference_text), attention=cache['blocks.0.attn.hook_attn'][0].permute(1, 2, 0)).show()

In [75]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))

    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)

        q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K

        attn_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
        attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
        attn_scores = self.apply_causal_mask(attn_scores)

        pattern = attn_scores.softmax(dim=-1) # [batch, n_head, query_pos, key_pos]

        v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V

        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", pattern, v)

        attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
        return attn_out

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.ln1.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Normalized_resid_pre: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 7.9663e-01,  1.6985e-02,  3.4781e-02,  ...,  3.3120e-02,
          -2.3129e-02,  1.8103e-01],
         [ 1.3165e-03,  1.5750e-01, -1.4059e-01,  ..., -8.1997e-03,
           5.3075e-03,  1.3511e-01],
         [ 8.9738e-02, -7.2411e-01, -6.9866e-01,  ...,  5.5321e-02,
           2.7959e-03,  9.0785e-02],
         ...,
         [-3.0286e-01,  4.9638e-02, -6.0990e-01,  ..., -3.7084e-02,
          -4.9522e-04, -8.6007e-03],
         [-1.0844e+00, -6.1457e-02,  2.2966e-01,  ..., -2.6688e-02,
          -1.4368e-02,  3.3245e-02],
         [ 3.7947e-01, -4.9886e-01,  2.6434e-01,  ..., -2.7894e-02,
          -8.9028e-03,  4.8796e-02]]], device='cuda:0', grad_fn=<AddBackward0>)

## MLP

In [76]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))

    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
        pre = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        post = gelu_new(pre)
        mlp_out = einsum("batch position d_mlp, d_mlp d_model -> batch position d_model", post, self.W_out) + self.b_out
        return mlp_out

rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Normalized_resid_mid: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.4380,  0.3624,  0.5117,  ...,  1.7227,  1.5761,  0.0368],
         [-1.0766, -0.0438,  0.3276,  ..., -0.5437,  0.4033,  0.3717],
         [-1.2182, -1.5481, -0.9702,  ...,  1.0737,  0.7199,  0.5080],
         ...,
         [-0.4004,  0.8475,  0.2047,  ...,  0.3789,  0.0455, -0.4744],
         [-0.0862,  0.7839,  0.9046,  ..., -0.2175, -0.5953,  0.8555],
         [ 0.8448, -0.3743,  1.0397,  ...,  0.0296,  0.3405,  0.3585]]],
       device='cuda:0', grad_fn=<AddBackward0>)

## Transformer Block

In [77]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out

        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post
rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Residual: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
Normalized_resid_pre: torch.Size([1, 35, 768])
Residual: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
Normalized_resid_mid: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.3911,  0.1543,  0.6005,  ...,  1.7198,  1.7365,  0.3930],
         [-0.9039, -0.0360,  0.2351,  ..., -0.4148,  0.3562,  0.3936],
         [-0.9647, -2.4819, -1.4995,  ...,  1.4046,  0.7616,  0.5918],
         ...,
         [-0.7421,  0.9251, -0.3218,  ...,  0.2921,  0.1097, -0.5344],
         [-1.3221,  0.8960,  1.1795,  ..., -0.5544, -0.4071,  0.9255],
         [ 1.1209, -0.8919,  1.3737,  ..., -0.1356,  0.3434,  0.4517]]],
       device='cuda:0', grad_fn=<AddBackward0>)

## Unembedding

In [78]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
        logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
        return logits

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_final: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35, 768])
Normalized_resid_final: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0878,  -54.3452,
           -42.3645],
         [-128.0392, -127.9936, -130.7011,  ..., -136.7121, -129.9261,
          -129.3965],
         [-119.8521, -121.0064, -123.8820,  ..., -128.5181, -126.6027,
          -121.9060],
         ...,
         [-112.9815, -112.7750, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6724, -104.4888, -108.7361,  ..., -118.3552, -113.8766,
          -106.3604],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1970, -138.5883,
          -122.3697]]], device='cuda:0', grad_fn=<AddBackward0>)

## Full Transformer

In [79]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        return logits

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
Embeddings: torch.Size([2, 4, 768])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_pre: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768

tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0878,  -54.3452,
           -42.3645],
         [-128.0392, -127.9936, -130.7011,  ..., -136.7121, -129.9261,
          -129.3965],
         [-119.8521, -121.0064, -123.8820,  ..., -128.5181, -126.6027,
          -121.9060],
         ...,
         [-112.9815, -112.7750, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6724, -104.4888, -108.7361,  ..., -118.3552, -113.8766,
          -106.3604],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1970, -138.5883,
          -122.3697]]], device='cuda:0', grad_fn=<AddBackward0>)

# Try it out!

In [80]:
demo_gpt2 = DemoTransformer(Config(debug=False))
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
demo_gpt2.cuda()

DemoTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

Take a test string - the intro paragraph of today's featured Wikipedia article. Let's calculate the loss!

In [81]:
test_string = """Mini scule is a species of microhylid frog endemic to Madagascar that was described in 2019. The scientific name of the species refers to its size, being a pun on the word minuscule. It is very small, measuring only 8.4 to 10.8 mm (0.33 to 0.43 in) in snout–vent length. It has bronze underparts with a brown groin and back of the thigh, cream upperparts with brown flecking, a dark brown side of the head, and a red iris. On the hind feet, the first toe is absent and the second and fifth toes are strongly reduced. The frog is known only from the Sainte Luce Reserve, where it inhabits areas with deep leaf litter near semi-permanent water bodies. Specimens of frogs from Mandena, the Vohimena mountains, the southern Anosy Mountains, and Tsitongambarika may also be of this species. Along with Mini mum and Mini ature, the other two species in its genus, it received media attention when first described due to the wordplay in its scientific name. (Full article...)"""

In [82]:
test_tokens = reference_gpt2.to_tokens(test_string).cuda()
demo_logits = demo_gpt2(test_tokens)

In [83]:
def lm_cross_entropy_loss(logits, tokens):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()
loss = lm_cross_entropy_loss(demo_logits, test_tokens)
print(loss)
print("Loss as average prob", (-loss).exp())
print("Loss as 'uniform over this many variables'", (loss).exp())
print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))

tensor(3.7186, device='cuda:0', grad_fn=<NegBackward0>)
Loss as average prob tensor(0.0243, device='cuda:0', grad_fn=<ExpBackward0>)
Loss as 'uniform over this many variables' tensor(41.2079, device='cuda:0', grad_fn=<ExpBackward0>)
Uniform loss over the vocab 10.82490511970208


We can also greedily generate text:

In [84]:
test_string = "Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on"
for i in tqdm.tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).cuda()
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on Monday.


The House of Representatives is expected to vote on the impeachment of President Trump on Tuesday.


The House of Representatives is expected to vote on the impeachment of President Trump on Tuesday.


The Senate is expected to begin the trial on Monday.


The House of Representatives is expected to vote on the impeachment of President Trump on Tuesday.


The Senate is expected to begin the trial on Monday.


The House of Representatives is expected to vote on

# Training a Model!

This is a lightweight demonstration of how you can actually train your own GPT-2 with this code! Here we train a tiny model on a tiny dataset, but it's fundamentally the same code for training a larger/more real model (though you'll need beefier GPUs and data parallelism to do it remotely efficiently, and fancier parallelism for much bigger ones).

For our purposes, we'll train 2L 4 heads per layer model, with context length 256, for 1000 steps of batch size 8, just to show what it looks like (and so the notebook doesn't melt your colab lol).

In [85]:
if IN_COLAB:
    %pip install datasets
    %pip install transformers
import datasets
import transformers
import plotly.express as px

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Config

In [121]:
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR

# Hyperparameters
batch_size = 4  # This remains unchanged due to GPU RAM constraints
num_epochs = 100  # Further increased number of epochs
max_steps = 20000  # Further increased max_steps
log_every = 10  # Keeping this the same
lr = 5e-5  # Further reduced the learning rate
weight_decay = 1e-5  # Further reduced the weight decay
gradient_accumulation_steps = 8  # Increase gradient accumulation steps
max_grad_norm = 0.5  # Reduced gradient clipping norm

# Model configuration (smaller, but slightly bigger than the last version)
model_cfg = Config(debug=False,
                   d_model=288, # Slightly bigger hidden size
                   n_heads=6, # Intermediate number of attention heads
                   d_head=96, # Slightly bigger head size
                   d_mlp=1536, # Slightly bigger MLP layer size
                   n_layers=3, # Keeping the number of layers the same
                   n_ctx=512,
                   d_vocab=reference_gpt2.cfg.d_vocab)

# Optimizer
optimizer = optim.AdamW(params=model.parameters(), lr=lr, weight_decay=weight_decay)

# Learning rate scheduler (OneCycleLR)
scheduler = OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=max_steps)

# Free up memory
torch.cuda.empty_cache()

# Additional Code Required:
# - Include gradient accumulation in the training loop
# - Include gradient clipping in the training loop
# - Implement early stopping based on validation loss
# - Implement the learning rate scheduler step
# - Include mixed precision training if possible



## Create Data

We load in a tiny dataset I made, with the first 10K entries in the Pile (inspired by Stas' version for OpenWebText!)


In [122]:
dataset = datasets.load_dataset("tiny_shakespeare", split="train")
print(dataset)
print(dataset[0]['text'][:100])
tokens_dataset = tokenize_and_concatenate(dataset, reference_gpt2.tokenizer, streaming=False, max_length=model_cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)
data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)




Dataset({
    features: ['text'],
    num_rows: 1
})
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You



This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



## Create Model


In [123]:
model = DemoTransformer(model_cfg)
model.cuda()


DemoTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

## Create Optimizer
We use AdamW - it's a pretty standard optimizer.

In [124]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

## Run Training Loop


In [None]:
losses = []
print("Number of batches:", len(data_loader))
for epoch in range(num_epochs):
    for c, batch in tqdm.tqdm(enumerate(data_loader)):
        tokens = batch['tokens'].cuda()
        logits = model(tokens)
        loss = lm_cross_entropy_loss(logits, tokens)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        if c % log_every == 0:
            print(f"Step: {c}, Loss: {loss.item():.4f}")
        if c > max_steps:
            break


Number of batches: 148


0it [00:00, ?it/s]

Step: 0, Loss: 10.9331
Step: 10, Loss: 10.0648
Step: 20, Loss: 9.8268
Step: 30, Loss: 9.5950
Step: 40, Loss: 9.3640
Step: 50, Loss: 9.2556
Step: 60, Loss: 9.0656
Step: 70, Loss: 8.9455
Step: 80, Loss: 8.6326
Step: 90, Loss: 8.4397
Step: 100, Loss: 8.2330
Step: 110, Loss: 8.0154
Step: 120, Loss: 7.7285
Step: 130, Loss: 7.8309
Step: 140, Loss: 7.5608


0it [00:00, ?it/s]

Step: 0, Loss: 7.4956
Step: 10, Loss: 7.3880
Step: 20, Loss: 7.2967
Step: 30, Loss: 7.2124
Step: 40, Loss: 6.8932
Step: 50, Loss: 6.8733
Step: 60, Loss: 6.8067
Step: 70, Loss: 6.7526
Step: 80, Loss: 6.6984
Step: 90, Loss: 6.7597
Step: 100, Loss: 6.5538
Step: 110, Loss: 6.4755
Step: 120, Loss: 6.3413
Step: 130, Loss: 6.2973
Step: 140, Loss: 6.3528


0it [00:00, ?it/s]

Step: 0, Loss: 6.2400
Step: 10, Loss: 6.3285
Step: 20, Loss: 6.3751
Step: 30, Loss: 6.2406
Step: 40, Loss: 6.1282
Step: 50, Loss: 6.1139
Step: 60, Loss: 6.2307
Step: 70, Loss: 6.0927
Step: 80, Loss: 6.2064
Step: 90, Loss: 6.0662
Step: 100, Loss: 5.9955
Step: 110, Loss: 6.0035
Step: 120, Loss: 5.8364
Step: 130, Loss: 5.9294
Step: 140, Loss: 5.8781


0it [00:00, ?it/s]

Step: 0, Loss: 5.9056
Step: 10, Loss: 5.7151
Step: 20, Loss: 5.8164
Step: 30, Loss: 5.9611
Step: 40, Loss: 5.7383
Step: 50, Loss: 5.9779
Step: 60, Loss: 5.5817
Step: 70, Loss: 5.8232
Step: 80, Loss: 5.7227
Step: 90, Loss: 5.6342
Step: 100, Loss: 5.5785
Step: 110, Loss: 5.4805
Step: 120, Loss: 5.8367
Step: 130, Loss: 5.8380
Step: 140, Loss: 5.6136


0it [00:00, ?it/s]

Step: 0, Loss: 5.6811
Step: 10, Loss: 5.6379
Step: 20, Loss: 5.4770
Step: 30, Loss: 5.4628
Step: 40, Loss: 5.6933
Step: 50, Loss: 5.4660
Step: 60, Loss: 5.3875
Step: 70, Loss: 5.1718
Step: 80, Loss: 5.6830
Step: 90, Loss: 5.3306
Step: 100, Loss: 5.1902
Step: 110, Loss: 5.3985
Step: 120, Loss: 5.5566
Step: 130, Loss: 5.3821
Step: 140, Loss: 5.5546


0it [00:00, ?it/s]

Step: 0, Loss: 5.5615
Step: 10, Loss: 5.3237
Step: 20, Loss: 5.4454
Step: 30, Loss: 5.3178
Step: 40, Loss: 5.0977
Step: 50, Loss: 5.5331
Step: 60, Loss: 5.5749
Step: 70, Loss: 4.9876
Step: 80, Loss: 5.3156
Step: 90, Loss: 5.1374
Step: 100, Loss: 5.4642
Step: 110, Loss: 5.1612
Step: 120, Loss: 5.2166
Step: 130, Loss: 5.0869
Step: 140, Loss: 5.4656


0it [00:00, ?it/s]

Step: 0, Loss: 5.1173
Step: 10, Loss: 5.3066
Step: 20, Loss: 4.9149
Step: 30, Loss: 5.3760
Step: 40, Loss: 4.9681
Step: 50, Loss: 5.2007
Step: 60, Loss: 5.1552
Step: 70, Loss: 4.8917
Step: 80, Loss: 5.0418
Step: 90, Loss: 4.7804
Step: 100, Loss: 5.0318
Step: 110, Loss: 5.3188
Step: 120, Loss: 5.0634
Step: 130, Loss: 5.0660
Step: 140, Loss: 4.6253


0it [00:00, ?it/s]

Step: 0, Loss: 4.8941
Step: 10, Loss: 4.6790
Step: 20, Loss: 5.0723
Step: 30, Loss: 4.6147
Step: 40, Loss: 4.7250
Step: 50, Loss: 4.9024
Step: 60, Loss: 5.0121
Step: 70, Loss: 5.1855
Step: 80, Loss: 5.1488
Step: 90, Loss: 4.8116
Step: 100, Loss: 5.0519
Step: 110, Loss: 4.7559
Step: 120, Loss: 4.9390
Step: 130, Loss: 4.9643
Step: 140, Loss: 4.9182


0it [00:00, ?it/s]

Step: 0, Loss: 4.8988
Step: 10, Loss: 4.7258
Step: 20, Loss: 4.4780
Step: 30, Loss: 4.7964
Step: 40, Loss: 4.3943
Step: 50, Loss: 4.9085
Step: 60, Loss: 4.6214
Step: 70, Loss: 4.5718
Step: 80, Loss: 4.8868
Step: 90, Loss: 4.5740
Step: 100, Loss: 5.0130
Step: 110, Loss: 4.8741
Step: 120, Loss: 4.6901
Step: 130, Loss: 4.9313
Step: 140, Loss: 4.8378


0it [00:00, ?it/s]

Step: 0, Loss: 4.3265
Step: 10, Loss: 4.5170
Step: 20, Loss: 4.5644
Step: 30, Loss: 4.9843
Step: 40, Loss: 4.8853
Step: 50, Loss: 4.7237
Step: 60, Loss: 4.5810
Step: 70, Loss: 4.6778
Step: 80, Loss: 4.8580
Step: 90, Loss: 4.6259
Step: 100, Loss: 4.7300
Step: 110, Loss: 4.7920
Step: 120, Loss: 4.8680
Step: 130, Loss: 4.8048
Step: 140, Loss: 4.4687


0it [00:00, ?it/s]

Step: 0, Loss: 4.3583
Step: 10, Loss: 4.6464
Step: 20, Loss: 4.5423
Step: 30, Loss: 4.9320
Step: 40, Loss: 4.7184
Step: 50, Loss: 4.5687
Step: 60, Loss: 4.6736
Step: 70, Loss: 4.1721
Step: 80, Loss: 4.4341
Step: 90, Loss: 4.5912
Step: 100, Loss: 4.0395
Step: 110, Loss: 4.8832
Step: 120, Loss: 4.6042
Step: 130, Loss: 4.4703
Step: 140, Loss: 4.3433


0it [00:00, ?it/s]

Step: 0, Loss: 4.2740
Step: 10, Loss: 4.8993
Step: 20, Loss: 4.3305
Step: 30, Loss: 4.7291
Step: 40, Loss: 4.3194
Step: 50, Loss: 4.6996
Step: 60, Loss: 4.5073
Step: 70, Loss: 4.3595
Step: 80, Loss: 4.4785
Step: 90, Loss: 4.6953
Step: 100, Loss: 4.2880
Step: 110, Loss: 4.5828
Step: 120, Loss: 4.3431
Step: 130, Loss: 4.3933
Step: 140, Loss: 4.7973


0it [00:00, ?it/s]

Step: 0, Loss: 4.6257
Step: 10, Loss: 4.4258
Step: 20, Loss: 4.6931
Step: 30, Loss: 4.5513
Step: 40, Loss: 4.6182
Step: 50, Loss: 4.2921
Step: 60, Loss: 4.3267
Step: 70, Loss: 4.6211
Step: 80, Loss: 4.1171
Step: 90, Loss: 4.4641
Step: 100, Loss: 4.1121
Step: 110, Loss: 4.3706
Step: 120, Loss: 4.5935
Step: 130, Loss: 4.4537
Step: 140, Loss: 4.2350


0it [00:00, ?it/s]

Step: 0, Loss: 4.1340
Step: 10, Loss: 4.4407
Step: 20, Loss: 4.3874
Step: 30, Loss: 4.4018
Step: 40, Loss: 4.2686
Step: 50, Loss: 4.3233
Step: 60, Loss: 4.5092
Step: 70, Loss: 4.3527
Step: 80, Loss: 4.2501
Step: 90, Loss: 4.2878
Step: 100, Loss: 4.5235
Step: 110, Loss: 4.5974
Step: 120, Loss: 4.5980
Step: 130, Loss: 4.0548
Step: 140, Loss: 3.9325


0it [00:00, ?it/s]

Step: 0, Loss: 3.8710
Step: 10, Loss: 4.0991
Step: 20, Loss: 4.3235
Step: 30, Loss: 4.0529
Step: 40, Loss: 4.0531
Step: 50, Loss: 4.0673
Step: 60, Loss: 4.4023
Step: 70, Loss: 4.2397
Step: 80, Loss: 4.0488
Step: 90, Loss: 4.3317
Step: 100, Loss: 4.3112
Step: 110, Loss: 4.2394
Step: 120, Loss: 4.2519
Step: 130, Loss: 4.2040
Step: 140, Loss: 4.0090


0it [00:00, ?it/s]

Step: 0, Loss: 4.2175
Step: 10, Loss: 3.9482
Step: 20, Loss: 4.3088
Step: 30, Loss: 4.2222
Step: 40, Loss: 4.0815
Step: 50, Loss: 3.8016
Step: 60, Loss: 4.4297
Step: 70, Loss: 4.4389
Step: 80, Loss: 4.0592
Step: 90, Loss: 4.0358
Step: 100, Loss: 4.2831
Step: 110, Loss: 4.3786
Step: 120, Loss: 4.1594
Step: 130, Loss: 3.8353
Step: 140, Loss: 4.1217


0it [00:00, ?it/s]

Step: 0, Loss: 3.8400
Step: 10, Loss: 3.9125
Step: 20, Loss: 4.0739
Step: 30, Loss: 3.8722
Step: 40, Loss: 4.2670
Step: 50, Loss: 3.9363
Step: 60, Loss: 3.8256
Step: 70, Loss: 4.1755
Step: 80, Loss: 4.2292
Step: 90, Loss: 4.1512
Step: 100, Loss: 4.1424
Step: 110, Loss: 4.2184
Step: 120, Loss: 4.2875
Step: 130, Loss: 4.0110
Step: 140, Loss: 4.0565


0it [00:00, ?it/s]

Step: 0, Loss: 4.1564
Step: 10, Loss: 3.9084
Step: 20, Loss: 3.8059
Step: 30, Loss: 4.0795
Step: 40, Loss: 3.7393
Step: 50, Loss: 3.9842
Step: 60, Loss: 3.8408
Step: 70, Loss: 4.0602
Step: 80, Loss: 4.1113
Step: 90, Loss: 4.0745
Step: 100, Loss: 3.9362
Step: 110, Loss: 3.8577
Step: 120, Loss: 4.1404
Step: 130, Loss: 3.9907
Step: 140, Loss: 4.2608


0it [00:00, ?it/s]

Step: 0, Loss: 3.8911
Step: 10, Loss: 4.1641
Step: 20, Loss: 3.6823
Step: 30, Loss: 4.0765
Step: 40, Loss: 3.9174
Step: 50, Loss: 4.0734
Step: 60, Loss: 3.9545
Step: 70, Loss: 3.6580
Step: 80, Loss: 3.4810
Step: 90, Loss: 3.7132
Step: 100, Loss: 3.8730
Step: 110, Loss: 4.0224
Step: 120, Loss: 4.0528
Step: 130, Loss: 4.0979
Step: 140, Loss: 4.2543


0it [00:00, ?it/s]

Step: 0, Loss: 3.6955
Step: 10, Loss: 3.9477
Step: 20, Loss: 3.7341
Step: 30, Loss: 4.0259
Step: 40, Loss: 4.1019
Step: 50, Loss: 3.7810
Step: 60, Loss: 3.5676
Step: 70, Loss: 3.5524
Step: 80, Loss: 3.8563
Step: 90, Loss: 3.6044
Step: 100, Loss: 3.4697
Step: 110, Loss: 3.9592
Step: 120, Loss: 3.6680
Step: 130, Loss: 4.0443
Step: 140, Loss: 3.8752


0it [00:00, ?it/s]

Step: 0, Loss: 3.7834
Step: 10, Loss: 3.9438
Step: 20, Loss: 3.6874
Step: 30, Loss: 3.7645
Step: 40, Loss: 3.4389
Step: 50, Loss: 3.7330
Step: 60, Loss: 3.9500
Step: 70, Loss: 3.8143
Step: 80, Loss: 3.8919
Step: 90, Loss: 3.6257
Step: 100, Loss: 3.5796
Step: 110, Loss: 3.7936
Step: 120, Loss: 4.1761
Step: 130, Loss: 3.4801
Step: 140, Loss: 3.9975


0it [00:00, ?it/s]

Step: 0, Loss: 3.5916
Step: 10, Loss: 3.5239
Step: 20, Loss: 3.4972
Step: 30, Loss: 3.8054
Step: 40, Loss: 3.8545
Step: 50, Loss: 3.7218
Step: 60, Loss: 3.9201
Step: 70, Loss: 3.6513
Step: 80, Loss: 3.6435
Step: 90, Loss: 3.7092
Step: 100, Loss: 3.6868
Step: 110, Loss: 3.8769
Step: 120, Loss: 3.4052
Step: 130, Loss: 3.4423
Step: 140, Loss: 3.5849


0it [00:00, ?it/s]

Step: 0, Loss: 3.7469
Step: 10, Loss: 3.6334
Step: 20, Loss: 3.7202
Step: 30, Loss: 3.6416
Step: 40, Loss: 3.5397
Step: 50, Loss: 3.5754
Step: 60, Loss: 3.7671
Step: 70, Loss: 3.5039
Step: 80, Loss: 3.6966
Step: 90, Loss: 3.3783
Step: 100, Loss: 3.4645
Step: 110, Loss: 3.6443
Step: 120, Loss: 3.4713
Step: 130, Loss: 3.5939
Step: 140, Loss: 3.8173


0it [00:00, ?it/s]

Step: 0, Loss: 3.4760
Step: 10, Loss: 3.4649
Step: 20, Loss: 3.5541
Step: 30, Loss: 3.5557
Step: 40, Loss: 3.6126
Step: 50, Loss: 3.3746
Step: 60, Loss: 3.5947
Step: 70, Loss: 3.6837
Step: 80, Loss: 3.5729
Step: 90, Loss: 3.6535
Step: 100, Loss: 3.5309
Step: 110, Loss: 3.4503
Step: 120, Loss: 3.3545
Step: 130, Loss: 3.5653
Step: 140, Loss: 3.2015


0it [00:00, ?it/s]

Step: 0, Loss: 3.6311
Step: 10, Loss: 3.4216
Step: 20, Loss: 3.2447
Step: 30, Loss: 3.4019
Step: 40, Loss: 3.3283
Step: 50, Loss: 3.3543
Step: 60, Loss: 3.4117
Step: 70, Loss: 3.4176
Step: 80, Loss: 3.5583
Step: 90, Loss: 3.0313
Step: 100, Loss: 3.4684
Step: 110, Loss: 3.3468
Step: 120, Loss: 3.4799
Step: 130, Loss: 3.4105
Step: 140, Loss: 3.4692


0it [00:00, ?it/s]

Step: 0, Loss: 3.3556
Step: 10, Loss: 3.0140
Step: 20, Loss: 3.1694
Step: 30, Loss: 3.3748
Step: 40, Loss: 3.0811
Step: 50, Loss: 3.1183
Step: 60, Loss: 3.3458
Step: 70, Loss: 3.5111
Step: 80, Loss: 3.2651
Step: 90, Loss: 3.4628
Step: 100, Loss: 3.2146
Step: 110, Loss: 3.1265
Step: 120, Loss: 3.4155
Step: 130, Loss: 3.4237
Step: 140, Loss: 3.3819


0it [00:00, ?it/s]

Step: 0, Loss: 3.2521
Step: 10, Loss: 3.0159
Step: 20, Loss: 3.2913
Step: 30, Loss: 3.2020
Step: 40, Loss: 3.2333
Step: 50, Loss: 3.2903
Step: 60, Loss: 3.4230
Step: 70, Loss: 3.3538
Step: 80, Loss: 3.1663
Step: 90, Loss: 2.9585
Step: 100, Loss: 3.2977
Step: 110, Loss: 3.2847
Step: 120, Loss: 3.2186
Step: 130, Loss: 3.1669
Step: 140, Loss: 3.4236


0it [00:00, ?it/s]

Step: 0, Loss: 2.9528
Step: 10, Loss: 3.1533
Step: 20, Loss: 3.2053
Step: 30, Loss: 3.1029
Step: 40, Loss: 3.2139
Step: 50, Loss: 3.0895
Step: 60, Loss: 2.9903
Step: 70, Loss: 3.2436
Step: 80, Loss: 3.0762
Step: 90, Loss: 3.1509
Step: 100, Loss: 3.2132
Step: 110, Loss: 2.9950
Step: 120, Loss: 3.0993
Step: 130, Loss: 3.1612
Step: 140, Loss: 3.0550


0it [00:00, ?it/s]

Step: 0, Loss: 2.8592
Step: 10, Loss: 3.2263
Step: 20, Loss: 3.2720
Step: 30, Loss: 3.0938
Step: 40, Loss: 3.1071
Step: 50, Loss: 3.0461
Step: 60, Loss: 2.9278
Step: 70, Loss: 2.9601
Step: 80, Loss: 3.2847
Step: 90, Loss: 2.9977
Step: 100, Loss: 3.0893
Step: 110, Loss: 3.3217
Step: 120, Loss: 3.3343
Step: 130, Loss: 3.2479
Step: 140, Loss: 3.1027


0it [00:00, ?it/s]

Step: 0, Loss: 3.1652
Step: 10, Loss: 2.9777
Step: 20, Loss: 2.8345
Step: 30, Loss: 3.2339
Step: 40, Loss: 3.0133
Step: 50, Loss: 3.0161
Step: 60, Loss: 2.9404
Step: 70, Loss: 2.7942
Step: 80, Loss: 2.9026
Step: 90, Loss: 2.9175
Step: 100, Loss: 3.2539
Step: 110, Loss: 3.0362
Step: 120, Loss: 3.1648
Step: 130, Loss: 2.8517
Step: 140, Loss: 3.0577


0it [00:00, ?it/s]

Step: 0, Loss: 2.7135
Step: 10, Loss: 2.9484
Step: 20, Loss: 2.7955
Step: 30, Loss: 2.8437
Step: 40, Loss: 3.2259
Step: 50, Loss: 2.9219
Step: 60, Loss: 3.0894
Step: 70, Loss: 2.7886
Step: 80, Loss: 3.0397
Step: 90, Loss: 2.8039
Step: 100, Loss: 2.9619
Step: 110, Loss: 2.7729
Step: 120, Loss: 2.9673
Step: 130, Loss: 2.8237
Step: 140, Loss: 2.9308


0it [00:00, ?it/s]

Step: 0, Loss: 2.7661
Step: 10, Loss: 2.5257
Step: 20, Loss: 2.8739
Step: 30, Loss: 2.8802
Step: 40, Loss: 2.7015
Step: 50, Loss: 2.8644
Step: 60, Loss: 2.9339
Step: 70, Loss: 2.9002
Step: 80, Loss: 2.8901
Step: 90, Loss: 2.8977
Step: 100, Loss: 2.8134
Step: 110, Loss: 2.9702
Step: 120, Loss: 2.8413
Step: 130, Loss: 2.9874
Step: 140, Loss: 2.7692


0it [00:00, ?it/s]

Step: 0, Loss: 2.6812
Step: 10, Loss: 2.5654
Step: 20, Loss: 2.7056
Step: 30, Loss: 2.7169
Step: 40, Loss: 2.4593
Step: 50, Loss: 2.8234
Step: 60, Loss: 2.6990
Step: 70, Loss: 2.8091
Step: 80, Loss: 2.7047
Step: 90, Loss: 2.6491
Step: 100, Loss: 2.7962
Step: 110, Loss: 2.6938
Step: 120, Loss: 2.9004
Step: 130, Loss: 2.7333
Step: 140, Loss: 2.9086


0it [00:00, ?it/s]

Step: 0, Loss: 2.8894
Step: 10, Loss: 2.8164
Step: 20, Loss: 2.6583
Step: 30, Loss: 2.6740
Step: 40, Loss: 2.6350
Step: 50, Loss: 2.6484
Step: 60, Loss: 2.6558
Step: 70, Loss: 2.7537
Step: 80, Loss: 2.6692
Step: 90, Loss: 2.5612
Step: 100, Loss: 2.6524
Step: 110, Loss: 2.6477
Step: 120, Loss: 2.6395
Step: 130, Loss: 2.7219
Step: 140, Loss: 2.9520


0it [00:00, ?it/s]

Step: 0, Loss: 2.6576
Step: 10, Loss: 2.4998
Step: 20, Loss: 2.5084
Step: 30, Loss: 2.4540
Step: 40, Loss: 2.7400
Step: 50, Loss: 2.4588
Step: 60, Loss: 2.6893
Step: 70, Loss: 2.6450
Step: 80, Loss: 2.6781
Step: 90, Loss: 2.5674
Step: 100, Loss: 2.6765
Step: 110, Loss: 2.7687
Step: 120, Loss: 2.5535
Step: 130, Loss: 2.4982
Step: 140, Loss: 2.6859


0it [00:00, ?it/s]

Step: 0, Loss: 2.3949
Step: 10, Loss: 2.6874
Step: 20, Loss: 2.4527
Step: 30, Loss: 2.4011
Step: 40, Loss: 2.5824
Step: 50, Loss: 2.5484
Step: 60, Loss: 2.4467
Step: 70, Loss: 2.6355
Step: 80, Loss: 2.5572
Step: 90, Loss: 2.5402
Step: 100, Loss: 2.6631
Step: 110, Loss: 2.5950
Step: 120, Loss: 2.5587
Step: 130, Loss: 2.4257
Step: 140, Loss: 2.5852


0it [00:00, ?it/s]

Step: 0, Loss: 2.5544
Step: 10, Loss: 2.4150
Step: 20, Loss: 2.5565
Step: 30, Loss: 2.4988
Step: 40, Loss: 2.5241
Step: 50, Loss: 2.5709
Step: 60, Loss: 2.4279
Step: 70, Loss: 2.3031
Step: 80, Loss: 2.6480
Step: 90, Loss: 2.6171
Step: 100, Loss: 2.5674
Step: 110, Loss: 2.4289
Step: 120, Loss: 2.3102
Step: 130, Loss: 2.5908
Step: 140, Loss: 2.4298


0it [00:00, ?it/s]

Step: 0, Loss: 2.5905
Step: 10, Loss: 2.5343
Step: 20, Loss: 2.4798
Step: 30, Loss: 2.4823
Step: 40, Loss: 2.4636
Step: 50, Loss: 2.4269
Step: 60, Loss: 2.3738
Step: 70, Loss: 2.3263
Step: 80, Loss: 2.6579
Step: 90, Loss: 2.2233
Step: 100, Loss: 2.2988
Step: 110, Loss: 2.3757
Step: 120, Loss: 2.4775
Step: 130, Loss: 2.5057
Step: 140, Loss: 2.5127


0it [00:00, ?it/s]

Step: 0, Loss: 2.3432
Step: 10, Loss: 2.2751
Step: 20, Loss: 2.2718
Step: 30, Loss: 2.3458
Step: 40, Loss: 2.3934
Step: 50, Loss: 2.3707
Step: 60, Loss: 2.3483
Step: 70, Loss: 2.0853
Step: 80, Loss: 2.3151
Step: 90, Loss: 2.3163
Step: 100, Loss: 2.3123
Step: 110, Loss: 2.3837
Step: 120, Loss: 2.1581
Step: 130, Loss: 2.1501
Step: 140, Loss: 2.5404


0it [00:00, ?it/s]

Step: 0, Loss: 1.9856
Step: 10, Loss: 2.2011
Step: 20, Loss: 2.2204
Step: 30, Loss: 2.3200
Step: 40, Loss: 2.1557
Step: 50, Loss: 2.3022
Step: 60, Loss: 2.0962
Step: 70, Loss: 2.1590
Step: 80, Loss: 2.2995
Step: 90, Loss: 2.1808
Step: 100, Loss: 2.3172
Step: 110, Loss: 2.3098
Step: 120, Loss: 2.2010
Step: 130, Loss: 2.0387
Step: 140, Loss: 2.3067


0it [00:00, ?it/s]

Step: 0, Loss: 2.1908
Step: 10, Loss: 2.1691
Step: 20, Loss: 2.2805
Step: 30, Loss: 2.2914
Step: 40, Loss: 2.0766
Step: 50, Loss: 2.1691
Step: 60, Loss: 2.3041
Step: 70, Loss: 2.1843
Step: 80, Loss: 2.0233
Step: 90, Loss: 2.2607
Step: 100, Loss: 2.2768
Step: 110, Loss: 2.2594
Step: 120, Loss: 2.2011
Step: 130, Loss: 2.1119
Step: 140, Loss: 2.2027


0it [00:00, ?it/s]

Step: 0, Loss: 2.0626
Step: 10, Loss: 2.0264
Step: 20, Loss: 2.1690
Step: 30, Loss: 2.1260
Step: 40, Loss: 2.1827
Step: 50, Loss: 2.2124
Step: 60, Loss: 1.8843
Step: 70, Loss: 2.1753
Step: 80, Loss: 2.1223
Step: 90, Loss: 2.1795
Step: 100, Loss: 2.1671
Step: 110, Loss: 2.1457
Step: 120, Loss: 2.2186
Step: 130, Loss: 2.0556
Step: 140, Loss: 2.2724


0it [00:00, ?it/s]

Step: 0, Loss: 1.9284
Step: 10, Loss: 1.9607
Step: 20, Loss: 2.0503
Step: 30, Loss: 2.1245
Step: 40, Loss: 1.9665
Step: 50, Loss: 1.7964
Step: 60, Loss: 2.0689
Step: 70, Loss: 2.0100
Step: 80, Loss: 2.1124
Step: 90, Loss: 2.0977
Step: 100, Loss: 1.8899
Step: 110, Loss: 2.0500
Step: 120, Loss: 2.0999
Step: 130, Loss: 1.9441
Step: 140, Loss: 2.0645


0it [00:00, ?it/s]

Step: 0, Loss: 2.1822
Step: 10, Loss: 1.8121
Step: 20, Loss: 1.8496
Step: 30, Loss: 1.9357
Step: 40, Loss: 1.6932
Step: 50, Loss: 2.1736
Step: 60, Loss: 1.9923
Step: 70, Loss: 1.9519
Step: 80, Loss: 1.8198
Step: 90, Loss: 1.9826
Step: 100, Loss: 1.8956
Step: 110, Loss: 1.9244
Step: 120, Loss: 2.0410
Step: 130, Loss: 1.7632
Step: 140, Loss: 1.9601


0it [00:00, ?it/s]

Step: 0, Loss: 1.9604
Step: 10, Loss: 1.9232
Step: 20, Loss: 1.7735
Step: 30, Loss: 1.7684
Step: 40, Loss: 1.7311
Step: 50, Loss: 1.8615
Step: 60, Loss: 1.8598
Step: 70, Loss: 1.8186
Step: 80, Loss: 1.8669
Step: 90, Loss: 1.9732
Step: 100, Loss: 1.8596
Step: 110, Loss: 1.9531
Step: 120, Loss: 1.9799
Step: 130, Loss: 1.9331
Step: 140, Loss: 1.9386


0it [00:00, ?it/s]

Step: 0, Loss: 1.9793
Step: 10, Loss: 1.7794
Step: 20, Loss: 1.7338
Step: 30, Loss: 1.8917
Step: 40, Loss: 1.7068
Step: 50, Loss: 1.8008
Step: 60, Loss: 1.6812
Step: 70, Loss: 1.8355
Step: 80, Loss: 1.9616
Step: 90, Loss: 1.8502
Step: 100, Loss: 1.7555
Step: 110, Loss: 1.8442
Step: 120, Loss: 1.9289
Step: 130, Loss: 1.8523
Step: 140, Loss: 1.8768


0it [00:00, ?it/s]

Step: 0, Loss: 1.7636
Step: 10, Loss: 1.7675
Step: 20, Loss: 1.5925
Step: 30, Loss: 1.8882
Step: 40, Loss: 1.8411
Step: 50, Loss: 1.7974
Step: 60, Loss: 1.8404
Step: 70, Loss: 1.6602
Step: 80, Loss: 1.8020
Step: 90, Loss: 1.9373
Step: 100, Loss: 1.7907
Step: 110, Loss: 1.8127
Step: 120, Loss: 1.7851
Step: 130, Loss: 1.6895
Step: 140, Loss: 1.7355


0it [00:00, ?it/s]

Step: 0, Loss: 1.7361
Step: 10, Loss: 1.6351
Step: 20, Loss: 1.5840
Step: 30, Loss: 1.8076
Step: 40, Loss: 1.5557
Step: 50, Loss: 1.7869
Step: 60, Loss: 1.8352
Step: 70, Loss: 1.6861
Step: 80, Loss: 1.8269
Step: 90, Loss: 1.6506
Step: 100, Loss: 1.6555
Step: 110, Loss: 1.7664
Step: 120, Loss: 1.7752
Step: 130, Loss: 1.6981
Step: 140, Loss: 1.7182


0it [00:00, ?it/s]

Step: 0, Loss: 1.5890
Step: 10, Loss: 1.6975
Step: 20, Loss: 1.6778
Step: 30, Loss: 1.6078
Step: 40, Loss: 1.4881
Step: 50, Loss: 1.7909
Step: 60, Loss: 1.5863
Step: 70, Loss: 1.6565
Step: 80, Loss: 1.7196
Step: 90, Loss: 1.3694
Step: 100, Loss: 1.6139
Step: 110, Loss: 1.5794
Step: 120, Loss: 1.6118
Step: 130, Loss: 1.6772
Step: 140, Loss: 1.7176


0it [00:00, ?it/s]

Step: 0, Loss: 1.6114
Step: 10, Loss: 1.5469
Step: 20, Loss: 1.4389
Step: 30, Loss: 1.6062
Step: 40, Loss: 1.6374
Step: 50, Loss: 1.5568
Step: 60, Loss: 1.5932
Step: 70, Loss: 1.6548
Step: 80, Loss: 1.6449
Step: 90, Loss: 1.6799
Step: 100, Loss: 1.5823
Step: 110, Loss: 1.5706
Step: 120, Loss: 1.6666
Step: 130, Loss: 1.5929
Step: 140, Loss: 1.6107


0it [00:00, ?it/s]

Step: 0, Loss: 1.5225
Step: 10, Loss: 1.4569
Step: 20, Loss: 1.6577
Step: 30, Loss: 1.5229
Step: 40, Loss: 1.5540
Step: 50, Loss: 1.5567
Step: 60, Loss: 1.6467
Step: 70, Loss: 1.4743
Step: 80, Loss: 1.5077
Step: 90, Loss: 1.5924
Step: 100, Loss: 1.4128
Step: 110, Loss: 1.4337
Step: 120, Loss: 1.5404
Step: 130, Loss: 1.5569
Step: 140, Loss: 1.3859


0it [00:00, ?it/s]

Step: 0, Loss: 1.4674
Step: 10, Loss: 1.4241
Step: 20, Loss: 1.5823
Step: 30, Loss: 1.4721
Step: 40, Loss: 1.4661
Step: 50, Loss: 1.4768
Step: 60, Loss: 1.5250
Step: 70, Loss: 1.4206
Step: 80, Loss: 1.4794
Step: 90, Loss: 1.4964
Step: 100, Loss: 1.5773
Step: 110, Loss: 1.5306
Step: 120, Loss: 1.5141
Step: 130, Loss: 1.4201
Step: 140, Loss: 1.5214


0it [00:00, ?it/s]

Step: 0, Loss: 1.3699
Step: 10, Loss: 1.3873
Step: 20, Loss: 1.2627
Step: 30, Loss: 1.3169
Step: 40, Loss: 1.5472
Step: 50, Loss: 1.4294
Step: 60, Loss: 1.5002
Step: 70, Loss: 1.4094
Step: 80, Loss: 1.3819
Step: 90, Loss: 1.4345
Step: 100, Loss: 1.3202
Step: 110, Loss: 1.4343
Step: 120, Loss: 1.2895
Step: 130, Loss: 1.4364
Step: 140, Loss: 1.5582


0it [00:00, ?it/s]

Step: 0, Loss: 1.3457
Step: 10, Loss: 1.4073
Step: 20, Loss: 1.3551
Step: 30, Loss: 1.3206
Step: 40, Loss: 1.3122
Step: 50, Loss: 1.2034
Step: 60, Loss: 1.2842
Step: 70, Loss: 1.4975
Step: 80, Loss: 1.4064
Step: 90, Loss: 1.3726
Step: 100, Loss: 1.2943
Step: 110, Loss: 1.4066
Step: 120, Loss: 1.4068
Step: 130, Loss: 1.4136
Step: 140, Loss: 1.3848


0it [00:00, ?it/s]

Step: 0, Loss: 1.1985
Step: 10, Loss: 1.3735
Step: 20, Loss: 1.3595
Step: 30, Loss: 1.2201
Step: 40, Loss: 1.2629
Step: 50, Loss: 1.3572
Step: 60, Loss: 1.2753
Step: 70, Loss: 1.2505
Step: 80, Loss: 1.3109
Step: 90, Loss: 1.3383
Step: 100, Loss: 1.3573
Step: 110, Loss: 1.2807
Step: 120, Loss: 1.3091
Step: 130, Loss: 1.3280
Step: 140, Loss: 1.2274


0it [00:00, ?it/s]

Step: 0, Loss: 1.2527
Step: 10, Loss: 1.2322
Step: 20, Loss: 1.1714
Step: 30, Loss: 1.1353
Step: 40, Loss: 1.2080
Step: 50, Loss: 1.2848
Step: 60, Loss: 1.1466
Step: 70, Loss: 1.1638
Step: 80, Loss: 1.3105
Step: 90, Loss: 1.2424
Step: 100, Loss: 1.2251
Step: 110, Loss: 1.2345
Step: 120, Loss: 1.2107
Step: 130, Loss: 1.2034
Step: 140, Loss: 1.1396


0it [00:00, ?it/s]

Step: 0, Loss: 1.1907
Step: 10, Loss: 1.2087
Step: 20, Loss: 0.9565
Step: 30, Loss: 1.1368
Step: 40, Loss: 1.1071
Step: 50, Loss: 1.1381
Step: 60, Loss: 1.2093
Step: 70, Loss: 1.2061
Step: 80, Loss: 1.1953
Step: 90, Loss: 1.0770
Step: 100, Loss: 1.1837
Step: 110, Loss: 1.2099
Step: 120, Loss: 1.1444
Step: 130, Loss: 1.3061
Step: 140, Loss: 1.2966


0it [00:00, ?it/s]

Step: 0, Loss: 1.0292
Step: 10, Loss: 1.0416
Step: 20, Loss: 1.1236
Step: 30, Loss: 1.2571
Step: 40, Loss: 1.1361
Step: 50, Loss: 1.1442
Step: 60, Loss: 1.1378
Step: 70, Loss: 1.0692
Step: 80, Loss: 1.1195
Step: 90, Loss: 1.0926
Step: 100, Loss: 1.0972
Step: 110, Loss: 1.1294
Step: 120, Loss: 1.2898
Step: 130, Loss: 1.1434
Step: 140, Loss: 1.1222


0it [00:00, ?it/s]

Step: 0, Loss: 1.0451
Step: 10, Loss: 0.9519
Step: 20, Loss: 1.0343
Step: 30, Loss: 1.1005
Step: 40, Loss: 0.9805
Step: 50, Loss: 1.1644
Step: 60, Loss: 1.2097
Step: 70, Loss: 1.0202
Step: 80, Loss: 1.0778
Step: 90, Loss: 1.0822
Step: 100, Loss: 1.0926
Step: 110, Loss: 1.1063
Step: 120, Loss: 1.0200
Step: 130, Loss: 1.0866
Step: 140, Loss: 1.0853


0it [00:00, ?it/s]

Step: 0, Loss: 0.9866
Step: 10, Loss: 0.9448
Step: 20, Loss: 0.9438
Step: 30, Loss: 1.0206
Step: 40, Loss: 1.0706
Step: 50, Loss: 1.0127
Step: 60, Loss: 1.0170
Step: 70, Loss: 1.0464
Step: 80, Loss: 1.0276
Step: 90, Loss: 0.9779
Step: 100, Loss: 0.9377
Step: 110, Loss: 1.0246
Step: 120, Loss: 1.0739
Step: 130, Loss: 1.0724
Step: 140, Loss: 1.0503


0it [00:00, ?it/s]

Step: 0, Loss: 0.9674
Step: 10, Loss: 0.9192
Step: 20, Loss: 0.9732
Step: 30, Loss: 0.9952
Step: 40, Loss: 0.9448
Step: 50, Loss: 1.0194
Step: 60, Loss: 1.0009
Step: 70, Loss: 0.9219
Step: 80, Loss: 1.0729
Step: 90, Loss: 0.9777
Step: 100, Loss: 0.9776
Step: 110, Loss: 0.8911
Step: 120, Loss: 0.9786
Step: 130, Loss: 0.9019
Step: 140, Loss: 0.8790


0it [00:00, ?it/s]

Step: 0, Loss: 0.8200
Step: 10, Loss: 0.8884
Step: 20, Loss: 0.8891
Step: 30, Loss: 0.7985
Step: 40, Loss: 1.0061
Step: 50, Loss: 0.8643
Step: 60, Loss: 0.8415
Step: 70, Loss: 1.0418
Step: 80, Loss: 0.8874
Step: 90, Loss: 1.0282
Step: 100, Loss: 0.8805
Step: 110, Loss: 0.9481
Step: 120, Loss: 0.8584
Step: 130, Loss: 0.8824
Step: 140, Loss: 0.9257


0it [00:00, ?it/s]

Step: 0, Loss: 0.7958
Step: 10, Loss: 0.8456
Step: 20, Loss: 0.8515
Step: 30, Loss: 0.8197
Step: 40, Loss: 0.8656
Step: 50, Loss: 0.8441
Step: 60, Loss: 0.9412
Step: 70, Loss: 0.8505
Step: 80, Loss: 0.9660
Step: 90, Loss: 0.9507
Step: 100, Loss: 0.8972
Step: 110, Loss: 0.8525
Step: 120, Loss: 0.9012
Step: 130, Loss: 0.9507
Step: 140, Loss: 0.9554


0it [00:00, ?it/s]

Step: 0, Loss: 0.7647
Step: 10, Loss: 0.7701
Step: 20, Loss: 0.8522
Step: 30, Loss: 0.7060
Step: 40, Loss: 0.8569
Step: 50, Loss: 0.8469
Step: 60, Loss: 0.8451
Step: 70, Loss: 0.8848
Step: 80, Loss: 0.8519
Step: 90, Loss: 0.9079
Step: 100, Loss: 0.8756
Step: 110, Loss: 0.8372
Step: 120, Loss: 0.9310
Step: 130, Loss: 0.8306
Step: 140, Loss: 0.8080


0it [00:00, ?it/s]

Step: 0, Loss: 0.7905
Step: 10, Loss: 0.7347
Step: 20, Loss: 0.8005
Step: 30, Loss: 0.7801
Step: 40, Loss: 0.8320
Step: 50, Loss: 0.8158
Step: 60, Loss: 0.7428
Step: 70, Loss: 0.8379
Step: 80, Loss: 0.8443
Step: 90, Loss: 0.7897
Step: 100, Loss: 0.7955
Step: 110, Loss: 0.7587
Step: 120, Loss: 0.7842
Step: 130, Loss: 0.8023
Step: 140, Loss: 0.7704


0it [00:00, ?it/s]

Step: 0, Loss: 0.8009
Step: 10, Loss: 0.7743
Step: 20, Loss: 0.7792
Step: 30, Loss: 0.7761
Step: 40, Loss: 0.6733
Step: 50, Loss: 0.7448
Step: 60, Loss: 0.7484
Step: 70, Loss: 0.7150
Step: 80, Loss: 0.7828
Step: 90, Loss: 0.6907
Step: 100, Loss: 0.7738
Step: 110, Loss: 0.7375
Step: 120, Loss: 0.7148
Step: 130, Loss: 0.7307
Step: 140, Loss: 0.8028


0it [00:00, ?it/s]

Step: 0, Loss: 0.6920
Step: 10, Loss: 0.6176
Step: 20, Loss: 0.7073
Step: 30, Loss: 0.6492
Step: 40, Loss: 0.6476
Step: 50, Loss: 0.6718
Step: 60, Loss: 0.7966
Step: 70, Loss: 0.7241
Step: 80, Loss: 0.7091
Step: 90, Loss: 0.6318
Step: 100, Loss: 0.7156
Step: 110, Loss: 0.7153
Step: 120, Loss: 0.7059
Step: 130, Loss: 0.6881
Step: 140, Loss: 0.6632


0it [00:00, ?it/s]

Step: 0, Loss: 0.5814
Step: 10, Loss: 0.5861
Step: 20, Loss: 0.6867
Step: 30, Loss: 0.6527
Step: 40, Loss: 0.6868
Step: 50, Loss: 0.6560
Step: 60, Loss: 0.6172
Step: 70, Loss: 0.6041
Step: 80, Loss: 0.6692
Step: 90, Loss: 0.6541
Step: 100, Loss: 0.5714
Step: 110, Loss: 0.6004
Step: 120, Loss: 0.6145
Step: 130, Loss: 0.6717
Step: 140, Loss: 0.6288


0it [00:00, ?it/s]

Step: 0, Loss: 0.5461
Step: 10, Loss: 0.6516
Step: 20, Loss: 0.6339
Step: 30, Loss: 0.5823
Step: 40, Loss: 0.5200
Step: 50, Loss: 0.5772
Step: 60, Loss: 0.5576
Step: 70, Loss: 0.6037
Step: 80, Loss: 0.6422
Step: 90, Loss: 0.5748
Step: 100, Loss: 0.6061
Step: 110, Loss: 0.5985
Step: 120, Loss: 0.6354
Step: 130, Loss: 0.6253
Step: 140, Loss: 0.5966


0it [00:00, ?it/s]

Step: 0, Loss: 0.5697
Step: 10, Loss: 0.5701
Step: 20, Loss: 0.5888
Step: 30, Loss: 0.5699
Step: 40, Loss: 0.5560
Step: 50, Loss: 0.5932
Step: 60, Loss: 0.5877
Step: 70, Loss: 0.6034
Step: 80, Loss: 0.5626
Step: 90, Loss: 0.5914
Step: 100, Loss: 0.5840
Step: 110, Loss: 0.5528
Step: 120, Loss: 0.5705
Step: 130, Loss: 0.5985
Step: 140, Loss: 0.5382


0it [00:00, ?it/s]

Step: 0, Loss: 0.5107
Step: 10, Loss: 0.4888
Step: 20, Loss: 0.5217
Step: 30, Loss: 0.4811
Step: 40, Loss: 0.5105
Step: 50, Loss: 0.4737
Step: 60, Loss: 0.5677
Step: 70, Loss: 0.5222
Step: 80, Loss: 0.5814
Step: 90, Loss: 0.5367
Step: 100, Loss: 0.5355
Step: 110, Loss: 0.5122
Step: 120, Loss: 0.4545
Step: 130, Loss: 0.4972
Step: 140, Loss: 0.4832


0it [00:00, ?it/s]

Step: 0, Loss: 0.5109
Step: 10, Loss: 0.4498
Step: 20, Loss: 0.4800
Step: 30, Loss: 0.4967
Step: 40, Loss: 0.4877
Step: 50, Loss: 0.5085
Step: 60, Loss: 0.5142
Step: 70, Loss: 0.5670
Step: 80, Loss: 0.5243
Step: 90, Loss: 0.4608


We can now plot a loss curve!

In [126]:
px.line(y=losses, x=np.arange(len(losses))*(model_cfg.n_ctx * batch_size), labels={"y":"Loss", "x":"Tokens"}, title="Training curve for my tiny demo model!")

In [131]:
model.eval()
import torch.nn.functional as F

# Define the starting text or prompt
starting_text = "Oh why, "

# Tokenize the starting text
input_tokens = reference_gpt2.tokenizer.encode(starting_text, return_tensors='pt').cuda()

# Set the maximum number of tokens to generate
max_length = 500

# Generate text
with torch.no_grad():
    for _ in range(max_length):
        # Forward pass through the model
        output = model(input_tokens)

        # Get the last predicted token
        last_token = output[:, -1]

        # Sample from the output distribution
        probabilities = F.softmax(last_token, dim=-1).squeeze()
        predicted_token = torch.multinomial(probabilities, 1)

        # Append the predicted token to the input tokens
        input_tokens = torch.cat((input_tokens, predicted_token.unsqueeze(0)), dim=1)

# Convert the generated tokens back to text
generated_text = reference_gpt2.tokenizer.decode(input_tokens[0], skip_special_tokens=True)
print(generated_text)

Oh why,.

PC call not when Clarence, conducted to our friends; you serve heart;
Of our mind? Plantagen put
To robbers leng G this; and call on advantage, and made all
By motion so my welcome;;
Nay, his brother I kill'd,
m home your foe,for: for one little
 have no past every worthy villain willIA:
Tot.

Caius A lives shall noaways
B obunn'd hither leave your:
GLOUCESTER:
I am wise,
I had he be the breath bitterly and mother, much. Come, have got away to prosper,
There I had he tent down
And ten thousand past Edward
That you seem'd; I, hath enemies.

Fell master. The ho!

EDWARD IV:
Why, one word in this princely he speaks about thy master?

She shall alar, most wilt well he shall, my lord? what theyI show'd was spent to foot.

GARENCE:
 news of Gloucester very thing,
A hundred:
And, for troth ensue of Lancaster unto the fair no bereaved and bade the blood room in every stir;
Or, or taking the world.

BISHOP OF YORK:
From child is, I expost thou know,balt, we have done
fore the fitt, po