# Save & Load an LLM (Checkpointing)  

We’ll:
1. Train a tiny toy "model" (just a weight and bias) so we have something to save.
2. **Save only the model weights**.
3. **Load** those weights into fresh parameters.
4. **Save model + optimizer** together (best practice for resuming).
5. **Load and resume** training from that checkpoint.
6. Show how to make runs more **deterministic** (less randomness).

> Runs on **CPU**. No GPU required.


## 0) Setup (imports + seeds)

In [None]:
# Import PyTorch and set a random seed so results are reproducible
import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(123)  # Make random numbers repeatable for this demo
device = "cpu"          # Force CPU usage to keep it simple


## 1) A tiny toy dataset (so we can demonstrate saving/loading)
We pretend the true relationship is `y = 3x + 2`.  
We'll let a tiny "model" learn that (just a weight `w` and bias `b`).

In [None]:
# Create simple inputs (x) and targets (y_true)
x_batch = torch.linspace(0, 1, steps=16).unsqueeze(1)  # 16 inputs between 0 and 1, shape [16,1]
y_true  = 3.0 * x_batch + 2.0                          # True outputs for our demo
print(x_batch[:5].flatten(), "->", y_true[:5].flatten())

## 2) A tiny "model" (no classes) + Optimizer + Loss
Think of this as the LLM's brain in miniature — just two parameters (`w` and `b`).

In [None]:
# Initialize the model's parameters randomly
w = torch.randn(1, requires_grad=True)   # model weight
b = torch.randn(1, requires_grad=True)   # model bias

# AdamW is our "nutritionist/coach" that gently updates w and b
optimizer = optim.AdamW([w, b], lr=1e-2, weight_decay=0.01)

# We'll measure error using Mean Squared Error (how far our guess is)
loss_fn = nn.MSELoss()

print("Initial w,b:", w.item(), b.item())

## 3) Quick training  
We do a few steps so the model learns something worth saving.

In [None]:
# Train for a handful of steps — predict -> compute loss -> backprop -> update
for step in range(200):
    optimizer.zero_grad()          # Clear old gradient info (yesterday's mistakes)
    y_pred = x_batch * w + b       # Model's current guess (y = w*x + b)
    loss = loss_fn(y_pred, y_true) # How wrong is the guess?
    loss.backward()                # Compute feedback (gradients)
    optimizer.step()               # Update w and b a little bit

print("After training — w,b:", w.item(), b.item())
print("Training loss:", float(loss))

## 4) Save **only** the model weights (like saving the chef's skills)
In a real `nn.Module` you'd save `state_dict()`.  
Here, since we have no class, we **pack the tensors ourselves**.

In [None]:
# Build a simple "state dict" manually and save to disk
model_state = {
    "w": w.detach().clone(),   # store learned weight
    "b": b.detach().clone()    # store learned bias
}

torch.save(model_state, "model.pth")
print("Saved model weights to: model.pth")

## 5) Load the model weights into **fresh** parameters  
Imagine a new session - we start with random params and **load** the saved ones in.

In [None]:
# Start with random parameters (as if in a new Python session)
w_loaded = torch.randn(1, requires_grad=True)
b_loaded = torch.randn(1, requires_grad=True)

# Load from disk and copy into these new parameters
checkpoint = torch.load("model.pth", map_location=device)
with torch.no_grad():
    w_loaded.copy_(checkpoint["w"])
    b_loaded.copy_(checkpoint["b"])

print("Loaded w,b:", w_loaded.item(), b_loaded.item())

# Conceptual 'eval' mode: in a real Module you'd call model.eval() to disable dropout, etc.
is_eval_mode = True
print("Eval mode flag:", is_eval_mode)

## 6) Save **model + optimizer** (best practice for resuming training)
This keeps the optimizer's internal state (momentum, etc.) so training can continue smoothly.

In [None]:
both_state = {
    "model": {"w": w.detach().clone(), "b": b.detach().clone()},
    "optim": optimizer.state_dict()
}
torch.save(both_state, "model_and_optimizer.pth")
print("Saved full checkpoint to: model_and_optimizer.pth")

## 7) Load the full checkpoint and **resume training**
We create fresh parameters and a fresh optimizer, then restore both from the checkpoint.

In [None]:
# New random params and optimizer (as if a fresh start)
w2 = torch.randn(1, requires_grad=True)
b2 = torch.randn(1, requires_grad=True)
optimizer2 = optim.AdamW([w2, b2], lr=1e-2, weight_decay=0.01)

# Load checkpoint and restore
ckpt = torch.load("model_and_optimizer.pth", map_location=device)
with torch.no_grad():
    w2.copy_(ckpt["model"]["w"])
    b2.copy_(ckpt["model"]["b"])
optimizer2.load_state_dict(ckpt["optim"])

# Continue training for a few more steps
for step in range(50):
    optimizer2.zero_grad()
    y_pred2 = x_batch * w2 + b2
    loss2 = loss_fn(y_pred2, y_true)
    loss2.backward()
    optimizer2.step()

print("Resumed w,b:", w2.item(), b2.item())
print("Loss after resume:", float(loss2))

## 8) Bonus — Deterministic generation/settings (no randomness)
To make runs repeatable:
- Fix random seeds (we did `torch.manual_seed(123)`).
- Use `model.eval()` in real modules to turn off dropout.
- For text generation: set **temperature=0**, disable **top-k / top-p**, and pick **argmax** at each step (greedy).

In [None]:
# Demonstrate repeatable prediction with fixed weights
torch.manual_seed(123)      # seed fixed
x_demo = torch.tensor([[0.25], [0.75]])
with torch.no_grad():
    y_demo = x_demo * w2 + b2   # deterministic given w2,b2 and fixed seed above
print("Deterministic demo outputs:", y_demo.flatten().tolist())