### Step 1: Install necesscary packages


In [None]:
#!pip install matplotlib
#!pip install torch numpy transformers datasets tiktoken wandb tqdm

### Step 2: Package imports and configuration


In [None]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt

# Configuration
beta = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4
epochs = 5
batch_size = 64
max_length = 64
num_samples = 1
max_new_tokens = 200
temperature = 0.8
top_k = 200

print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load tokenizer with error handling
try:
    with open("../sft/meta.pkl", "rb") as f:
        meta = pickle.load(f)
    stoi, itos = meta["stoi"], meta["itos"]
    print(f"✅ Loaded tokenizer with {len(itos)} tokens")
except Exception as e:
    print(f"❌ Error loading tokenizer: {e}")
    raise

def encode(s): 
    """Encode string to token IDs with bounds checking"""
    try:
        return [stoi.get(c, 0) for c in s]  # Use .get() to handle missing chars
    except Exception as e:
        print(f"Encoding error for '{s}': {e}")
        return [0]  # Return padding token on error

def decode(l): 
    """Decode token IDs to string with bounds checking"""
    try:
        # Ensure all tokens are within vocabulary bounds
        valid_tokens = [i for i in l if isinstance(i, int) and 0 <= i < len(itos)]
        return ''.join([itos[i] for i in valid_tokens])
    except Exception as e:
        print(f"Decoding error for {l}: {e}")
        return ""

print("✅ Configuration and tokenizer loaded successfully!")

### Step 3: Define helper functions


In [None]:
def compute_logprob(input_ids):
    inputs = input_ids[:, :-1]
    targets = input_ids[:, 1:]
    logits, _ = gpt(inputs, full_seq=True)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    return -loss 

def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

def get_batches(lines, batch_size):
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        if len(batch) < batch_size:
            continue
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model


