[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
https://colab.research.google.com/github/CMU-IDeeL/CMU-IDeeL.github.io/blob/master/F25/document/Recitation_0_Series/0.24/0_24_Checkpointing.ipynb)

#**Recitation 0: Checkpointing**

We will show you how to checkpoint and load your model :D

##**Section 0: Setup**

Let's define a quick dummy model, optimizer, and scheduler that we'll be saving and loading :)

In [None]:
!pip install wandb --quiet # Install WandB

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

os.environ['WANDB_API_KEY'] = ""#your key here
wandb.login()

Device:  cpu


In [None]:
# A simple submodule
class DummySubmodule(nn.Module):
    def __init__(self):
        super(DummySubmodule, self).__init__()
        self.layer = nn.Linear(in_features = 32, out_features = 32)

    def forward(self, x):
        return self.layer(x)

# A simple network
class DummyNetwork(nn.Module):

    def __init__(self):
        super(DummyNetwork, self).__init__()

        self.lower_layer = nn.Sequential(
            nn.Linear(in_features = 32, out_features = 64),
            nn.ReLU(),
            nn.Linear(in_features = 32, out_features = 64),
            nn.ReLU(),
        )
        self.upper_layer = nn.Sequential(
            nn.Linear(in_features = 64, out_features = 32),
            nn.ReLU(),
        )

        self.module1 = DummySubmodule()

    def forward(self, x):
        res = self.lower_layer(x)
        res = self.upper_layer(res)
        res = self.submodule(res)
        return res

# Declare the model, optimizer, and scheduler
model = DummyNetwork().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

Let's take a look at some of the information we can checkpoint

In [None]:
# Print model's state_dict
print("==============================================================")
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("==============================================================")
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
print("==============================================================")
# Print scheduler's state_dict
print("\nScheduler's state_dict:")
for var_name, value in scheduler.state_dict().items():
    print(var_name, "\t", value)
print("==============================================================")

---                              

##**Section 1: How to save the checkpoint Saving a checkpoint**

*(This is typically placed inside the training loop, often at the end of each epoch or after validation.)*

Checkpointing locally

In [None]:
# let's pretend we're in the middle of our training
epoch = 6 # pretend we're in our 6th epoch
loss = 0.78 #pretend this is our model's loss at the moment

checkpoint_path=f"<run_name>_{epoch}.pth"

# Saving your states locally with torch.save
torch.save({
    'model_state_dict': model.state_dict(),   # saving the model state
    # if isinstance(model, nn.DataParallel) 'model_state_dict': model.module.state_dict()
    'optimizer_state_dict': optimizer.state_dict(),   # saving the optimizer state
    'scheduler_state_dict': scheduler.state_dict(),   # saving the scheduler state
    'epoch': epoch,
    'current_loss': loss
    }, checkpoint_path
)


Checkpointing and saving to wandb as an artifact  
❗<strong><small>Remember to run <code>wandb.init()</code> before logging, or saving will fail.</small></strong>


In [None]:
# Before the run, you need to have started a run like so....
run = wandb.init(
    project="wandb-quickstart",
    name="<run_name>",
    )

# ...
# ...
# ...
# Within a training loop (or wherever else you want)....

# Option 1:
# create artifacts (keeps track of versioning, and is much more organized to work with between collaborators)
checkpoint_artifact = wandb.Artifact("<run_name>", type="checkpoint") # You can switch type="model if you only want to save a model"

checkpoint_artifact.add_file(checkpoint_path)

run.log_artifact(checkpoint_artifact)

# Option 2:
# directly save the model to wandb
wandb.save(checkpoint_path, base_path=os.path.dirname(checkpoint_path))

---

##**Section 2\: Loading a checkpoint file into our current model**

*(This should be placed before training starts, after defining your model and optimizer.)*

Downloading a model from wandb

