<a href="https://colab.research.google.com/github/EnmingTigerZhang/EnmingTigerZhang.github.io/blob/main/src/test_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import sys
from google.colab import drive
from types import SimpleNamespace

In [2]:
drive.mount('/content/drive')
%cd 'drive/MyDrive/CS182 Project/limits-of-icl/src'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/CS182 Project/limits-of-icl/src


In [3]:
!pip install wandb munch



In [4]:
import torch
import os
import sys
import uuid
import yaml
from munch import Munch

In [5]:
# --- 1. Setup Environment ---
# Ensure Python can find your modules.
if '.' not in sys.path:
    sys.path.append('.')

# Import the necessary functions from your project
from train import train
from wrapper_model import build_model
# All other necessary modules (tasks, samplers, etc.) will be imported by the above.

# --- 2. Define the Configuration as a Python Dictionary ---
# This replaces the need for YAML files and the quinine library.
# It combines the settings from softmax_test.yaml and models/standard.yaml.
config_dict = {
    "out_dir": "../models/nanogpt_softmax_test",
    "test_run": False,
    "model": {
        "family": "nanogpt",
        "attention_type": "softmax_causal",
        "attention_kwargs": {},
        "n_dims": 20,
        "n_positions": 41,
        "n_embd": 256,
        "n_layer": 12,
        "n_head": 8,
    },
    "training": {
        "task": "linear_regression",
        "task_kwargs": {},
        "data": "gaussian",
        "batch_size": 64,
        "learning_rate": 0.0001,
        "train_steps": 100001, # Shortened for a quick test
        "save_every_steps": 3000,
        "curriculum": {
            "dims": { "start": 5, "end": 20, "inc": 1, "interval": 2000 },
            "points": { "start": 11, "end": 41, "inc": 2, "interval": 2000 },
        },
    },
    "wandb": {
        "project": "in-context-training-colab",
        "entity": "your-wandb-entity", # CHANGE THIS if using wandb
        "name": "nanogpt-softmax-test-run-no-quinine",
        "log_every_steps": 10,
    },
}

In [6]:
# --- 3. Prepare for Training ---
# Convert the dictionary to a Munch object to allow dot notation (e.g., args.model.family)
args = SimpleNamespace()
args.out_dir = config_dict["out_dir"]
args.test_run = config_dict["test_run"]
args.wandb = SimpleNamespace(**config_dict["wandb"])
args.training = SimpleNamespace(**config_dict["training"])
args.training.curriculum = SimpleNamespace(**config_dict["training"]["curriculum"])
args.training.curriculum.dims = SimpleNamespace(**config_dict["training"]["curriculum"]["dims"])
args.training.curriculum.points = SimpleNamespace(**config_dict["training"]["curriculum"]["points"])
# The 'model' attribute is itself a namespace object
args.model = SimpleNamespace(**config_dict["model"])

#args.training.keep_every_steps = 1000

print("Checking for and applying default values...")

# Top-level defaults
if not hasattr(args, 'test_run'):
    args.test_run = False

# Model defaults
if not hasattr(args.model, 'attention_type'):
    args.model.attention_type = 'softmax_causal'
if not hasattr(args.model, 'attention_kwargs'):
    args.model.attention_kwargs = {}

# Training defaults
if not hasattr(args.training, 'num_tasks'):
    args.training.num_tasks = None
if not hasattr(args.training, 'num_training_examples'):
    args.training.num_training_examples = None
if not hasattr(args.training, 'batch_size'):
    args.training.batch_size = 64
if not hasattr(args.training, 'learning_rate'):
    args.training.learning_rate = 3e-4
if not hasattr(args.training, 'train_steps'):
    args.training.train_steps = 1000
if not hasattr(args.training, 'save_every_steps'):
    args.training.save_every_steps = 1000
if not hasattr(args.training, 'keep_every_steps'):
    args.training.keep_every_steps = -1
if not hasattr(args.training, 'resume_id'):
    args.training.resume_id = None

# Wandb defaults
if not hasattr(args.wandb, 'project'):
    args.wandb.project = 'in-context-training'
if not hasattr(args.wandb, 'entity'):
    args.wandb.entity = 'in-context'
if not hasattr(args.wandb, 'notes'):
    args.wandb.notes = ''
if not hasattr(args.wandb, 'name'):
    args.wandb.name = None
if not hasattr(args.wandb, 'log_every_steps'):
    args.wandb.log_every_steps = 10

print("Defaults applied successfully.")
print("-" * 40)

# Create output directory
if not os.path.exists(args.out_dir):
    os.makedirs(args.out_dir)
print(f"Output directory set to: {os.path.abspath(args.out_dir)}")
print("-" * 40)

Checking for and applying default values...
Defaults applied successfully.
----------------------------------------
Output directory set to: /content/drive/MyDrive/CS182 Project/limits-of-icl/models/nanogpt_softmax_test2
----------------------------------------


In [7]:
# --- 4. Run the Training ---
print("Building model...")
model = build_model(args.model)

# Move model to GPU if available
if torch.cuda.is_available():
    model.cuda()
model.train()

print("\nStarting training...")
train(model, args)

print("\nTraining finished successfully!")

Building model...
number of parameters: 9.48M

Starting training...


loss 7.3115644454956055:   0%|          | 0/100001 [00:00<?, ?it/s]

here?


loss 4.357875823974609:   1%|          | 505/100001 [00:22<1:02:56, 26.35it/s]

here?


loss 3.784058094024658:   1%|          | 1006/100001 [00:41<1:02:49, 26.26it/s]

here?


loss 2.952779769897461:   2%|▏         | 1504/100001 [01:00<1:04:04, 25.62it/s] 

here?


loss 3.96420955657959:   2%|▏         | 2005/100001 [01:19<1:03:47, 25.61it/s]

here?


loss 3.3156516551971436:   3%|▎         | 2503/100001 [01:40<1:09:42, 23.31it/s]

here?


loss 3.1475515365600586:   3%|▎         | 2998/100001 [02:02<1:08:20, 23.65it/s]

here?


loss 3.386669874191284:   4%|▎         | 3505/100001 [02:24<1:08:45, 23.39it/s]

here?


loss 2.582371711730957:   4%|▎         | 3620/100001 [02:29<1:06:17, 24.23it/s]


KeyboardInterrupt: 