## Recap – Autoregressive Models

**Chain Rule Decomposition**  
$$
p(x_1, \ldots, x_T) = \prod_{t=1}^{T} p(x_t \mid x_1, \ldots, x_{t-1})
$$

**Maximum Likelihood Estimation (MLE)**  
$$
\widehat{\theta} = \arg \max_{\theta} \sum_{t=1}^T \log p_\theta(x_t \mid x_{1}, x_{2}, \ldots, x_{t-1})
$$

- Factorizes the joint distribution into conditional probabilities.  
- Training objective: maximize log-likelihood of observed sequence.  
- Examples: Language modeling with RNNs, Transformers, masked CNNs.

In [None]:
# PyTorch demo: compute autoregressive log-likelihood for a toy sequence

import torch
import torch.nn as nn
import torch.nn.functional as F

# Toy setup
vocab_size = 10
seq_len = 6
batch_size = 2

# Fake "model" = simple embedding + linear classifier
class ToyARModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim=16):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        h = self.embed(x)
        out, _ = self.rnn(h)
        logits = self.fc(out)  # shape: (batch, seq_len, vocab_size)
        return logits

# Sample data: two sequences
x = torch.randint(0, vocab_size, (batch_size, seq_len))
print("Input sequences:\n", x)

model = ToyARModel(vocab_size)
logits = model(x[:, :-1])  # predict next token given prefix

# Compute autoregressive log-likelihood
log_probs = F.log_softmax(logits, dim=-1)
target = x[:, 1:]  # predict x_t given x_{<t}
nll = F.nll_loss(
    log_probs.reshape(-1, vocab_size),
    target.reshape(-1),
    reduction="mean"
)

print("Negative Log-Likelihood (to minimize):", nll.item())


Input sequences:
 tensor([[4, 7, 0, 4, 3, 6],
        [9, 7, 5, 9, 6, 0]])
Negative Log-Likelihood (to minimize): 2.16798996925354


In [None]:
print(log_probs.shape)

torch.Size([2, 5, 10])


## Autoregressive Models

Autoregressive models factorize the joint distribution into conditionals  
and can be implemented in several ways:

| Model Type          | Example Architectures               | Parallel Training | Parameter Sharing |
|---------------------|-------------------------------------|------------------|------------------|
| **Recurrent**       | RNN, LSTM                           | ❌ No             | ✅ Yes           |
| **Masked MLP**      | MADE (Masked Autoencoder for Dist.) | ✅ Yes            | ❌ No            |
| **Masked CNN / Transformer** | PixelCNN, Transformer, Linear Attention | ✅ Yes | ✅ Yes |

- **Parallel training:** Can multiple time steps be computed at once?  
- **Parameter sharing:** Do different positions reuse the same parameters?


## Limitation of CNNs

- Convolutional Neural Networks (CNNs) capture **local dependencies** via small receptive fields.  
- Stacking layers gradually increases receptive field size, but this grows **slowly**.  
- To model **long-range dependencies**, CNNs require:
  - Very deep networks, or  
  - Dilated convolutions (to expand receptive field faster).  
- Transformers address this limitation with **self-attention**, which directly connects all positions.


## Masked Self-Attention

Convolutional models face a recurring problem:
- **Limited receptive field** → hard to capture long-range dependencies.

Self-Attention provides an alternative with key advantages:
- **Unlimited receptive field** (unlike PixelCNN).  
- **Parameter efficiency**: $O(1)$ scaling w.r.t input dimension (unlike MADE).  
- **Parallel computation** across positions (unlike RNNs).  

**Masked Attention** ensures autoregressive behavior:  
Each position $t$ can only attend to $\{1, \ldots, t\}$.


## Transformer Architecture

![More details](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ_zVcjcs2fTW_1qGoPk9yO5_oRfacPwnCbmg&s)