In [None]:
try:
    print("Loading pretrained NanoGPT model...")
    ckpt = torch.load("../sft/gpt.pt", map_location=device, weights_only=False)
    
    gptconf = GPTConfig(**ckpt['model_args'])
    gpt = GPT(gptconf)
    
    # Clean state dict
    state_dict = ckpt['model']
    unwanted_prefix = '_orig_mod.'
    for k in list(state_dict.keys()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    
    gpt.load_state_dict(state_dict)
    gpt.to(device).train()
    
    print("✅ Model loaded successfully!")
    print(f"Model parameters: {sum(p.numel() for p in gpt.parameters()):,}")
    print(f"Model config: {gptconf}")
    
    # Test model with a simple forward pass
    test_input = torch.randint(0, len(itos), (1, 10), device=device)
    with torch.no_grad():
        logits, _ = gpt(test_input, full_seq=True)
        print(f"✅ Model forward pass test successful! Output shape: {logits.shape}")
    
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("Please check if ../sft/gpt.pt exists and is accessible")
    raise

### Generate 100K Negative Responses with GPT Model


In [None]:
import time
from tqdm import tqdm
import gc
import psutil
import os

print("=== ROBUST 100K GENERATION WITH AUTO-RESTART CAPABILITY ===")

# Load all questions
try:
    with open("questions_empty.json", "r") as f:
        questions = json.load(f)
    print(f"✅ Loaded {len(questions)} questions from questions_empty.json")
except Exception as e:
    print(f"❌ Error loading questions: {e}")
    raise

def find_latest_checkpoint():
    """Find the most recent checkpoint file"""
    checkpoint_files = [f for f in os.listdir('.') if f.startswith('temp_responses_')]
    if checkpoint_files:
        checkpoint_nums = [int(f.split('_')[2].split('.')[0]) for f in checkpoint_files]
        latest_num = max(checkpoint_nums)
        latest_file = f"temp_responses_{latest_num}.json"
        return latest_file, latest_num
    return None, 0

def load_existing_progress():
    """Load existing progress with error handling"""
    latest_checkpoint, latest_count = find_latest_checkpoint()
    existing_responses = []
    
    if latest_checkpoint:
        try:
            with open(latest_checkpoint, "r") as f:
                existing_responses = json.load(f)
            print(f"🔄 Resuming from: {latest_checkpoint} ({len(existing_responses)} responses)")
            
            # Verify data integrity
            if len(existing_responses) != latest_count:
                print(f"⚠️ Data mismatch: file has {len(existing_responses)}, expected {latest_count}")
                # Use actual length
                
        except Exception as e:
            print(f"⚠️ Could not load {latest_checkpoint}: {e}")
            # Try to find working checkpoint
            checkpoint_files = sorted([f for f in os.listdir('.') if f.startswith('temp_responses_')], 
                                    key=lambda x: int(x.split('_')[2].split('.')[0]), reverse=True)
            
            for backup_file in checkpoint_files[1:]:  # Skip the corrupted one
                try:
                    with open(backup_file, "r") as f:
                        existing_responses = json.load(f)
                    print(f"✅ Loaded backup: {backup_file} ({len(existing_responses)} responses)")
                    break
                except:
                    continue
    
    return existing_responses

# Configuration for stability
num_samples = min(100000, len(questions))
existing_responses = load_existing_progress()
start_idx = len(existing_responses)
remaining_samples = num_samples - start_idx

print(f"🎯 Target: {num_samples:,} samples")
print(f"📊 Progress: {len(existing_responses):,} completed")
print(f"🚀 Remaining: {remaining_samples:,} samples")

def ultra_lightweight_generation(prompt):
    try:
        gpt.eval()
        
        # Very short prompt processing
        clean_prompt = prompt.strip()[:30]  
        prompt_tokens = encode(clean_prompt)
        valid_tokens = [t for t in prompt_tokens if isinstance(t, int) and 0 <= t < len(itos)]
        
        if len(valid_tokens) == 0 or len(valid_tokens) > 20:  
            return "I don't know"
        
        with torch.no_grad():
            input_ids = torch.tensor([valid_tokens], dtype=torch.long, device=device)
            
            try:
                # VERY conservative settings
                generated_ids, _ = gpt.generate(
                    input_ids, 
                    max_new_tokens=20,    
                    temperature=0.7,     
                    top_k=30            
                )
                
                new_tokens = generated_ids[0].tolist()[len(valid_tokens):]
                valid_new_tokens = [t for t in new_tokens if isinstance(t, int) and 0 <= t < len(itos)]
                
                if len(valid_new_tokens) == 0:
                    return "I don't know"
                
                new_text = decode(valid_new_tokens).strip()
                
                # Simple cleanup
                new_text = new_text.replace('\n', ' ').strip()
                
                # Short length limit
                if len(new_text) > 30:
                    new_text = new_text[:30].split()
                    new_text = ' '.join(new_text[:-1]) if len(new_text) > 1 else new_text[0]
                
                return new_text if new_text else "I don't know"
                    
            except Exception:
                # Immediate cleanup
                torch.cuda.empty_cache()
                return "I don't know"
                
    except Exception:
        return "I don't know"

def emergency_cleanup():
    """Aggressive memory management"""
    try:
        # Multiple cleanup passes
        for _ in range(2):
            gc.collect()
            
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
        gc.collect()
        
    except Exception as e:
        print(f"⚠️ Cleanup warning: {e}")

def safe_save(data, filename):
    """Safe file saving with backup"""
    try:
        # Save to temporary file first
        temp_file = f"temp_{filename}"
        with open(temp_file, "w") as f:
            json.dump(data, f, indent=1)
        
        # Move to final location
        import shutil
        shutil.move(temp_file, filename)
        return True
    except Exception as e:
        print(f"⚠️ Save error: {e}")
        return False

# Main generation loop with crash recovery
if remaining_samples <= 0:
    print(f"✅ All {num_samples:,} samples already completed!")
else:
    print(f"\n🚀 Starting generation from sample {start_idx:,}")
    
    # Ultra conservative settings for stability
    batch_size = 25          # Small batches
    save_frequency = 100     # Frequent saves
    cleanup_frequency = 20   # Very frequent cleanup
    
    print(f"🔧 Settings: batch={batch_size}, save_freq={save_frequency}, cleanup_freq={cleanup_frequency}")
    
    negative_responses = existing_responses.copy()
    start_time = time.time()
    last_save_count = len(negative_responses)
    
    try:
        # Create progress bar
        with tqdm(total=remaining_samples, desc="Generating", unit="samples", 
                 initial=0, position=0, leave=True) as pbar:
            
            for i in range(start_idx, num_samples, batch_size):
                batch_end = min(i + batch_size, num_samples)
                batch_questions = questions[i:batch_end]
                
                batch_results = []
                for j, q in enumerate(batch_questions):
                    try:
                        question = q["negative"].strip()
                        
                        # Generate with minimal resources
                        gpt_response = ultra_lightweight_generation(question)
                        
                        # Simple formatting
                        batch_results.append({
                            "negative": f"{question} {gpt_response}.",
                            "positive": q["positive"].strip()
                        })
                        
                        # Micro-cleanup every few generations
                        if (j + 1) % 10 == 0:
                            emergency_cleanup()
                        
                    except Exception:
                        # Silent fallback
                        safe_q = question if 'question' in locals() else f"Q{i+j}"
                        batch_results.append({
                            "negative": f"{safe_q} I don't know.",
                            "positive": q.get("positive", "Answer").strip()
                        })
                
                # Add batch results
                negative_responses.extend(batch_results)
                current_count = len(negative_responses)
                
                # Update progress
                progress_made = current_count - last_save_count
                pbar.update(progress_made)
                last_save_count = current_count
                
                # Show current stats
                elapsed = time.time() - start_time
                if elapsed > 0:
                    rate = (current_count - start_idx) / elapsed
                    pbar.set_postfix({
                        'count': current_count,
                        'rate': f'{rate:.1f}/s'
                    })
                
                # Frequent cleanup
                if current_count % cleanup_frequency == 0:
                    emergency_cleanup()
                
                # Frequent checkpointing
                if current_count % save_frequency == 0:
                    checkpoint_file = f"temp_responses_{current_count}.json"
                    if safe_save(negative_responses, checkpoint_file):
                        print(f"\n💾 Checkpoint: {current_count} samples saved")
                    else:
                        print(f"\n⚠️ Checkpoint save failed at {current_count}")
                
                # Brief pause to prevent overheating
                if (current_count - start_idx) % 100 == 0:
                    time.sleep(0.05)
                
                # Progress check
                if current_count >= num_samples:
                    break
        
        # Final processing
        final_count = len(negative_responses)
        elapsed_total = time.time() - start_time
        
        print(f"\n✅ Generation completed!")
        print(f"📊 Final count: {final_count:,} samples")
        print(f"⏱️ Total time: {elapsed_total/60:.1f} minutes")
        
        # Save final results
        final_output = "negative_responses_100k_gpt.json"
        print(f"\n💾 Saving final results...")
        
        if safe_save(negative_responses, final_output):
            print(f"✅ Final save successful: {final_output}")
        else:
            # Emergency save with timestamp
            emergency_file = f"emergency_final_{int(time.time())}.json"
            if safe_save(negative_responses, emergency_file):
                print(f"💾 Emergency save: {emergency_file}")
            else:
                print(f"❌ All saves failed! Data in memory only.")
        
    except KeyboardInterrupt:
        current_count = len(negative_responses)
        print(f"\n⚠️ Interrupted at {current_count:,} samples")
        
        # Save progress
        interrupt_file = f"interrupt_save_{current_count}.json"
        if safe_save(negative_responses, interrupt_file):
            print(f"💾 Progress saved: {interrupt_file}")
        
    except Exception as e:
        current_count = len(negative_responses)
        print(f"\n❌ Error at {current_count:,} samples: {str(e)[:100]}")
        
    finally:
        # Final cleanup
        emergency_cleanup()
        print("🧹 Cleanup completed")

### Step 5: Load Data (**students are required to complete this part!**)


In [None]:
# Load data from ./data/pos_neg_pairs.json

### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)


In [None]:
# recommend to use the AdamW optimizer 

### Step 7: Begin training (**students are required to complete this part!**)


In [None]:
total_steps = len(lines) // batch_size
for epoch in range(epochs):
    pbar = tqdm(get_batches(lines, batch_size))
    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        ###########################################################
        # Please complete the training code here!
        # Examples: 
        # ...
        # neg_logprob
        # pos_logprob 
        # loss = -F.logsigmoid((pos_logprob - neg_logprob) / beta).mean() - pos_logprob.mean() * 0.1 
        # ...
        ###########################################################
    ckpt_path = f"./dpo.pt"
    torch.save({
        "model_state_dict": gpt.state_dict(),
        "model_args": ckpt['model_args'],
    }, ckpt_path)
    print(f"Saved checkpoint to {ckpt_path}")

### Step 8: Begin testing (**students are required to complete this part!**)


In [None]:
# Load the fine-tuned model
ckpt_path = "../dpo/dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).cuda()
try:
    state_dict = checkpoint['model']
except:
    state_dict = checkpoint['model_state_dict']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)
# Test
gpt.eval()
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "3*17=?", "72/4=?", "72-x=34,x=?"]
with torch.no_grad():
    for prompt in test_set: 
        prompt_ids = encode(prompt)
        ###########################################################
        # Please complete the test code here!
        # ...
        # gpt.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        # ...
        ###########################################################