In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np

#------------------------------------------------
class CausalSelfAttention(nn.Module):
  def __init__(self,config):
    super().__init__()
    assert config.n_embd % config.n_head == 0
    # key, query, value projections for all heads, but in a batch
    self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
    # output projection
    self.c_proj = nn.Linear(config.n_embd, config.n_embd)
    self.c_proj.NANOGPT_SCALE_INIT = 1
    # regularization
    self.n_head = config.n_head
    self.n_embd = config.n_embd
    # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
    self.register_buffer("bias", torch.tril(torch.ones(config.block_size,config.block_size)).view(1, 1, config.block_size, config.block_size))

  def forward(self, x):
    B, T, C = x.size() # batch size, sequene length, embedding dimensionality
    # calculate query key, values for all heads in a batch and move head forwrad to be
    # nh is "number of heads", hs is "head size", and C(number of chanels) = nh * hs
    # e.g., in GPT-2 (124M), n_head, hs = 64, so nh*hs = C = 768 channels in the Transformer
    qkv = self.c_attn(x)
    q, k, v = qkv.split(self.n_embd, dim = 2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
    # attention (materializes the large (T,T) matrix for all the queries and keys)
    # att = (q @ k.transpose(-2,-1)) * (1.0/math.sqrt(k.size(-1)))
    # att = att.masked_fill(self.bias[:,:,:T,:T]==0, float('-inf'))
    # att = F.softmax(att, dim=-1)
    # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = F.scaled_dot_product_attention(q, k, v, is_causal = True)
    y = y.transpose(1,2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

    #ouput projection
    y = self.c_proj(y)
    return y

class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd)
    self.gelu   = nn.GELU(approximate='tanh')
    self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
    self.c_proj.NANOGPT_SCALE_INIT = 1


  def forward(self, x):
    x = self.c_fc(x)
    x = self.gelu(x)
    x = self.c_proj(x)
    return x


class Block(nn.Module):

  def __init__(self, config):
    super().__init__()
    self.ln_1 = nn.LayerNorm(config.n_embd)
    self.attn = CausalSelfAttention(config)
    self.ln_2 = nn.LayerNorm(config.n_embd)
    self.mlp = MLP(config)

  def forward(self,x):
    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))
    return x


@dataclass
class GPTConfig:
  block_size: int = 1024 # max sequence length
  vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|>
  n_layer: int = 12 #number of layers
  n_head: int = 12 # number of heads
  n_embd: int = 768 # embedding dimension

class GPT(nn.Module):

  def __init__(self, config):
    super().__init__()
    self.config = config

    self.transformer = nn.ModuleDict(dict(
        wte = nn.Embedding(config.vocab_size, config.n_embd),
        wpe = nn.Embedding(config.block_size, config.n_embd),
        h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ln_f = nn.LayerNorm(config.n_embd),
    ))

    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

    # weight sharing scheme
    self.transformer.wte.weight = self.lm_head.weight

    # init params
    self.apply(self._init_weights)

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      std = 0.02
      if hasattr(module, 'NANOGPT_SCALE_INIT'):
        std *= (2*self.config.n_layer)** 0.5
      torch.nn.init.normal_(module.weight, mean=0.0, std=std)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


  def forward(self, idx, targets = None):
    #idx is of shape (B, T)
    B, T = idx.size()
    assert T <= self.config.block_size, f"Cannot forward sequence of length{T}, block_size is {self.config.block_size}"

    #forward the token and position embeddings
    pos = torch.arange(0, T, dtype = torch.long, device = idx.device) # shape (T)
    pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
    tok_emb = self.transformer.wte(idx) # token embeddings of shae (B, T, n_embd)
    x = tok_emb + pos_emb
    # forward the blocks of the transformer
    for block in self.transformer.h:
      x = block(x)
    # forward the final layernorm and the classifier
    x = self.transformer.ln_f(x)
    logits = self.lm_head(x) # (B, T, vocab_size)
    loss = None
    if targets is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    return logits, loss

