# 1. Define the model

In [2]:
import os
import math
import time
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import inspect
import requests
from dataclasses import dataclass

# -----------------------------------------------------------------------------

class MultiHeadCausalAttention(nn.Module):
    """Implements multi-head attention with causal masking."""
    def __init__(self, config):
        super(MultiHeadCausalAttention, self).__init__()
        if config.embedding_dim % config.num_heads != 0:
            raise ValueError("Embedding size must be divisible by the number of heads.")
        # Linear transformations for query, key, and value
        self.qkv_projection = nn.Linear(config.embedding_dim, 3 * config.embedding_dim)
        # Linear transformation for output
        self.final_projection = nn.Linear(config.embedding_dim, config.embedding_dim)
        self.final_projection.SCALE_FACTOR = 1  # Custom scaling for initialization

        # Number of attention heads and dimensions per head
        self.num_heads = config.num_heads
        self.dim_per_head = config.embedding_dim // config.num_heads

    def forward(self, inputs):
        batch_size, seq_length, embed_dim = inputs.size()
        # Compute query, key, and value tensors
        qkv = self.qkv_projection(inputs)
        q, k, v = torch.chunk(qkv, 3, dim=2)

        # Reshape tensors for multi-head attention
        q = q.view(batch_size, seq_length, self.num_heads, self.dim_per_head).transpose(1, 2)
        k = k.view(batch_size, seq_length, self.num_heads, self.dim_per_head).transpose(1, 2)
        v = v.view(batch_size, seq_length, self.num_heads, self.dim_per_head).transpose(1, 2)

        # Perform scaled dot-product attention with causal masking
        #att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        #att = att.masked_fill(torch.tril(torch.ones_like(att)) == 0, float('-inf'))
        #attn_output = F.softmax(att, dim=-1) @ v
        # use fast attention instead, which is faster
        attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # Reshape and project back to original embedding size
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_length, embed_dim)
        return self.final_projection(attn_output)

class PositionwiseFeedForward(nn.Module):
    """Applies a feed-forward network to the input tensor."""
    def __init__(self, config):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(config.embedding_dim, 4 * config.embedding_dim)
        self.activation_fn = nn.GELU(approximate='tanh')
        self.linear2 = nn.Linear(4 * config.embedding_dim, config.embedding_dim)
        self.linear2.SCALE_FACTOR = 1  # Custom scaling for initialization

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)
        return self.linear2(x)

class TransformerLayer(nn.Module):
    """Defines a single layer of the transformer."""
    def __init__(self, config):
        super(TransformerLayer, self).__init__()
        # Normalization layers
        self.pre_attn_norm = nn.LayerNorm(config.embedding_dim)
        self.pre_ffn_norm = nn.LayerNorm(config.embedding_dim)

        # Attention and feed-forward components
        self.causal_attention = MultiHeadCausalAttention(config)
        self.feedforward = PositionwiseFeedForward(config)

    def forward(self, x):
        # Residual connection with attention
        x = x + self.causal_attention(self.pre_attn_norm(x))
        # Residual connection with feed-forward
        x = x + self.feedforward(self.pre_ffn_norm(x))
        return x

@dataclass
class TransformerConfig:
    """Configuration for the Transformer model."""
    seq_len: int = 1024  # Maximum sequence length
    vocab_size: int = 50257  # Vocabulary size
    num_layers: int = 12  # Number of transformer layers
    num_heads: int = 12  # Number of attention heads
    embedding_dim: int = 768  # Dimensionality of embeddings

    def get_block_size(self):
        """Returns the block size for the transformer."""
        return self.seq_len

