In [1]:
import torch
import datasets

from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

In [2]:
model_path = "google/gemma-3-1b-it"
model_name = "gemma_1b_reasonning"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Since I don't have enough ressources
quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                        bnb_4bit_use_double_quant=True,
                                        bnb_4bit_quant_type="nf4",
                                        bnb_4bit_compute_dtype=torch.bfloat16) 

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quantization_config,
    device_map="auto",
    attn_implementation="eager"  # ← add this
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token

In [3]:
def generate_response(model, tokenizer, prompt, max_new_tokens=50, temperature=0.7):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    inputs = tokenizer(prompt, return_tensors='pt').to(device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=False
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return response

In [4]:
prompt = "What's the capital of France?"
generate_response(model, tokenizer, prompt)

You have set `use_cache` to `False`, but cache_implementation is set to hybrid. cache_implementation will have no effect.


"What's the capital of France?\n\n**Paris**\n\n"

In [5]:
import torch.nn as nn

def prune_layers(model, start_layer, end_layer):
    
    # Keep only layers outside the pruning range
    pruned_layers = nn.ModuleList(
        [layer for idx, layer in enumerate(model.model.layers) 
         if idx < start_layer or idx >= end_layer]
    )
    print(len(pruned_layers))
    # Assign back to the model
    model.model.layers = pruned_layers
    model.config.num_hidden_layers = len(pruned_layers)

    return model

In [6]:
pruned_model = prune_layers(model, 9, 14)

21


In [7]:
prompt = "what's the capital of france?"
generate_response(pruned_model, tokenizer, prompt)

"what's the capital of france?\n\nThe following are the themes:\n\n*   **A** \n*   **Irn-Up**\n*   **The** *a* (This is the level of detail is the the the the the).\n\n*   "

In [8]:
pruned_model.model.layers[15]

Gemma3DecoderLayer(
  (self_attn): Gemma3Attention(
    (q_proj): Linear4bit(in_features=1152, out_features=1024, bias=False)
    (k_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
    (v_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
    (o_proj): Linear4bit(in_features=1024, out_features=1152, bias=False)
    (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
    (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
  )
  (mlp): Gemma3MLP(
    (gate_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
    (up_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
    (down_proj): Linear4bit(in_features=6912, out_features=1152, bias=False)
    (act_fn): PytorchGELUTanh()
  )
  (input_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (post_attention_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (pre_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (post_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
)

In [9]:
from peft import LoraConfig, get_peft_model, PeftModel

lora_config = LoraConfig(
    r=8,                          # LoRA rank, lower is lighter-weight
    lora_alpha=8,                # alpha scaling factor (usually 2*r)
    lora_dropout=0.05,             # dropout regularization
    target_modules=[
        # "q_proj", "k_proj", "v_proj", "o_proj",   # attention layers
        "gate_proj", "up_proj", "down_proj"       # MLP layers
    ],
    bias="none",                  # LoRA typically ignores biases
    layers_to_transform=list(range(len(pruned_model.model.layers))),
    task_type="CAUSAL_LM"
)

In [10]:
pruned_model = get_peft_model(pruned_model, lora_config)

In [11]:
prompt = "what's the capital of france?"
generate_response(pruned_model, tokenizer, prompt)

'what\'s the capital of france?\n\nThe date of issuance,\n\nThe capitalizes the “date”\n\n| leading the way right with the "50 year+50" |\n\n| | Leading the year 2000 is the “5 year”|\n\n'

In [12]:
def format_instruction(sample):
    if sample["context"].strip():
        prompt = f"Instruction: {sample['instruction']}\nContext: {sample['context']}\nAnswer:"
    else:
        prompt = f"Instruction: {sample['instruction']}\nAnswer:"
    
    return {"prompt": prompt, "completion": sample["response"]}

def tokenize(sample, max_length=512):
    prompt = sample["prompt"]
    completion = sample["completion"]

    prompt_tokens = tokenizer(prompt, truncation=True, max_length=max_length)
    completion_tokens = tokenizer(completion, truncation=True, max_length=max_length, add_special_tokens=False)

    input_ids = prompt_tokens["input_ids"] + completion_tokens["input_ids"]

    attention_mask = [1] * len(input_ids)

    labels = [-100] * len(prompt_tokens["input_ids"]) + completion_tokens["input_ids"]

    # No padding/truncation here; let DataCollator handle that
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

In [13]:
from torch.nn.utils.rnn import pad_sequence

class CustomDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.pad_token_id = tokenizer.pad_token_id

    def __call__(self, features):
        input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
        attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features]
        labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]

        # Pad sequences
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset, load_from_disk
from transformers import DataCollatorForLanguageModeling

# Tokenization function
def tokenize_function(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

batch_size = 2

local_dataset = load_from_disk("./c4_100k_subset")
# local_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

# prepared_dataset = local_dataset.map(format_instruction, remove_columns=["instruction", "context", "response", "category"])


# Tokenize dataset
tokenized_dataset = local_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
# tokenized_dataset = prepared_dataset.map(tokenize, remove_columns=["prompt", "completion"])
# tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

# data_collator = CustomDataCollator(
#     tokenizer=tokenizer
# )

dataloader = DataLoader(tokenized_dataset, batch_size, shuffle=True)
print(len(dataloader))

Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

7506


In [15]:
from tqdm import tqdm
from torch.optim import AdamW

optimizer = AdamW(pruned_model.parameters(), lr=1e-4)

In [16]:

# Fine-tuning loop
epochs = 1  # 1 epoch is often enough for healing
n_interm = len(dataloader)//100

pruned_model.train()
for epoch in range(epochs):
    epoch_loss = 0
    idx = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        idx+=1
        batch = {k: v.to(pruned_model.device) for k, v in batch.items()}
        outputs = pruned_model(**batch)
        loss = outputs.loss
        
        if torch.isnan(loss):
            print("Loss is NaN at batch", idx, batch)
            continue
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()
     
        
        if (idx)%n_interm==0 :
            print(f"batch {idx}/{len(dataloader)} completed. Average Loss: {epoch_loss / idx:.4f}")
            
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")

# Save fine-tuned model adapters
pruned_model.save_pretrained("gemma_pruned_lora_2")

  input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
  attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features]
  labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
Epoch 1:   1%|          | 75/7506 [00:41<30:55,  4.01it/s]  

batch 75/7506 completed. Average Loss: 3.1048


Epoch 1:   2%|▏         | 150/7506 [01:14<39:12,  3.13it/s]  

batch 150/7506 completed. Average Loss: 2.8065


Epoch 1:   3%|▎         | 225/7506 [01:54<45:58,  2.64it/s]  

batch 225/7506 completed. Average Loss: 2.6745


Epoch 1:   4%|▍         | 300/7506 [02:33<48:47,  2.46it/s]  

batch 300/7506 completed. Average Loss: 2.5860


Epoch 1:   5%|▍         | 375/7506 [03:02<35:41,  3.33it/s]  

batch 375/7506 completed. Average Loss: 2.5461


Epoch 1:   6%|▌         | 450/7506 [03:38<59:20,  1.98it/s]  

batch 450/7506 completed. Average Loss: 2.5208


Epoch 1:   7%|▋         | 525/7506 [04:16<33:01,  3.52it/s]  

batch 525/7506 completed. Average Loss: 2.4913


Epoch 1:   8%|▊         | 600/7506 [05:03<40:04,  2.87it/s]  

batch 600/7506 completed. Average Loss: 2.4674


Epoch 1:   9%|▉         | 675/7506 [05:44<35:49,  3.18it/s]  

batch 675/7506 completed. Average Loss: 2.4541


Epoch 1:  10%|▉         | 750/7506 [06:17<38:50,  2.90it/s]  

batch 750/7506 completed. Average Loss: 2.4320


Epoch 1:  11%|█         | 825/7506 [06:57<1:09:11,  1.61it/s]

batch 825/7506 completed. Average Loss: 2.4039


Epoch 1:  12%|█▏        | 900/7506 [07:29<1:09:35,  1.58it/s]

batch 900/7506 completed. Average Loss: 2.3953


Epoch 1:  13%|█▎        | 975/7506 [08:04<56:32,  1.93it/s]  

batch 975/7506 completed. Average Loss: 2.3883


Epoch 1:  14%|█▍        | 1050/7506 [08:38<55:09,  1.95it/s]  

batch 1050/7506 completed. Average Loss: 2.3825


Epoch 1:  15%|█▌        | 1126/7506 [09:06<34:44,  3.06it/s]  

batch 1125/7506 completed. Average Loss: 2.3783


Epoch 1:  16%|█▌        | 1201/7506 [09:39<27:27,  3.83it/s]  

batch 1200/7506 completed. Average Loss: 2.3690


Epoch 1:  17%|█▋        | 1275/7506 [10:12<38:22,  2.71it/s]  

batch 1275/7506 completed. Average Loss: 2.3694


Epoch 1:  18%|█▊        | 1350/7506 [10:52<1:03:09,  1.62it/s]

batch 1350/7506 completed. Average Loss: 2.3646


Epoch 1:  19%|█▉        | 1425/7506 [11:38<29:05,  3.48it/s]  

batch 1425/7506 completed. Average Loss: 2.3615


Epoch 1:  20%|█▉        | 1500/7506 [12:22<46:12,  2.17it/s]  

batch 1500/7506 completed. Average Loss: 2.3564


Epoch 1:  21%|██        | 1560/7506 [13:07<50:00,  1.98it/s]  


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 11.53 GiB is allocated by PyTorch, and 1.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
model.eval()
prompt = "Quesion: What is the capital of France? A:"

generate_response(pruned_model, tokenizer, prompt)

'Q: What is the capital of France? A: All of France\'s capital cities are named after France, which means "France".\nA: All of France\'s capital cities are named after France.\nA: All of France\'s capital cities are named after France, which means "'

In [None]:
prompt = "What is the capital of france?"
response = model.generate(
    **tokenizer(prompt, return_tensors="pt").to(model.device),
    max_new_tokens=30,
    do_sample=False,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.2,  # this helps reduce repetitions
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
    use_cache = False,
)
print(tokenizer.decode(response[0], skip_special_tokens=True))



What is the capital of france?
France has a population 1.2 billion people, and it covers almost half-the continent: France represents about one third (about two percent
