# LoRA implementation

## Set up LoRA

In [1]:
from transformers import GPT2Model, GPT2Tokenizer, GPT2LMHeadModel

model = GPT2Model.from_pretrained('gpt2')
model

  from .autonotebook import tqdm as notebook_tqdm


GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [2]:
model.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2SdpaAttention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [3]:
model.h[0].attn.c_attn.weight.shape

torch.Size([768, 2304])

Let's do the low rank matrices as `B` (2304 by 4) and `A` (4 by 768).

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4):
        super().__init__()
        self.lora_A = nn.Parameter(nn.init.normal_(torch.randn(rank, in_features)))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.rank = rank
        self.scaling = 0.005 

    def forward(self, x):
        return (x@self.lora_A.T@self.lora_B.T) * self.scaling/self.rank


In [6]:
lora_layers = nn.ModuleList([LoRALayer(768, 2304) for _ in range(12)])

In [7]:
import types

In [8]:
for i, block in enumerate(model.h):
    original_forward = block.attn.c_attn.forward

    def new_forward(self, x):
        x = original_forward(x) + lora_layers[i](x)
        return x
    
    block.attn.c_attn.forward = types.MethodType(new_forward, block.attn.c_attn)

In [9]:
# Freeze all parameters of the original model
for param in model.parameters():
    param.requires_grad = False

# Unfreeze LoRA parameters
for layer in lora_layers:
    for param in layer.parameters():
        param.requires_grad = True

In [10]:
optimizer = torch.optim.AdamW(lora_layers.parameters(), lr=1e-3)

In [11]:
model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

## Load data

In [12]:
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

In [13]:
class ShakespeareDataset(Dataset):
    def __init__(self, file_path, tokenizer, seq_length=128):
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
            # DEBUG
            text = text[:1000]
        
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.tokens = tokenizer.encode(text)
        
    def __len__(self):
        return len(self.tokens) - self.seq_length

    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx+self.seq_length+1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

In [14]:
def create_dataloaders(dataset, batch_size, train_split=0.9):
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

In [15]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')



In [16]:
dataset = ShakespeareDataset('tiny-shakespeare.txt', tokenizer, seq_length=128)
train_loader, val_loader = create_dataloaders(dataset, batch_size=32)

## Train LoRA

In [17]:
import torch
import torch.nn as nn
import math

In [18]:
class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        self.rank = rank
        self.scaling = alpha / rank

        # Initialize LoRA parameters
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        return (x @ self.lora_A.T @ self.lora_B.T) * self.scaling

In [19]:
class LoRAWrapper(nn.Module):
    def __init__(self, model, rank=4, alpha=1):
        super().__init__()
        self.model = model
        self.lora_layers = nn.ModuleList([
            LoRALayer(768, 2304, rank, alpha) for _ in range(len(model.transformer.h))
        ])
        
        # Freeze original model parameters
        for param in self.model.parameters():
            param.requires_grad = False
        
        # Enable grad for LoRA parameters
        for layer in self.lora_layers:
            for param in layer.parameters():
                param.requires_grad = True

    def forward(self, *args, **kwargs):
        # Store original forward methods
        original_forwards = [block.attn.c_attn.forward for block in self.model.transformer.h]
        lora_layers = self.lora_layers

        # Replace forward methods with LoRA-augmented versions
        for i, block in enumerate(self.model.transformer.h):
            def new_forward(self, x, i=i):
                return original_forwards[i](x) + lora_layers[i](x)
            block.attn.c_attn.forward = new_forward.__get__(block.attn.c_attn, type(block.attn.c_attn))

        # Call the model with LoRA
        output = self.model(*args, **kwargs)

        # Restore original forward methods
        # for i, block in enumerate(self.model.transformer.h):
        #     block.attn.c_attn.forward = original_forwards[i]

        return output

    def get_lora_parameters(self):
        return self.lora_layers.parameters()


In [20]:
from tqdm import tqdm

In [21]:
def train_lora(model, train_loader, val_loader, optimizer, scheduler, num_epochs, device):
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs, labels=targets)
                val_loss += outputs.loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        print(f"Validation Loss: {avg_val_loss:.4f}")
        

In [45]:
batch_size = 128
seq_length = 16
num_epochs = 5
learning_rate = 5e-3
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [46]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [47]:
lora_model = LoRAWrapper(model, rank=4, alpha=1)

