# Lab 3.2: Wanda Pruning - Environment Setup

**Goal:** Prepare the environment for Wanda pruning experiments.

**You will learn to:**
- Verify GPU and PyTorch sparse tensor support
- Install Wanda library and dependencies
- Load a baseline dense model for pruning
- Prepare calibration data for activation statistics

---

## Why Environment Verification Matters

**Wanda pruning has specific requirements**:
- **GPU Memory**: Pruning requires loading FP16 model (~15GB for Llama-2-7B)
- **PyTorch Sparse**: Need PyTorch with sparse tensor support
- **Calibration Data**: Representative dataset for activation statistics
- **Hardware Acceleration**: Optional 2:4 sparse support (NVIDIA A100)

**Time investment**: 5-10 minutes (one-time setup)

---
## Step 1: Hardware Verification

First, let's verify GPU availability and specifications.

In [None]:
# Check NVIDIA GPU status
!nvidia-smi

In [None]:
import torch

print("=" * 60)
print("GPU Configuration Check")
print("=" * 60)

# PyTorch and CUDA versions
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

if torch.cuda.is_available():
    # GPU details
    gpu_id = 0
    gpu_props = torch.cuda.get_device_properties(gpu_id)
    
    print(f"\n✅ GPU Detected:")
    print(f"   Name: {torch.cuda.get_device_name(gpu_id)}")
    print(f"   Total Memory: {gpu_props.total_memory / 1e9:.2f} GB")
    print(f"   Compute Capability: SM {gpu_props.major}.{gpu_props.minor}")
    
    # Check 2:4 sparse support (Ampere+ = SM 8.0+)
    if gpu_props.major >= 8:
        print(f"   ✅ 2:4 Sparse Support: YES (SM {gpu_props.major}.{gpu_props.minor} >= 8.0)")
        print(f"      → Can use NVIDIA Sparse Tensor Cores for 2x acceleration")
    else:
        print(f"   ⚠️  2:4 Sparse Support: NO (SM {gpu_props.major}.{gpu_props.minor} < 8.0)")
        print(f"      → Sparse pruning will have no hardware acceleration")
    
    # Memory recommendation
    if gpu_props.total_memory / 1e9 >= 16:
        print(f"   ✅ Memory: Sufficient for pruning (>= 16GB)")
    else:
        print(f"   ⚠️  Memory: Limited (<16GB). May need smaller model.")
else:
    print("\n❌ No GPU detected!")
    print("   Wanda pruning requires GPU for efficient computation.")
    print("   Please run on a machine with NVIDIA GPU.")

print("=" * 60)

---
## Step 2: Check PyTorch Sparse Tensor Support

Wanda creates sparse models. Let's verify PyTorch sparse functionality.

In [None]:
print("=" * 60)
print("PyTorch Sparse Tensor Support Check")
print("=" * 60)

# Test sparse tensor creation
try:
    # Create a dense tensor
    dense = torch.randn(100, 100)
    
    # Create sparse tensor (COO format)
    indices = torch.LongTensor([[0, 1, 1], [2, 0, 2]])
    values = torch.FloatTensor([3, 4, 5])
    sparse = torch.sparse_coo_tensor(indices, values, (2, 3))
    
    print("✅ Sparse tensor creation: SUCCESS")
    print(f"   Sparse tensor shape: {sparse.shape}")
    print(f"   Sparse tensor format: {sparse.layout}")
    
    # Test sparse to dense conversion
    dense_from_sparse = sparse.to_dense()
    print("✅ Sparse to dense conversion: SUCCESS")
    
    # Test sparsity measurement
    test_tensor = torch.randn(100, 100)
    test_tensor[test_tensor.abs() < 0.5] = 0  # Create 50% sparsity
    sparsity = (test_tensor == 0).sum().item() / test_tensor.numel()
    print(f"✅ Sparsity measurement: {sparsity:.2%}")
    
    print("\n✅ All sparse tensor operations working!")
    
except Exception as e:
    print(f"❌ Sparse tensor test failed: {e}")
    print("   Please update PyTorch to version >= 2.0")

print("=" * 60)

---
## Step 3: Install Wanda and Dependencies

We'll install:
- **transformers**: Model loading and inference
- **datasets**: Calibration data loading
- **accelerate**: Distributed loading support

**Note**: Wanda is typically implemented as a standalone script. We'll implement it ourselves in the next notebook.

In [None]:
# Install core libraries
!pip install -q transformers>=4.35.0  # Model support
!pip install -q datasets  # Calibration data
!pip install -q accelerate  # Distributed loading

print("✅ Installation complete!")

---
## Step 4: Verify Library Versions

In [None]:
import transformers
import accelerate
import datasets

print("=" * 60)
print("Library Version Check")
print("=" * 60)
print(f"PyTorch:      {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"Accelerate:   {accelerate.__version__}")
print(f"Datasets:     {datasets.__version__}")
print("=" * 60)

# Version checks
def check_version(name, current, required):
    from packaging import version
    if version.parse(current) >= version.parse(required):
        print(f"✅ {name}: {current} >= {required}")
    else:
        print(f"⚠️  {name}: {current} < {required} (may cause issues)")

check_version("Transformers", transformers.__version__, "4.35.0")
check_version("PyTorch", torch.__version__.split("+")[0], "2.0.0")

print("\n✅ All libraries verified!")

---
## Step 5: Load Baseline Model (Dense)

Let's load the baseline **Llama-2-7B** model in FP16 precision.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc

# Model configuration
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # Change to TinyLlama if OOM
# Alternative: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print("=" * 60)
print(f"Loading Baseline Model: {MODEL_NAME}")
print("=" * 60)
print("⏳ This may take 1-3 minutes...\n")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token
print("✅ Tokenizer loaded")

# Load model in FP16
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,  # FP16 precision
    device_map="auto",          # Automatic device placement
    trust_remote_code=True      # Allow custom model code
)