![More details](https://i.sstatic.net/eAKQu.png)



## Attention
![Transformer Architecture](https://lilianweng.github.io/posts/2018-06-24-attention/transformer.png)

- **Inputs:**  
  - A sequence of **(key, value)** pairs.  
    - *Values* are hidden states from a previous layer.  
    - In encoder–decoder attention,  
      - **Values = final encoder hidden states**  
      - **Keys = encoder hidden states aligned with target sequence**.  
  - A sequence of **queries**, representing the *current focus*.  

- **Mechanism:**  
  - Compute **scores** between each query and all keys.  
  - Normalize scores with softmax.  
  - Produce **weighted sum of values**.  

**Scaled Dot-Product Attention**  
$$
\text{Attention}(Q, K, V)
= \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V
$$

where  
- $Q \in \mathbb{R}^{n_q \times d_k}$ (queries)  
- $K \in \mathbb{R}^{n_k \times d_k}$ (keys)  
- $V \in \mathbb{R}^{n_k \times d_v}$ (values)  





![Input](https://miro.medium.com/v2/resize:fit:1400/0*yGkD4RobNX5VUsxP.png)


![Input](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/encoder_decoder/EncoderDecoder_step_by_step.png)




In [None]:
# PyTorch demo: scaled dot-product attention

import torch
import torch.nn.functional as F

torch.manual_seed(42)

# Example dimensions
batch_size = 1
n_q, n_k = 3, 5
d_k, d_v = 4, 6

# Random queries, keys, values
Q = torch.randn(batch_size, n_q, d_k)
K = torch.randn(batch_size, n_k, d_k)
V = torch.randn(batch_size, n_k, d_v)

# Compute scaled dot-product attention
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)   # (batch, n_q, n_k)
weights = F.softmax(scores, dim=-1)               # attention distribution
output = weights @ V                              # (batch, n_q, d_v)

print("Attention scores (before softmax):\n", scores[0])
print("\nAttention weights (after softmax):\n", weights[0])
print("\nAttention output shape:", output.shape)


## Attention: Parallelization & Multi-Head

![Multihead](https://substackcdn.com/image/fetch/$s_!Q6zJ!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc1497d87-2b8c-45eb-b7b9-3a0a8ebe0d3d_1783x747.png)

- At each position $j$, attention computations are **independent**.  
- This allows us to parallelize using **matrix multiplications**.  

**Scaled Dot-Product Attention**  
$$
\text{Attention}(Q,K,V) \;=\; \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
$$

- Queries $Q \in \mathbb{R}^{n_q \times d_k}$  
- Keys $K \in \mathbb{R}^{n_k \times d_k}$  
- Values $V \in \mathbb{R}^{n_k \times d_v}$  

---

### Multi-Head Attention
- Run **multiple attention layers in parallel** ("heads").  
- Each head has its own projection of $Q, K, V$.  
- Outputs from all heads are concatenated and projected.  

$$
\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O
$$


![Multihead](https://data-science-blog.com/wp-content/uploads/2022/01/mha_visualization-930x1030.png)

In [None]:
# PyTorch demo: Scaled Dot-Product Attention (matrix form) and Multi-Head

import torch
import torch.nn.functional as F
import torch.nn as nn

torch.manual_seed(123)

batch_size, n_q, n_k = 2, 4, 5
d_model, d_k, d_v, n_heads = 16, 16, 16, 4

# Random Q, K, V
Q = torch.randn(batch_size, n_q, d_model)
K = torch.randn(batch_size, n_k, d_model)
V = torch.randn(batch_size, n_k, d_model)

# ---- 1. Scaled Dot-Product Attention ----
scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)  # (batch, n_q, n_k)
weights = F.softmax(scores, dim=-1)
out = weights @ V  # (batch, n_q, d_model)

print("Scaled Dot-Product Attention output shape:", out.shape)

# ---- 2. Multi-Head Attention (PyTorch built-in) ----
mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
mha_out, attn_weights = mha(Q, K, V)

print("Multi-Head Attention output shape:", mha_out.shape)
print("Attention weights shape (batch, n_q, n_k):", attn_weights.shape)


Scaled Dot-Product Attention output shape: torch.Size([2, 4, 16])
Multi-Head Attention output shape: torch.Size([2, 4, 16])
Attention weights shape (batch, n_q, n_k): torch.Size([2, 4, 5])


## Encoder–Decoder Transformer: Output Distribution

At decoding step $t$, the model predicts the probability of the next token:

$$
P(y_t = i \mid x_{1:T}, y_{1:t-1}) =
\frac{\exp \big( (E \widehat{y}_t)[i] \big)}
     {\sum_j \exp \big( (E \widehat{y}_t)[j] \big)}
$$

- $\widehat{y}_t$: decoder hidden state at position $t$.  
- $E$: embedding (projection) matrix.  
- Output is a **softmax distribution** over the vocabulary.  
- This connects decoder hidden states to predicted tokens.


In [None]:
# PyTorch demo: Transformer output layer (projection to vocab)

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

# Suppose vocab of size 8 and decoder hidden dim = 16
vocab_size = 8
d_model = 16

# Decoder hidden state at time t
y_t_hat = torch.randn(1, d_model)  # (batch=1, d_model)

# Embedding / projection matrix
E = nn.Linear(d_model, vocab_size, bias=False)

# Logits and softmax distribution
logits = E(y_t_hat)                  # (1, vocab_size)
probs = F.softmax(logits, dim=-1)    # (1, vocab_size)

print("Logits:\n", logits)
print("\nSoftmax distribution over vocabulary:\n", probs)
print("\nPredicted token index:", torch.argmax(probs, dim=-1).item())


Logits:
 tensor([[-0.6512,  0.2615,  0.4667,  0.1221,  0.5225,  0.0916, -0.4690,  0.7660]],
       grad_fn=<MmBackward0>)

Softmax distribution over vocabulary:
 tensor([[0.0516, 0.1286, 0.1578, 0.1118, 0.1669, 0.1085, 0.0619, 0.2129]],
       grad_fn=<SoftmaxBackward0>)

Predicted token index: 7


## Positional Encoding

- Transformers are **permutation invariant** by default.  
- Word order matters in language:  
  - *"The mouse ate the cat"* vs *"The cat ate the mouse"*.  
- To inject sequence order, we add a **positional encoding** to each token embedding.


![Positional Encoding](https://substackcdn.com/image/fetch/$s_!IcmJ!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F4b94aaa0-59fb-45b8-86e5-9c2ce9c1cb12_1937x1388.png)

![Attention](https://deepgram.com/_next/image?url=https%3A%2F%2Fwww.datocms-assets.com%2F96965%2F1684228093-7-transformers-explained.png&w=3840&q=90)

**Key idea:**  
$$
\text{Input}_t = \text{TokenEmbedding}(x_t) + \text{PositionalEncoding}(t)
$$

Common choices:
- **Sinusoidal encoding** (fixed, deterministic).  
- **Learned positional embeddings** (trainable).

![Add](https://substackcdn.com/image/fetch/$s_!Vo8t!,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F58a77f49-ed6d-4614-9c64-505455bd0c83_2043x1300.png)

# 🌊 Sinusoidal Positional Encoding

Transformers need a way to represent **token positions**, since self-attention itself is order-agnostic.  
The original *Attention is All You Need* paper uses **sinusoidal encodings**:

$$
PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad
PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right)
$$

- $pos$: position in the sequence $(0, 1, 2, …)$  
- $i$: dimension index  
- $d$: embedding dimension  

Even indices use sine, odd indices use cosine.

---

### ✨ Key Properties
- **Nearby positions** have similar encodings.  
- **Multiple frequencies** capture both coarse and fine positional info.  
- **Extrapolation**: encodings generalize beyond training length.  



In [None]:
# PyTorch demo: sinusoidal positional encoding

import torch
import math

def positional_encoding(seq_len, d_model):
    """Return sinusoidal positional encodings."""
    pe = torch.zeros(seq_len, d_model)
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)   # even dimensions
    pe[:, 1::2] = torch.cos(position * div_term)   # odd dimensions
    return pe

seq_len, d_model = 10, 16
pe = positional_encoding(seq_len, d_model)

print("Shape of positional encoding:", pe.shape)
print("\nFirst position encoding:\n", pe[0])
print("\nSecond position encoding:\n", pe[1])


![Cos](https://erdem.pl/static/3dbeb6eb036c37502841a6a0238b0ebd/21b4d/position-values-45k.png)

## Layer Normalization & Residual Connections in Transformers

- **Normalization** dramatically improves trainability.  
  - **Post-norm (original)**: apply LayerNorm *after* residual addition.  
  - **Pre-norm (modern)**: apply LayerNorm *before* sublayer (stabilizes training for deep Transformers).  

- **Residual connections** ensure:  
  - Input and output of each sublayer have the **same shape**.  
  - The sublayer computes a *residual update*, following the ResNet idea:  

$$
\text{Output} = \text{Input} + \text{Sublayer}(\text{Input})
$$


![ResNet](https://i.ytimg.com/vi/r0HvOIjziw4/maxresdefault.jpg)

##Loss surface comparison
![resnet](https://miro.medium.com/v2/resize:fit:723/1*_Qd_txKxRlsMdfuH2J-k4g.png)

![normalization](https://miro.medium.com/1*GwxdzLlnWf1NQ5yHnTSAHA.png)

In [None]:
# PyTorch demo: residual + layer norm (pre-norm vs post-norm)

import torch
import torch.nn as nn

torch.manual_seed(0)

x = torch.randn(2, 5, 16)  # (batch, seq_len, d_model)
layer_norm = nn.LayerNorm(16)
linear = nn.Linear(16, 16)

# --- Post-norm (original Transformer) ---
out_post = layer_norm(x + linear(x))

# --- Pre-norm (modern Transformers) ---
out_pre = x + linear(layer_norm(x))

print("Input shape:", x.shape)
print("Post-norm output shape:", out_post.shape)
print("Pre-norm output shape:", out_pre.shape)


Input shape: torch.Size([2, 5, 16])
Post-norm output shape: torch.Size([2, 5, 16])
Pre-norm output shape: torch.Size([2, 5, 16])


## Transformer Encoder Block: Step by Step

Each encoder layer applies the following sequence:

1. **Multi-Head Attention (MHA)**  
   - Self-attention over the input sequence.  

2. **Add & Norm**  
   - Residual connection + LayerNorm.  

3. **Feed Forward (FFN)**  
   - Position-wise MLP applied independently to each token.  

4. **Add & Norm**  
   - Residual connection + LayerNorm.  

---
![Repeat](https://miro.medium.com/v2/resize:fit:2000/1*XV0h6aTDdEPmL8Gn4gYFoA.png)




## Transformer Decoder Block: Step by Step

Each decoder layer applies the following sequence:

1. **Masked Multi-Head Attention (MHA)**  
   - Self-attention over the *target sequence so far*.  
   - Causal mask prevents attending to future tokens.  

2. **Add & Norm**  
   - Residual connection + LayerNorm.  

3. **Encoder–Decoder Multi-Head Attention**  
   - Queries come from the decoder.  
   - Keys/Values come from the encoder output.  

4. **Add & Norm**  

5. **Feed Forward (FFN)**  
   - Position-wise MLP.  

6. **Add & Norm**  

---

![More details](https://i.sstatic.net/eAKQu.png)



![Decoder](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/encoder_decoder/encoder_decoder_detail.png)

## RNNs vs Transformers

### RNNs
- (+) LSTMs work reasonably well for **long sequences**.  
- (–) Require **ordered inputs** (cannot handle sets).  
- (–) **Sequential computation**: each hidden state depends on the previous one.  

### Transformers
- (+) Handle **long sequences** effectively: attention looks at all inputs at once.  
- (+) Can operate on **unordered sets** or **ordered sequences** (with positional encodings).  
- (+) **Parallel computation**: all attention scores are computed simultaneously.  
- (–) **Memory expensive**: need to compute & store \(N \times M\) attention scores per head.  


In [None]:
# PyTorch demo: Sequential RNN vs Parallel Transformer Attention

import torch
import torch.nn as nn

torch.manual_seed(42)
batch_size, seq_len, d_model = 1, 6, 8

x = torch.randn(batch_size, seq_len, d_model)

# --- RNN: sequential processing ---
rnn = nn.RNN(d_model, d_model, batch_first=True)
h, _ = rnn(x)  # output (batch, seq_len, d_model)

print("RNN processes sequentially:")
for t in range(seq_len):
    print(f" Step {t}: hidden state depends on all x[:{t+1}]")

# --- Transformer: parallel attention ---
mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=2, batch_first=True)
attn_out, attn_weights = mha(x, x, x)

print("\nTransformer processes in parallel:")
print(" Attention weights shape:", attn_weights.shape)  # (batch, num_heads, seq_len, seq_len)
print(" Each position attends to ALL positions simultaneously.")


RNN processes sequentially:
 Step 0: hidden state depends on all x[:1]
 Step 1: hidden state depends on all x[:2]
 Step 2: hidden state depends on all x[:3]
 Step 3: hidden state depends on all x[:4]
 Step 4: hidden state depends on all x[:5]
 Step 5: hidden state depends on all x[:6]

Transformer processes in parallel:
 Attention weights shape: torch.Size([1, 6, 6])
 Each position attends to ALL positions simultaneously.


## Problems with Vanilla Transformers

1. **Quadratic Complexity**  
   - Attention is $O(n^2)$ in sequence length $n$.  
   - Training with long sequences becomes expensive.  

2. **Memory Bottleneck**  
   - Attention requires storing all $n \times n$ alignment scores.  
   - Inference on long inputs hits GPU memory limits.  

3. **Positional Encoding Limitations**  
   - Standard sinusoidal encodings may not extrapolate well to sequences longer than those seen in training.  


In [None]:
# PyTorch demo: Attention cost grows quadratically with sequence length

import torch
import torch.nn as nn
import time

def attention_cost(seq_len, d_model=64, n_heads=4):
    mha = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
    x = torch.randn(1, seq_len, d_model)
    start = time.time()
    _ = mha(x, x, x)  # forward pass
    elapsed = time.time() - start
    return elapsed

for n in [64, 128, 256, 512, 1024]:
    t = attention_cost(n)
    print(f"Seq Len = {n:4d}, Time = {t*1000:.2f} ms")


Seq Len =   64, Time = 18.29 ms
Seq Len =  128, Time = 1.07 ms
Seq Len =  256, Time = 3.46 ms
Seq Len =  512, Time = 11.30 ms
Seq Len = 1024, Time = 39.70 ms


## Transformers, BERT, and GPT-2

- **Full Transformer (Seq2Seq)**  
  - Encoder stack processes input.  
  - Decoder stack generates output, attending to encoder output.  
  - Example: **Machine Translation**.  

- **BERT (Encoder-only)**  
  - Use only the **encoder stack**.  
  - Pretrain with **masked language modeling** (bidirectional).  
  - Good for representation learning (classification, QA, etc.).  

- **GPT (Decoder-only)**  
  - Use only the **decoder stack**.  
  - Train with **causal/next-word prediction** (autoregressive).  
  - Good for generative tasks (text completion, chat, etc.).  


## Decoder-Only Models: GPT-2 Scaling

GPT-2 is a **decoder-only Transformer** trained autoregressively.  
Its performance improves significantly with scale:

| Model     | Parameters |
|-----------|------------|
| GPT-2 Small  | 117M |
| GPT-2 Medium | 345M |
| GPT-2 Large  | 762M |
| GPT-2 XL     | 1542M |

**Scaling laws:**  
- Larger models capture longer dependencies & richer structure.  
- Requires proportional scaling in compute & data.  


## Pretraining Task: Next-Token Prediction

- **Training objective:** predict the next token given all previous tokens.  
  $$
  P(y_t \mid y_{<t})
  $$

- **Naive decoding:**  
  - At step $t$, recompute queries, keys, values for the **entire prefix**.  
  - Total cost grows **quadratically** with sequence length.  

- **Caching trick (used in GPT-style models):**  
  - Store keys/values from previous steps.  
  - At step $t$, compute only $(q_t, k_t, v_t)$ for the new token.  
  - Additional cost per step becomes **linear** instead of quadratic.  


In [None]:
# PyTorch demo: caching key/value states for autoregressive decoding

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

d_model, nhead = 16, 2
mha = nn.MultiheadAttention(d_model, nhead, batch_first=True)

# Suppose we decode tokens step by step
seq_len = 5
cache_k, cache_v = None, None

for t in range(seq_len):
    x_t = torch.randn(1, 1, d_model)  # new token embedding at step t

    # Compute q_t
    q_t = x_t

    # Update cache (append new k,v)
    k_t, v_t = x_t, x_t
    cache_k = k_t if cache_k is None else torch.cat([cache_k, k_t], dim=1)
    cache_v = v_t if cache_v is None else torch.cat([cache_v, v_t], dim=1)

    # Attention: query = q_t, keys/values = cache
    out, _ = mha(q_t, cache_k, cache_v)

    print(f"Step {t}: cache size = {cache_k.shape[1]}, output shape = {out.shape}")


Step 0: cache size = 1, output shape = torch.Size([1, 1, 16])
Step 1: cache size = 2, output shape = torch.Size([1, 1, 16])
Step 2: cache size = 3, output shape = torch.Size([1, 1, 16])
Step 3: cache size = 4, output shape = torch.Size([1, 1, 16])
Step 4: cache size = 5, output shape = torch.Size([1, 1, 16])


## Downstream Tasks with GPT-2

- **No task-specific fine-tuning** (unlike BERT).  
- GPT-2 is pretrained on large, diverse text corpora with **next-token prediction**.  
- At inference, the same model can be applied directly to many tasks:  

**Examples:**
- 🌐 Translation  
- ❓ Question Answering  
- 📝 Summarization  
- 📖 Reading Comprehension  
- 🔮 Language Modeling  

**Key Idea:**  
Pretraining on broad data gives GPT-2 *zero-shot* and *few-shot* abilities — no architectural changes, just different prompts.


In [None]:
# PyTorch demo: One GPT-2 style decoder reused for multiple tasks

import torch
import torch.nn as nn

# Toy GPT-like decoder-only model
class MiniGPT(nn.Module):
    def __init__(self, vocab_size=100, d_model=32, nhead=4, nlayers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=nlayers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x, memory=None, tgt_mask=None):
        h = self.embed(x)
        dec_out = self.decoder(h, memory if memory is not None else h, tgt_mask=tgt_mask)
        return self.out(dec_out)

# Same model reused for different tasks
model = MiniGPT()

# Example: predict next word in a sentence (language modeling)
x = torch.randint(0, 100, (1, 6))  # toy input sequence
logits = model(x)
probs = torch.softmax(logits, dim=-1)

print("Output vocab distribution shape:", probs.shape)  # (batch, seq_len, vocab_size)
print("Predicted next token at last step:", torch.argmax(probs[:, -1, :], dim=-1).item())

# The same model could be prompted differently for QA, summarization, translation, etc.


Output vocab distribution shape: torch.Size([1, 6, 100])
Predicted next token at last step: 4


![GPT4](https://www.stylefactoryproductions.com/wp-content/uploads/2023/04/chatgpt-4-training-data-size.png)