In [2]:
from dataclasses import dataclass
import os
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
import tiktoken


class DataLoaderLite:
    def __init__(self, B, T, enc, filename):
        self.B = B
        self.T = T

        # at init load tokens from disk and store them in memory
        with open(filename, 'r') as f:
            text = f.read()
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

        # state
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, reset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

# 

In [3]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# attempt to autodetect the device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)





using device: cuda


In [4]:
# init a huggingface/transformers model\
# model_type = 'gpt2'
# enc = tiktoken.get_encoding("gpt2")
# model = GPT2LMHeadModel.from_pretrained(model_type)
# model.to(device)


In [5]:
import torch
from dataclasses import dataclass

checkpoint_path = 'model/gpt2_model.pt'
# Load the checkpoint dictionary
checkpoint = torch.load(checkpoint_path, map_location=device)

# Define GPTConfig so it can be found during unpickling
model = GPT(checkpoint['config'])
# Get the original state dict from the checkpoint
state_dict = checkpoint['model']

# Remove the '_orig_mod.' prefix from each key
new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

# Now load the updated state dict into your model
model.load_state_dict(new_state_dict)
model.to(device)
enc = tiktoken.get_encoding("gpt2")

print(f"Checkpoint loaded at step {checkpoint['step']} with validation loss {checkpoint['val_loss']}")


  checkpoint = torch.load(checkpoint_path, map_location=device)


Checkpoint loaded at step 15000 with validation loss 3.1931183338165283


In [6]:
# prefix tokens
model.eval()
num_return_sequences = 5
max_length = 50
tokens = enc.encode("We need to save the Planet.")
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) # (5, 8)
x = tokens.to(device)


In [7]:
def generate(prompt_tokens):
    x = prompt_tokens.clone()
    # generate! right now x is (B, T) where B = 5, T = 8
    # set the seed to 42
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    while x.size(1) < max_length:
        # forward the model to get the logits
        with torch.no_grad():
            logits, _ = model(x) # (B, T, vocab_size)
            # take the logits at the last position
            logits = logits[:, -1, :] # (B, vocab_size)
            # get the probabilities
            probs = F.softmax(logits, dim=-1)
            # do top-k sampling of 50 (huggingface pipeline default)
            # topk_probs here becomes (5, 50), topk_indices is (5, 50)
            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
            # select a token from the top-k probabilities
            # note: multinomial does not demand the input to sum to 1
            ix = torch.multinomial(topk_probs, 1) # (B, 1)
            # gather the corresponding indices
            xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
            # append to the sequence
            x = torch.cat((x, xcol), dim=1)
    
    # print the generated text
    for i in range(num_return_sequences):
        tokens = x[i, :max_length].tolist()
        decoded = enc.decode(tokens)
        print(">", decoded)

prompt_tokens = tokens.to(device)
generate(prompt_tokens)

> We need to save the Planet.
The most important part of it is simply our collective responsibility to create a healthy environment and live with dignity and respect and with dignity.
However, what kind of environment is the problem that we can make the greatest
> We need to save the Planet.”
The world’s largest asteroid called 2017 Vesta, a giant, Earth-sized asteroid that’s the largest such impact globally. The asteroid is expected to be 5.5 to 7
> We need to save the Planet.” Our concern is a resounding “climate change problem.”
What we need is a ‘climate change solution’ that is designed to help the planet and its inhabitants as a whole come
> We need to save the Planet. We’re not talking to cars or dirty factories, but to be on guard against toxic pollution.
“We’re using technologies to take advantage of all the green options. And we’
> We need to save the Planet. The biggest problem will be getting more green energy that turns into power. “I don’t know for sure; I see people as i

In [8]:
import torch.nn.utils.parametrize as parametrize