In [None]:
# METHOD 1: Download from wandb Artifact
# If you need to re-obtain the run, you can do the following....
api = wandb.Api()
# information can be obtained from the wandb link adddress as follows:
# https://wandb.ai/<USERNAME>/<PROJECT_NAME>/runs/<RUN_ID>?nw=nwuser<USERNAME>
run = api.run("<USERNAME>/<PROJECT_NAME>/<RUN_ID>")

# To retrieve the artifact....
# Get the artifact (choose which version of the model you want)
artifact = run.use_artifact('<run_name>:latest')
# Downloading the artifact
artifact_dir = artifact.download()
# Loading the model dict
checkpoint_dict = torch.load(os.path.join(artifact_dir, '<run_name>'))


# METHOD 2: Download the directly saved file from wandb to Local File
checkpoint_file = wandb.restore('<run_name>', run_path="<USERNAME>/<PROJECT_NAME>/<RUN_ID>").name
checkpoint_dict = torch.load(checkpoint_file)

Loading a .pth checkpoint file from our local directory to our model

In [None]:
# .pth checkpoint file path can also be obtained from a locally saved .pth file. Or, you can use the checkpoint_dict obtained from the prior wandb artifact download :)
checkpoint_path = "/content/<run_name>_6.pth"
checkpoint_dict = torch.load(checkpoint_path)


# Loading model weights
# if isinstance(model, nn.DataParallel) model.module.load_state_dict(checkpoint_dict['model_state_dict'])
model.load_state_dict(checkpoint_dict['model_state_dict'])
# Loading optimizer state
optimizer.load_state_dict(checkpoint_dict['optimizer_state_dict'])
# Loading the scheduler state
scheduler.load_state_dict(checkpoint_dict['scheduler_state_dict'])

# Done!!!!!!

In [None]:
# If you want to load specific parts of your model (in our case, we can load just the lower layers or just the upper layers)
specific_weights = { # Creates dictionary of only desired weights
    key: value
    for key, value in checkpoint_dict['model_state_dict'].items()
    if 'lower_layer' in key
}

model.load_state_dict(specific_weights, strict=False)

---

## Summary and Some Reminders

### Checkpoint Saving — Where

You can save to either or both of the following locations. **Saving to both is safer**, but if you're saving many checkpoints (e.g., every few epochs without overwriting), wandb storage may become a problem.

- **wandb**: Cloud backup, useful against crashes. But has storage limits.  
- **local**: More flexible, but make sure to save in a persistent path (e.g., Google Drive, Kaggle working dir, or PSC persistent storage). ❗**Do not save to temporary environments**—files will be lost once the session ends or breakdown occurs.

---
Below are some checkpoint save and load strategies that may be useful in HWP2. Feel free to try other approaches that work best for you :)


### Checkpoint Saving — Strategy

- **Save only the best and the last**: most common, low storage usage, but you can’t retrieve intermediate epochs.  
- **Save every N epochs** (e.g., every 5 or 10): can be useful during early experimentation or to monitor overfitting, but takes more space and is usually turned off later. (not recommended unless specifically needed)

---

### Checkpoint Loading — Strategy

- **Load all components and continue training**: must ensure model architecture, optimizer, scheduler, and (if used) AMP scaler match exactly.  
-**Load only part of the checkpoint** (e.g., model weights): allows resetting the optimizer or changing the learning rate schedule.  
- **Load pretrained weights for partial initialization**: common in transformers; can be combined with freezing and unfreezing certain layers.

❗If continuing training, make sure to **restore the previous epoch count**. Otherwise, logging will restart from epoch 0, which can break plots in wandb or mislead learning rate scheduling.

---

### wandb.init Reminder

❗**Always run `wandb.init()` before logging or saving to wandb**. If skipped, it may cause runtime errors or failed uploads.



<p align="center">
Have fun navigating the world of saving, loading, and training —<br>
&emsp;&emsp;Train well, save wisely, load carefully, and may your final model be worth it all. 🙂
</p>
