In [None]:
!pip install transformers peft datasets torch bitsandbytes tensorboard

In [2]:
import bitsandbytes
print(bitsandbytes.__version__)

0.45.5


In [3]:
!pip install wandb -qU

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.4/21.4 MB[0m [31m74.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [6]:
import wandb
import random
import math

In [None]:
wandb.login(key="your-secret-key")

In [None]:
import io
from PIL import Image

def log_lora_weights_to_wandb(model, pattern="lora_A"):
    for name, param in model.named_parameters():
        if pattern in name:
            param_data = param.detach().cpu().numpy()
            plt.figure(figsize=(6, 3))
            plt.imshow(param_data, aspect='auto', cmap='viridis')
            plt.colorbar()
            plt.title(f"{name} - shape: {param_data.shape}")
            plt.xlabel("Columns")
            plt.ylabel("Rows")
            plt.tight_layout()

            # Save plot to an image buffer
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            image = Image.open(buf)

            # Log to wandb
            wandb.log({f"{name}": wandb.Image(image)})

            plt.close()


In [20]:
from transformers import TrainerCallback

class LogLoRAWeightsCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        for name, param in kwargs['model'].named_parameters():
            if "lora_A" in name or "lora_B" in name:
                wandb.log({f"weights/{name}": wandb.Histogram(param.detach().cpu().numpy())}, step=state.global_step)


In [None]:
import torch
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime



writer = SummaryWriter(log_dir="runs/lora_experiment")



for handler in logging.root.handlers[:]:
    print(handler)
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler("training.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class LoRAHook:
    def __init__(self, writer = None):
        self.inputs = {}
        self.grads = {}

    def capture_input(self, name):
        
        def hook(module, input, output):
            
            if input[0] is not None:
                self.inputs[name] = input[0].detach()
        return hook

    def capture_grad(self, name):
        def hook(module, grad_input, grad_output):
            if grad_output[0] is not None:
                self.grads[name] = grad_output[0].detach()
               
        return hook

    def check_grad(self,name):
        def hook(module, grad_input, grad_output):
            print(f"[HOOK] {name} got gradient: {grad_output[0] is not None}")
        return hook

    

    def attach(self, model):
        for name, module in model.named_modules():
            if "lora_A" in name or "lora_B" in name:
                # logger.info(f"Attaching hooks to {name}")
                module.register_forward_hook(self.capture_input(name))
                module.register_full_backward_hook(self.capture_grad(name))
                # module.register_full_backward_hook(self.check_grad(name))




def fisher_kfac_lora(input_act, output_grad, damping=1e-3):
    # logger.info(f"[DEBUG] input_act shape: {input_act.shape}")
    # logger.info(f"[DEBUG] output_grad shape: {output_grad.shape}")
    try:
        B, T, D_in = input_act.shape
        _, _, D_out = output_grad.shape

        # Flatten (B*T, D)
        input_flat = input_act.reshape(B * T, D_in)
        output_flat = output_grad.reshape(B * T, D_out)

        A = (input_flat.T @ input_flat) / (B * T) + damping * torch.eye(D_in, device=input_act.device)
        G = (output_flat.T @ output_flat) / (B * T) + damping * torch.eye(D_out, device=output_grad.device)

        return A, G
    except Exception as e:
        logger.error(f"Error in fisher_kfac_lora: {e}")
        return None, None

def apply_natural_gradient(model, lora_hook, damping=1e-3):
    logger.info("Applying Natural Gradient update to LoRA weights")
    device = model.device
    for name, module in model.named_modules():
        if "lora_A" in name or "lora_B" in name:
            if name in lora_hook.inputs and name in lora_hook.grads:
                a = lora_hook.inputs[name].to(device)
                g = lora_hook.grads[name].to(device)
                A, G = fisher_kfac_lora(a, g, damping)
                if A is None or G is None:
                    continue
                try:
                    A_inv = torch.linalg.inv(A)
                    G_inv = torch.linalg.inv(G)
                    weight = module.weight
                    if weight.grad is None:
                        logger.warning(f"No gradient for {name}")
                        continue
                    grad = weight.grad.detach()
                    ng_update = A_inv @ grad @ G_inv
                    with torch.no_grad():
                        weight.add_(-0.01 * ng_update)
                    logger.info(f"Applied NG update to {name}")
                except Exception as e:
                    logger.error(f"Error in NG update for {name}: {e}")

def generate_sample(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)



 

def main():
    device = torch.device("cpu")
    logger.info(f"Using device: {device}")

    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    logger.info(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)


    lora_config = LoraConfig(
        r=8,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    



    hyperparameters = {
        "model": "TinyLlama-1.1B-Chat-v1.0",
        "lr": 2e-4,
        "batch_size": 1,
        "epochs": 2,
        "lora_r": 8,
        "lora_alpha": 32,
        "lora_dropout": 0.05,
    }
    
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    
    wandb.init(
        project="tinyllama-lora-alpaca",
        name=f"experiment-{timestamp}",  
        # sync_tensorboard=True,
        config=hyperparameters  
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    initial_weights = {}
    
    for name, param in model.named_parameters():
        if "lora_" in name:
            initial_weights[name] = param.detach().clone().to(param.device)


    lora_hook = LoRAHook()
    lora_hook.attach(model)

    logger.info("Loading dataset")
    dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]")

    def format_prompt(example):
        instruction = example["instruction"]
        input_text = example["input"]
        response = example["output"]
        prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{response}" if input_text else f"### Instruction:\n{instruction}\n\n### Response:\n{response}"
        return {"text": prompt}

    dataset = dataset.map(format_prompt)

    def tokenize(example):
        output = tokenizer(
            example["text"],
            truncation=True,
            padding="max_length",
            max_length=512,
            return_tensors="pt"
        )
        output["input_ids"] = output["input_ids"].squeeze().to("cpu")
        output["labels"] = output["input_ids"].clone()
        return output

    tokenized = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)

    training_args = TrainingArguments(
        output_dir="./lora-tinyllama-alpaca",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        num_train_epochs=2,
        learning_rate=2e-4,
        fp16=False,
        bf16=False,
        logging_steps=50,
        save_steps=5,
        save_total_limit=1,
        report_to="wandb",
        label_names=["labels"]
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    def custom_data_collator(features):
        batch = data_collator(features)
        return {k: v.to(device) for k, v in batch.items()}

    logger.info("Starting training")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized,
        processing_class=tokenizer,
        data_collator=custom_data_collator,
        callbacks=[LogLoRAWeightsCallback()] 
    )
    trainer.train()

    apply_natural_gradient(model, lora_hook)

    import matplotlib.pyplot as plt
    import os

    os.makedirs("lora_diff_plots", exist_ok=True)
    
    for name, param in model.named_parameters():
        if "lora_" in name and name in initial_weights:
            before = initial_weights[name].to(param.device)
            after = param.detach()
    
            # Difference mask (1 where changed)
            diff_mask = (before != after).cpu().int()
    
            plt.figure(figsize=(6, 6))
            plt.imshow(diff_mask, cmap="Greens", interpolation="nearest")
            plt.title(f"Changed Params in {name}")
            plt.axis("off")
    
            # Save plot
            plt.savefig(f"lora_diff_plots/{name.replace('.', '_')}.png")
            plt.close()


    logger.info("Analyzing LoRA weights")
    for name in lora_hook.inputs:
        a = lora_hook.inputs[name].to(device)
        g = lora_hook.grads[name].to(device)
        A, G = fisher_kfac_lora(a, g)
        if A is not None:
            wandb.log({f"{name}_A_norm": torch.norm(A).item(), f"{name}_G_norm": torch.norm(G).item()})
            # logger.info(f"{name} — A shape: {A.shape}, G shape: {G.shape}")

    logger.info("Saving model")
    model.save_pretrained("./lora-tinyllama-alpaca")
    tokenizer.save_pretrained("./tinyllama-lora-alpaca")

    logger.info("Generating sample output")
    prompt = "What is the capital of France?"
    response = generate_sample(model, tokenizer, prompt)
    logger.info(f"Sample prompt: {prompt}\nResponse: {response}")
    wandb.log({"sample_prompt": prompt, "sample_response": response})  

    wandb.finish()  # ✅ Finish WandB run
    

if __name__ == "__main__":
    main()

In [27]:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("./tinyllama-lora-alpaca")
model = AutoModelForCausalLM.from_pretrained("./lora-tinyllama-alpaca")
prompt = "What is the capital of France?"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

What is the capital of France?
What is the currency of France?
What is the population of France?
What is the official language of France?
What is the religion of France?
What is the capital of Germany?



In [10]:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("./tinyllama-lora-alpaca")
model = AutoModelForCausalLM.from_pretrained("./lora-tinyllama-alpaca")
prompt = "What is 2 + 2?"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

What is 2 + 2?
2 + 2 = 4

### Response:
The answer is 4. 2 + 2 is 4. This is the sum of 2 and 2.