#########################################
# 1) Define your LoRAParametrization
#########################################
class LoRAParametrization(nn.Module):
    def __init__(self, out_features, in_features, rank=1, alpha=1.0, device=None):
        super().__init__()
        self.lora_B = nn.Parameter(torch.zeros((out_features, rank), device=device))
        self.lora_A = nn.Parameter(torch.zeros((rank, in_features), device=device))
        nn.init.normal_(self.lora_A, mean=0.0, std=0.02)
        
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, base_weight: torch.Tensor) -> torch.Tensor:
        if self.enabled:
            lora_update = torch.matmul(self.lora_B, self.lora_A)
            return base_weight + self.scale * lora_update
        else:
            return base_weight

In [9]:
def apply_lora_to_attention(model, rank=4, alpha=8.0, device=None):
    for name, module in model.named_modules():
        # Check:
        # 1) Is this an nn.Linear?
        # 2) Does the name contain "attn" and "c_proj"? 
        #    (So we wrap only the attention's c_proj, not the MLP's c_proj.)
        if isinstance(module, nn.Linear) and ("attn" in name) and (("c_proj" in name) or ("c_attn" in name)):
            out_features, in_features = module.weight.shape
            
            # Register your LoRAParametrization on 'weight'
            parametrize.register_parametrization(
                module,
                "weight",
                LoRAParametrization(
                    out_features, 
                    in_features, 
                    rank=rank, 
                    alpha=alpha, 
                    device=device
                )
            )
def enable_disable_lora(enabled=True):
  for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and ("attn" in name) and (("c_proj" in name) or ("c_attn" in name)):
            module.parametrizations["weight"][0].enabled = enabled


In [10]:
apply_lora_to_attention(model, rank=4, alpha=8.0, device='cuda')
enable_disable_lora(enabled=True)

In [11]:
def count_lora_and_non_lora_params(model):
    lora_params = 0
    non_lora_params = 0

    for name, param in model.named_parameters():
        # Skip parameters that require no grad if you only want trainable counts
        if not param.requires_grad:
            continue

        if "lora_A" in name or "lora_B" in name:
            lora_params += param.numel()
        else:
            non_lora_params += param.numel()

    return lora_params, non_lora_params


lora_count, non_lora_count = count_lora_and_non_lora_params(model)
print(f"LoRA parameters: {lora_count}")
print(f"Non-LoRA parameters: {non_lora_count}")


LoRA parameters: 221184
Non-LoRA parameters: 124475904


In [12]:
# freeze the non-Lora parameters
for name, param in model.named_parameters():
  if 'lora' not in name:
    print(f'Freezing non-LoRA parameter{name}')
    param.requires_grad = False

Freezing non-LoRA parametertransformer.wte.weight
Freezing non-LoRA parametertransformer.wpe.weight
Freezing non-LoRA parametertransformer.h.0.ln_1.weight
Freezing non-LoRA parametertransformer.h.0.ln_1.bias
Freezing non-LoRA parametertransformer.h.0.attn.c_attn.bias
Freezing non-LoRA parametertransformer.h.0.attn.c_attn.parametrizations.weight.original
Freezing non-LoRA parametertransformer.h.0.attn.c_proj.bias
Freezing non-LoRA parametertransformer.h.0.attn.c_proj.parametrizations.weight.original
Freezing non-LoRA parametertransformer.h.0.ln_2.weight
Freezing non-LoRA parametertransformer.h.0.ln_2.bias
Freezing non-LoRA parametertransformer.h.0.mlp.c_fc.weight
Freezing non-LoRA parametertransformer.h.0.mlp.c_fc.bias
Freezing non-LoRA parametertransformer.h.0.mlp.c_proj.weight
Freezing non-LoRA parametertransformer.h.0.mlp.c_proj.bias
Freezing non-LoRA parametertransformer.h.1.ln_1.weight
Freezing non-LoRA parametertransformer.h.1.ln_1.bias
Freezing non-LoRA parametertransformer.h.1.a

In [13]:
for n, p in model.named_parameters():
    if p.requires_grad:
        print("Trainable:", n, p.shape)


