In [1]:
import torch
import datasets

from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

In [2]:
model_path = "google/gemma-3-1b-it"
model_name = "gemma_1b_reasonning"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Since I don't have enough ressources
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                        bnb_4bit_use_double_quant=True,
                                        bnb_4bit_quant_type="nf4",
                                        bnb_4bit_compute_dtype=torch.bfloat16) 

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quantization_config,
    device_map="auto",
    attn_implementation="eager"  # ← add this
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token

In [3]:
def generate_response(model, tokenizer, prompt, max_new_tokens=50, temperature=0.7):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    inputs = tokenizer(prompt, return_tensors='pt').to(device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return response

In [4]:
prompt = "What's the capital of France?"
generate_response(model, tokenizer, prompt)

"What's the capital of France?\n\n**Paris**\n"

In [5]:
import torch.nn as nn

def prune_layers(model, start_layer, end_layer):
    
    # Keep only layers outside the pruning range
    pruned_layers = nn.ModuleList(
        [layer for idx, layer in enumerate(model.model.layers) 
         if idx < start_layer or idx >= end_layer]
    )
    print(len(pruned_layers))
    # Assign back to the model
    model.model.layers = pruned_layers
    model.config.num_hidden_layers = len(pruned_layers)

    return model

In [6]:
pruned_model = prune_layers(model, 9, 14)

21


In [7]:
prompt = "what's the capital of france?"
generate_response(pruned_model, tokenizer, prompt)

"what's the capital of france?\n\nThe capital's France**\n\nThe following are the common years of the capital's France?\n\n0100128400128400000000000000"

In [8]:
pruned_model.model.layers[15]

Gemma3DecoderLayer(
  (self_attn): Gemma3Attention(
    (q_proj): Linear4bit(in_features=1152, out_features=1024, bias=False)
    (k_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
    (v_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
    (o_proj): Linear4bit(in_features=1024, out_features=1152, bias=False)
    (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
    (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
  )
  (mlp): Gemma3MLP(
    (gate_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
    (up_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
    (down_proj): Linear4bit(in_features=6912, out_features=1152, bias=False)
    (act_fn): PytorchGELUTanh()
  )
  (input_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (post_attention_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (pre_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (post_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
)

In [9]:
from peft import LoraConfig, get_peft_model, PeftModel

lora_config = LoraConfig(
    r=8,                          # LoRA rank, lower is lighter-weight
    lora_alpha=16,                # alpha scaling factor (usually 2*r)
    lora_dropout=0.1,             # dropout regularization
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",   # attention layers
        "gate_proj", "up_proj", "down_proj"       # MLP layers
    ],
    bias="none",                  # LoRA typically ignores biases
    layers_to_transform=list(range(len(pruned_model.model.layers))),
    task_type="CAUSAL_LM"
)

In [10]:
pruned_model = get_peft_model(pruned_model, lora_config)

In [11]:
prompt = "what's the capital of france?"
generate_response(pruned_model, tokenizer, prompt)

"what's the capital of france?\n\nThe date of issuance of a royal/official proclamation is a common occurrence.\n\n**French**\n\n**13th, 13th, 14th, 14th, 17th, 14th"

In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset, load_from_disk

# Tokenization function
def tokenize_function(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

batch_size = 8

local_dataset = load_from_disk("./c4_10k_subset")

# Tokenize dataset
tokenized_dataset = local_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

dataloader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True)

In [13]:
from tqdm import tqdm
from torch.optim import AdamW

optimizer = AdamW(pruned_model.parameters(), lr=1e-4)

In [None]:

# Fine-tuning loop
epochs = 1  # 1 epoch is often enough for healing

pruned_model.train()
for epoch in range(epochs):
    epoch_loss = 0
    for idx, batch in tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}"):
        
        batch = {k: v.to(pruned_model.device) for k, v in batch.items()}
        outputs = pruned_model(**batch, labels=batch["input_ids"])
        loss = outputs.loss
        
        if torch.isnan(loss):
            print("Loss is NaN at batch", idx, batch)
            continue
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()
     
        
        if (idx+1)%100==0 :
            print(f"batch {idx}/{len(dataloader)} completed. Average Loss: {epoch_loss / idx:.4f}")
            
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")

# Save fine-tuned model adapters
pruned_model.save_pretrained("gemma_pruned_lora")

Epoch 1: 100it [01:10,  1.46it/s]

batch 99/125 completed. Average Loss: 4.2245


Epoch 1: 125it [01:27,  1.42it/s]


Epoch 1 completed. Average Loss: 4.0607


In [15]:
pruned_model.eval()
prompt = "What is the capital of France? "

generate_response(pruned_model, tokenizer, prompt)

'What is the capital of France? 1500 1200 1800 1994 1995 1996 1998 1999 1999 2000 '