In [48]:
optimizer = torch.optim.AdamW(lora_model.get_lora_parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

In [49]:
train_lora(lora_model, train_loader, val_loader, optimizer, scheduler, num_epochs, device)

Epoch 1/5: 100%|██████████| 5/5 [02:41<00:00, 32.22s/it]


Epoch 1/5, Average Loss: 7.7928
Validation Loss: 6.5757


Epoch 2/5: 100%|██████████| 5/5 [03:08<00:00, 37.77s/it]


Epoch 2/5, Average Loss: 6.0483
Validation Loss: 5.4927


Epoch 3/5: 100%|██████████| 5/5 [03:42<00:00, 44.55s/it]


Epoch 3/5, Average Loss: 5.4157
Validation Loss: 5.1772


Epoch 4/5: 100%|██████████| 5/5 [02:50<00:00, 34.10s/it]


Epoch 4/5, Average Loss: 4.9419
Validation Loss: 4.6396


Epoch 5/5: 100%|██████████| 5/5 [04:37<00:00, 55.53s/it]


Epoch 5/5, Average Loss: 4.4962
Validation Loss: 4.1927


In [50]:
torch.save(lora_model.lora_layers.state_dict(), 'lora_weights.pth')

## Load model


In [51]:
lora_weights = torch.load('lora_weights.pth')
lora_model.lora_layers.load_state_dict(lora_weights)

  lora_weights = torch.load('lora_weights.pth')


<All keys matched successfully>

In [52]:
lora_model

LoRAWrapper(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2SdpaAttention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (lora_layers): ModuleList(
   

## Inference

In [53]:
lora_model.to("mps")

LoRAWrapper(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2SdpaAttention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (lora_layers): ModuleList(
   

In [54]:
def generate_text_with_lora(lora_model, tokenizer, prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.95, num_return_sequences=1, device='mps', return_full_text=False):
    lora_model.eval()
    
    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    prompt_length = input_ids.size(1)
    
    generated_sequences = []
    
    for _ in range(num_return_sequences):
        current_input_ids = input_ids.clone()
        
        with torch.no_grad():
            for _ in range(max_length):
                try:
                    # Get the model's output
                    outputs = lora_model(current_input_ids)
                    next_token_logits = outputs.logits[:, -1, :]
                    
                    # Apply temperature
                    next_token_logits = next_token_logits / temperature
                    
                    # Apply top-k filtering
                    top_k = min(top_k, next_token_logits.size(-1))
                    top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
                    next_token_logits[next_token_logits < top_k_logits[:, [-1]]] = float('-inf')
                    
                    # Apply top-p (nucleus) filtering
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool).scatter_(
                        dim=-1, index=sorted_indices, src=sorted_indices_to_remove
                    )
                    next_token_logits[indices_to_remove] = float('-inf')
                    
                    # Sample from the filtered distribution
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    
                    # Append the new token to the input_ids
                    current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
                    
                    # Check if we've generated an EOS token
                    if next_token.item() == tokenizer.eos_token_id:
                        break
                
                except RuntimeError as e:
                    if 'out of memory' in str(e):
                        print(f"WARNING: ran out of memory in iteration {_}. This might result in inferior results. Try a smaller model or reduce batch size.")
                        if torch.backends.mps.is_available():
                            torch.backends.mps.empty_cache()
                        break
                    else:
                        raise e
        
        # Decode the generated sequence
        generated_sequence = current_input_ids[0].tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        
        if not return_full_text:
            # Remove the prompt from the generated text
            text = text.split(prompt)[-1]
        
        generated_sequences.append(text.strip())
    
    return generated_sequences

In [61]:
prompt = "Antoine: Thou art as fat as butter."

In [62]:
generated_texts = generate_text_with_lora(
    lora_model, 
    tokenizer, 
    prompt, 
    max_length=20, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.6, 
    num_return_sequences=3, 
    device='mps',
    return_full_text=True
)


In [63]:
generated_texts[0]

'Antoine: Thou art as fat as butter.\n\ncius killed killedcius killed citizens killed killed killed Citizen\n killed killed killed killed killFirst citizens'

Reflection: As we can see, the model overfits on our small subset of data. As such, we should be using a bigger dataset. However, we are limited by GPUs for now.