Trainable: transformer.h.0.attn.c_attn.parametrizations.weight.0.lora_B torch.Size([2304, 4])
Trainable: transformer.h.0.attn.c_attn.parametrizations.weight.0.lora_A torch.Size([4, 768])
Trainable: transformer.h.0.attn.c_proj.parametrizations.weight.0.lora_B torch.Size([768, 4])
Trainable: transformer.h.0.attn.c_proj.parametrizations.weight.0.lora_A torch.Size([4, 768])
Trainable: transformer.h.1.attn.c_attn.parametrizations.weight.0.lora_B torch.Size([2304, 4])
Trainable: transformer.h.1.attn.c_attn.parametrizations.weight.0.lora_A torch.Size([4, 768])
Trainable: transformer.h.1.attn.c_proj.parametrizations.weight.0.lora_B torch.Size([768, 4])
Trainable: transformer.h.1.attn.c_proj.parametrizations.weight.0.lora_A torch.Size([4, 768])
Trainable: transformer.h.2.attn.c_attn.parametrizations.weight.0.lora_B torch.Size([2304, 4])
Trainable: transformer.h.2.attn.c_attn.parametrizations.weight.0.lora_A torch.Size([4, 768])
Trainable: transformer.h.2.attn.c_proj.parametrizations.weight.0.lo

In [14]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)

train_filename = "data/train.txt"

B = 4
T = 512
train_loader = DataLoaderLite(B=B, T=T, enc=enc, filename=train_filename)

# optimize!
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
grad_accum_steps = 8
num_steps = 300 // grad_accum_steps

loss_accum = 0.0
model.train()

for step in range(num_steps):    
    # training loop
    optimizer.zero_grad()
    loss_accum = 0.0 
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
           logits, loss = model(x, y)
        #import code; code.interact(local=locals())
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
  
    print(f"step {step}, train loss: {loss_accum.item()}")


loaded 11286 tokens
1 epoch = 5 batches
step 0, train loss: 4.258481502532959
step 1, train loss: 4.251651287078857
step 2, train loss: 4.200958251953125
step 3, train loss: 4.242818832397461
step 4, train loss: 4.189047336578369
step 5, train loss: 4.203182220458984
step 6, train loss: 4.193680286407471
step 7, train loss: 4.142091274261475
step 8, train loss: 4.182614326477051
step 9, train loss: 4.130834579467773
step 10, train loss: 4.143932819366455
step 11, train loss: 4.135772228240967
step 12, train loss: 4.082915306091309
step 13, train loss: 4.120965003967285
step 14, train loss: 4.069603443145752
step 15, train loss: 4.079800128936768
step 16, train loss: 4.074286460876465
step 17, train loss: 4.019215106964111
step 18, train loss: 4.058935642242432
step 19, train loss: 4.006374835968018
step 20, train loss: 4.0183634757995605
step 21, train loss: 4.012117862701416
step 22, train loss: 3.9580013751983643
step 23, train loss: 3.9965455532073975
step 24, train loss: 3.94507026

In [15]:
generate(prompt_tokens)

> We need to save the Planet.
The next time we got the hell out of the Earth, look in.
This is my favorite of the three!
I guess it is gonna pop up like this.
Yeah, the Earth's just
> We need to save the Planet.”
The Earth is going green. The atmosphere just got dirty. I don’t get much of anything.
I’ve become a big fan of the Universe. Can you trust us?
> We need to save the Planet.” Here, God's great lords, He's the Lord!
The planet is not the planet.
- The planet is not the planet
- The planet is not
- The moon is not the
> We need to save the Planet. The better the planet, the more of the life we have left. It is pretty important to save the Planet and everyone else in it.
Do you think the only way to reduce carbon emissions from the planet that
> We need to save the Planet. I can’t go to those big numbers (I didn’t want the planet to go bad on Earth); I was scared that’s not it. And there’s no way I