class GPT(nn.Module):
    """Full GPT model with transformer architecture."""
    def __init__(self, config):
        super(GPT, self).__init__()
        self.config = config

        # Embedding layers
        self.token_embed = nn.Embedding(config.vocab_size, config.embedding_dim)
        self.positional_embed = nn.Embedding(config.seq_len, config.embedding_dim)

        # Transformer blocks
        self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_layers)])
        self.final_norm = nn.LayerNorm(config.embedding_dim)

        # Output layer
        self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
        self.lm_head.weight = self.token_embed.weight  # Weight sharing with token embedding

        # Initialize parameters
        self._init_weights()

    def _init_weights(self):
        """Custom initialization for model parameters."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                std = 0.02
                if hasattr(module, 'SCALE_FACTOR'):
                    std *= (2 * self.config.num_layers) ** -0.5
                nn.init.normal_(module.weight, mean=0.0, std=std)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, targets=None):
        """Forward pass of the GPT model."""
        B, T = input_ids.size()
        if T > self.config.get_block_size():
            raise ValueError(f"Input sequence length {T} exceeds block size {self.config.get_block_size()}.")

        # Embedding lookup
        positions = torch.arange(T, device=input_ids.device).unsqueeze(0)
        x = self.token_embed(input_ids) + self.positional_embed(positions)

        # Pass through transformer layers
        for layer in self.layers:
            x = layer(x)

        # Apply final normalization
        x = self.final_norm(x)
        logits = self.lm_head(x)

        # Compute loss if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    @classmethod
    def load_pretrained(cls, model_type):
        """Load pretrained GPT-2 weights from HuggingFace."""
        from transformers import GPT2LMHeadModel

        # Supported GPT-2 models
        model_mapping = {
            "gpt2": {"num_layers": 12, "num_heads": 12, "embedding_dim": 768},
            "gpt2-medium": {"num_layers": 24, "num_heads": 16, "embedding_dim": 1024},
            "gpt2-large": {"num_layers": 36, "num_heads": 20, "embedding_dim": 1280},
            "gpt2-xl": {"num_layers": 48, "num_heads": 25, "embedding_dim": 1600},
        }

        # Check if model type is valid
        if model_type not in model_mapping:
            raise ValueError(f"Unsupported model type: {model_type}")

        # Set configuration
        config_params = model_mapping[model_type]
        config_params.update({"vocab_size": 50257, "seq_len": 1024})
        config = TransformerConfig(**config_params)

        # Initialize our model
        model = cls(config)

        # Load HuggingFace GPT-2 model
        hf_model = GPT2LMHeadModel.from_pretrained(model_type)
        hf_state_dict = hf_model.state_dict()

        # Transfer weights from HuggingFace model
        own_state_dict = model.state_dict()
        for name, param in hf_state_dict.items():
            if name in own_state_dict and param.size() == own_state_dict[name].size():
                own_state_dict[name].copy_(param)

        return model

    def configure_optimizers(self, weight_decay, lr, device_type):
        """Configure optimizer with weight decay for model parameters."""
        # Separate parameters into groups for weight decay
        decay_params = []
        no_decay_params = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                if param.ndimension() >= 2:
                    decay_params.append(param)
                else:
                    no_decay_params.append(param)

        optimizer_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": no_decay_params, "weight_decay": 0.0},
        ]

        # Use fused AdamW if supported
        try:
            use_fused = "fused" in inspect.signature(torch.optim.AdamW).parameters and device_type == "cuda"
        except Exception:
            use_fused = False

        print(f"Using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(
            optimizer_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused
        )
        return optimizer


## 1.1 Show some sample sequences using pretrained gpt2 weights

In [7]:
!pip install tiktoken



In [9]:
from codecs import encode
# check if gpu is available
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print(f"Using device: {device}")

# generate some sample sequence using pretrained gpt2 weights
num_return_sequence =5
max_length = 15
# generate some sample sequence using pretrained gpt2 weights
model = GPT.load_pretrained("gpt2")
# generate using random non-trained model
#model = GPT(TransformerConfig())
model.eval()
model = model.to(device)

# tokenize the text
import tiktoken
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode("I am a language model,")
x = torch.tensor(tokens,dtype = torch.long).unsqueeze(0).repeat(num_return_sequence,1).to(device)

torch.manual_seed(42)
torch.cuda.manual_seed(42)
while x.size(1) < max_length:
  with torch.no_grad():
    logits,_  = model(x)
    logits = logits[:, -1, :]
    probs = F.softmax(logits, dim=-1)
    # select next token from top k probabilities
    topkprobs, topkindices = torch.topk(probs, 50, dim=-1)
    next_token_id = torch.multinomial(topkprobs, num_samples=1)
    next_token = topkindices.gather(dim=-1, index=next_token_id)
    # add next token to sequence
    x = torch.cat((x, next_token), dim=1)
# print generated sequence
for i in range(num_return_sequence):
  print(enc.decode(x[i].tolist()))

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


I am a language model,,,,,:::::
I am a language model,,,,,,,,,,
I am a language model,,,,,,,,,,
I am a language model,,,,,,,:::
I am a language model,,,,,,,,,,


# 2. Train the model

## 2.1 Prepare the data (using shakespear dataset)

In [10]:
# prepare the data
# download the tiny shakespeare dataset and separate it for train and validation
# this part of code is copied from github https://github.com/karpathy/nanoGPT
input_file_path = os.path.join(os.path.dirname("/content/sample_data"), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w', encoding='utf-8') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
    data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname("/content/sample_data"), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname("/content/sample_data"), 'val.bin'))

# train.bin has 301,966 tokens
# val.bin has 36,059 tokens

train has 301,966 tokens
val has 36,059 tokens


## 2.2 Training on Shakespear dataset

In [11]:
# first define the DataLoader class to load the data in batches
class DataLoader:
  #B is for batch size, T is for block size
  def __init__(self, data, B, T):
    self.tokens = torch.tensor(data,dtype = torch.int64)
    self.B = B
    self.T = T
    self.num_batches = len(data) // (B * T)
    print(f"Number of batches: {self.num_batches}")
    print(f"Number of tokens: {len(data)}")
    # set position at the begin
    self.current_position = 0

  def next_batch(self):
    B = self.B
    T = self.T
    buf = self.tokens [self.current_position:self.current_position + B * T+1]
    x = buf[:-1].view(B, T)
    y = buf[1:].view(B, T)
    self.current_position += B * T

    # if the remaining data does not enough for one batch, reset
    if self.current_position + B * T+1 > len(self.tokens):
      self.current_position = 0
    return x, y

In [14]:
pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [13]:
# this part of code is to playwith the method for speeding up and get the average  runtime
import triton
# check if gpu is available
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print(f"Using device: {device}")
torch.manual_seed(42)
torch.cuda.manual_seed(42)

B = 16
T = 1024
train_loader = DataLoader(train_ids, B,T)

# to speed up, downgrade the precision for matrix multiplication
torch.set_float32_matmul_precision('high')

# change the vocab size to nice number 50304, which can be divided by 2
# the original vocab size is 50257
model = GPT(TransformerConfig(vocab_size= 50304))
model = model.to(device)

#using torch.compile also helps for improving the speed
model = torch.compile(model)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)



time_sents = []
for i in range(50):
  t0 = time.time()
  x, y = train_loader.next_batch()
  x = x.to(device)
  y = y.to(device)
  optimizer.zero_grad()
  # another method for speed up, lower the precision for some of the operation
  with torch.autocast(device_type=device, dtype=torch.bfloat16):
    logits, loss = model(x, y)
  loss.backward()
  optimizer.step()
  torch.cuda.synchronize()
  t1 = time.time()
  dt = t1-t0
  time_sents.append(dt)
  print(f"step: {i}, loss: {loss.item()}, time: {dt:.4f}")

# get the average time
avg_time = np.mean(time_sents[1:50])
print(f"Average time: {avg_time:.4f}")

Using device: cuda
Number of batches: 18
Number of tokens: 301966
step: 0, loss: 10.934246063232422, time: 24.3486
step: 1, loss: 9.439720153808594, time: 0.0944
step: 2, loss: 9.042047500610352, time: 0.0943
step: 3, loss: 8.909623146057129, time: 0.0943
step: 4, loss: 8.652402877807617, time: 0.0945
step: 5, loss: 8.543598175048828, time: 0.0943
step: 6, loss: 8.34595775604248, time: 0.0944
step: 7, loss: 8.137788772583008, time: 0.0945
step: 8, loss: 7.843873500823975, time: 0.0946
step: 9, loss: 7.622587203979492, time: 0.0941
step: 10, loss: 7.421063423156738, time: 0.0944
step: 11, loss: 7.318807601928711, time: 0.0943
step: 12, loss: 7.150790214538574, time: 0.0941
step: 13, loss: 7.102771759033203, time: 0.0949
step: 14, loss: 7.051050662994385, time: 0.0946
step: 15, loss: 6.89571475982666, time: 0.0945
step: 16, loss: 6.862709999084473, time: 0.0944
step: 17, loss: 6.815102577209473, time: 0.0943
step: 18, loss: 6.6193695068359375, time: 0.0943
step: 19, loss: 6.4099822044372