print("✅ Model loaded in FP16")

# Memory usage
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    print(f"\n📊 GPU Memory Usage:")
    print(f"   Allocated: {memory_allocated:.2f} GB")
    print(f"   Reserved:  {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Model info
num_params = sum(p.numel() for p in model.parameters())
print(f"\n📝 Model Info:")
print(f"   Parameters: {num_params / 1e9:.2f}B")
print(f"   Precision: FP16 (2 bytes/param)")
print(f"   Estimated size: {num_params * 2 / 1e9:.2f} GB")

print("\n" + "=" * 60)
print("✅ Baseline model ready for pruning!")
print("=" * 60)

---
## Step 6: Prepare Calibration Data

Wanda requires calibration data to collect activation statistics. We'll use **WikiText-2**.

In [None]:
from datasets import load_dataset

print("=" * 60)
print("Loading Calibration Data")
print("=" * 60)

# Load WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

print(f"✅ Dataset loaded")
print(f"   Total samples: {len(dataset)}")
print(f"   Sample text: {dataset[0]['text'][:200]}...")

# Filter out empty texts
dataset = dataset.filter(lambda x: len(x['text'].strip()) > 0)
print(f"\n✅ Filtered dataset: {len(dataset)} non-empty samples")

print("=" * 60)

---
## Step 7: Prepare Calibration Samples

Create tokenized calibration samples for activation collection.

In [None]:
# Configuration
NSAMPLES = 128  # Number of calibration samples
SEQLEN = 2048   # Sequence length

print("=" * 60)
print("Preparing Calibration Samples")
print("=" * 60)
print(f"Samples: {NSAMPLES}")
print(f"Sequence length: {SEQLEN}\n")

# Tokenize and prepare samples
calibration_samples = []

from tqdm import tqdm

for i in tqdm(range(NSAMPLES), desc="Tokenizing"):
    if i >= len(dataset):
        break
    
    text = dataset[i]['text']
    
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=SEQLEN,
        truncation=True,
        padding="max_length"
    )
    
    calibration_samples.append(inputs['input_ids'])

# Stack into batch
calibration_batch = torch.cat(calibration_samples, dim=0)

print(f"\n✅ Calibration data prepared")
print(f"   Shape: {calibration_batch.shape}")
print(f"   Size: {calibration_batch.numel() * 4 / 1e6:.2f} MB")
print("=" * 60)

---
## Step 8: Test Baseline Inference

Verify the model works correctly before pruning.

In [None]:
# Test prompt
prompt = "The future of artificial intelligence is"

print("=" * 60)
print("Baseline Inference Test")
print("=" * 60)
print(f"Prompt: {prompt}\n")

# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Generate
import time
start_time = time.time()

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.8,
        top_p=0.9
    )

end_time = time.time()
latency = end_time - start_time

# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Output: {generated_text}\n")
print(f"⏱️  Latency: {latency:.2f} seconds")
print(f"📊 Tokens/sec: {len(outputs[0]) / latency:.2f}")
print("=" * 60)
print("✅ Inference test passed!")

---
## Step 9: Calculate Current Sparsity (Should be 0%)

Let's measure the baseline model's sparsity (should be near 0% for dense model).

In [None]:
def calculate_sparsity(model, threshold=1e-6):
    """
    Calculate model sparsity (percentage of near-zero weights)
    """
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_params += param.numel()
            zero_params += (param.abs() < threshold).sum().item()
    
    sparsity = zero_params / total_params if total_params > 0 else 0
    return sparsity, total_params, zero_params

print("=" * 60)
print("Baseline Model Sparsity")
print("=" * 60)

sparsity, total, zeros = calculate_sparsity(model)

print(f"Total parameters: {total / 1e9:.2f}B")
print(f"Near-zero params: {zeros / 1e6:.2f}M")
print(f"Sparsity: {sparsity:.4%}")
print(f"\n✅ Baseline sparsity is near 0% (expected for dense model)")
print("=" * 60)

---
## ✅ Setup Complete!

**Summary**:
- ✅ GPU verified (CUDA available)
- ✅ PyTorch sparse tensor support confirmed
- ✅ Libraries installed (transformers, datasets, accelerate)
- ✅ Baseline FP16 model loaded
- ✅ Calibration data prepared (128 samples)
- ✅ Inference test passed
- ✅ Baseline sparsity measured (~0%)

**Next Steps**:
1. Proceed to **02-Prune.ipynb** to apply Wanda pruning
2. Target: 50% sparsity (3.5B effective parameters)
3. Expected precision loss: <8% (Perplexity +0.44)

**Key Variables Available**:
- `model`: Baseline dense Llama-2-7B model
- `tokenizer`: Tokenizer for text processing
- `calibration_batch`: Tokenized calibration data (128 samples)
- `calculate_sparsity()`: Function to measure sparsity

---

**⏭️ Continue to**: [02-Prune.ipynb](./02-Prune.ipynb)