# Lab 2.4.6: Mamba Fine-tuning - SOLUTIONS

Complete solutions for the Mamba fine-tuning exercises.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

try:
    from peft import LoraConfig, get_peft_model, TaskType
    HAS_PEFT = True
except ImportError:
    HAS_PEFT = False
    print('PEFT not installed. Run: pip install peft')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Exercise Solution: Custom Domain Fine-tuning

In [None]:
# Example: Customer Service Domain Dataset

customer_service_examples = [
    {
        'instruction': 'Respond to a customer complaint about late delivery.',
        'input': 'My order was supposed to arrive 3 days ago and it still has not arrived!',
        'output': 'I sincerely apologize for the delay in your delivery. I understand how frustrating this must be. Let me check the status of your order right away and provide you with an updated delivery estimate. Would you like me to expedite the shipping at no extra cost?'
    },
    {
        'instruction': 'Help a customer with a refund request.',
        'input': 'I want to return this product. It does not work as advertised.',
        'output': 'I am sorry to hear the product did not meet your expectations. We would be happy to process a full refund for you. Could you please provide your order number? We can also arrange a prepaid return label to be sent to your email.'
    },
    {
        'instruction': 'Answer a question about store hours.',
        'input': 'What time do you close on Saturdays?',
        'output': 'Our store hours on Saturdays are 9 AM to 6 PM. Is there anything specific you are planning to visit us for? I would be happy to check if the item you need is in stock.'
    },
    {
        'instruction': 'Respond to a billing inquiry.',
        'input': 'I see a charge I do not recognize on my statement.',
        'output': 'I understand your concern about the unrecognized charge. For your security, let me help you review this. Could you please provide the last 4 digits of the card and the date of the charge? I will look into this immediately and ensure any unauthorized charges are reversed.'
    },
] * 50  # Repeat for training data

print(f'Customer Service Dataset: {len(customer_service_examples)} examples')
print(f'\nExample:')
print(f'Instruction: {customer_service_examples[0]["instruction"]}')
print(f'Input: {customer_service_examples[0]["input"]}')
print(f'Output: {customer_service_examples[0]["output"][:100]}...')

In [None]:
# Format for training
def format_customer_service(example):
    return f'''### Customer Query:
{example['input']}

### Instruction:
{example['instruction']}

### Response:
{example['output']}'''

formatted_data = [format_customer_service(ex) for ex in customer_service_examples]
print('Formatted example:')
print(formatted_data[0])

In [None]:
# Auto-detect LoRA targets for Mamba
import torch.nn as nn

def find_lora_targets(model):
    """Find linear layers suitable for LoRA."""
    linear_layers = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            layer_type = name.split('.')[-1]
            if layer_type not in linear_layers:
                linear_layers.append(layer_type)
    return linear_layers

# LoRA configuration for customer service fine-tuning
if HAS_PEFT:
    # Preferred Mamba target modules
    preferred_targets = ['in_proj', 'out_proj', 'x_proj', 'dt_proj']
    
    print('LoRA Configuration:')
    print(f'  Preferred targets: {preferred_targets}')
    print('  Note: Actual targets will be auto-detected from model architecture')
    print('  Example code:')
    print('''
    available = find_lora_targets(model)
    valid_targets = [t for t in preferred_targets if t in available]
    
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=valid_targets,
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    ''')

In [None]:
# Training configuration (for reference)
training_config = '''
TrainingArguments(
    output_dir="./mamba-customer-service",
    num_train_epochs=3,  # More epochs for domain adaptation
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    warmup_ratio=0.1,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="epoch",
    gradient_checkpointing=True,
)
'''
print('Recommended Training Configuration:')
print(training_config)

In [None]:
# Evaluation prompts for customer service
eval_prompts = [
    ('Respond to angry customer', 'This is the worst service I have ever received!'),
    ('Handle refund request', 'I need my money back immediately.'),
    ('Answer shipping question', 'How long will it take to receive my order?'),
    ('Resolve technical issue', 'The website is not letting me complete my purchase.'),
]

print('Evaluation Prompts for Fine-tuned Model:')
print('=' * 60)
for instruction, query in eval_prompts:
    print(f'\n Instruction: {instruction}')
    print(f'   Query: {query}')

print('\n After fine-tuning, the model should:')
print('- Respond with professional, empathetic tone')
print('- Offer specific solutions')
print('- Follow customer service best practices')
print('- Use consistent formatting')

In [None]:
# Memory comparison for fine-tuning
print('\n Memory Requirements for Mamba Fine-tuning:')
print('=' * 60)

models = [
    ('Mamba-130M', 0.13, 0.5),
    ('Mamba-1.4B', 1.4, 5),
    ('Mamba-2.8B', 2.8, 10),
]

print(f'{"Model":<15} {"Full FT (GB)":<15} {"LoRA (GB)":<15} {"Savings":<10}')
print('-' * 55)

for name, params_b, lora_params_m in models:
    full_ft = params_b * 2 * 6  # model + grad + optimizer
    lora_ft = (params_b * 2) + (lora_params_m / 1000 * 6)
    savings = (1 - lora_ft / full_ft) * 100
    print(f'{name:<15} {full_ft:<15.1f} {lora_ft:<15.1f} {savings:.0f}%')