In [32]:
# Inspect a given model
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
import matplotlib.pyplot as plt

from metaparc.model.transformer.model import get_model
from metaparc.run.train import get_lr_scheduler

def inspect_model(model: nn.Module):

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")

    # Print the number of parameters for each layer
    for name, module in model.named_modules():
        print(f"{name}: {sum(p.numel() for p in module.parameters())}")

In [None]:
config_path = Path("/Users/zsa8rk/Coding/MetaPARC/metaparc/run/config.yaml")

with open(config_path, "r") as f:
    config = yaml.safe_load(f)
model_config = config["model"]


model = get_model(model_config)
print(model)
inspect_model(model)

In [None]:
# Inspect learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=config["training"]["lr_scheduler"]["learning_rate"])
cosine_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10000
    )

# plot learning rate
lr = []
for epoch in range(1):
    for batch in range(100000):
        optimizer.step()
        cosine_scheduler.step()
        lr.append(cosine_scheduler.get_last_lr()[0])

plt.plot(lr)
