In [None]:
from datasets import load_dataset, Dataset

# Stream the dataset
streamed_dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)

# Take the first 10,000 samples and convert to a list (or list of dicts)
sampled_data = list(streamed_dataset.take(10000))

# Convert to a regular Hugging Face Dataset
dataset_to_save = Dataset.from_list(sampled_data)

# Save locally (Arrow format)
dataset_to_save.save_to_disk("./c4_10k_subset")

from datasets import load_from_disk

local_dataset = load_from_disk("./c4_10k_subset")

KeyboardInterrupt: 

In [None]:
import torch.nn as nn

def prune_layers(model, start_layer, end_layer):
    
    # Keep only layers outside the pruning range
    pruned_layers = nn.ModuleList(
        [layer for idx, layer in enumerate(model.model.layers) 
         if idx < start_layer or idx >= end_layer]
    )
    print(len(pruned_layers))
    # Assign back to the model
    model.model.layers = pruned_layers
    model.config.num_hidden_layers = len(pruned_layers)

    return model

In [2]:
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_path = "google/gemma-3-1b-it"

# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quantization_config,
    device_map="auto"
)

model = prune_layers(base_model, 9, 14)
# Load the fine-tuned LoRA adapters
model = PeftModel.from_pretrained(base_model, "gemma_pruned_lora").to(device)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Generation
model.eval()



21


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForCausalLM(
      (model): Gemma3TextModel(
        (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
        (layers): ModuleList(
          (0-20): 21 x Gemma3DecoderLayer(
            (self_attn): Gemma3Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=1152, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1152, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
   

In [3]:
print(f"Num layers after pruning: {model.config.num_hidden_layers}")


Num layers after pruning: 21


In [4]:
def generate_response(model, tokenizer, prompt, max_new_tokens=50, temperature=0.7):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    inputs = tokenizer(prompt, return_tensors='pt').to(device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=False
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return response

In [5]:
prompt = "What is the capital of France? "

generate_response(model, tokenizer, prompt)

You have set `use_cache` to `False`, but cache_implementation is set to hybrid. cache_implementation will have no effect.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [4]:
prompt = "What is the capital of France? "
inputs = tokenizer(prompt, return_tensors='pt')
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs, use_cache=False)

# Logits tensor
logits = outputs.logits

print("Logits shape:", logits.shape)
print("Logits sample:", logits[0, -1, :5])  # Print the logits of the last token, first 5 logits

# Check for NaNs or Infs
if torch.isnan(logits).any():
    print("NaNs detected in logits!")
if torch.isinf(logits).any():
    print("Infs detected in logits!")

Logits shape: torch.Size([1, 9, 262144])
Logits sample: tensor([nan, nan, nan, nan, nan], device='cuda:0', dtype=torch.float16)
NaNs detected in logits!
