In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
from pathlib import Path

import lightning as L
import torch as t
from data import EOT_TOKEN, WikitextDataset
from model import LangNet
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torch.utils.data import DataLoader
import wandb

In [3]:
DATAROOT = Path.home() / "mldata" / "wikitext-2-raw"
RUNROOT = Path.home() / "mlruns" / "makemore-2"

In [4]:
@dataclass
class HyperParams:
    context_len: int = 5
    batch_size: int = 1000
    emb_dim = 256
    learning_rate: float = 0.1
    hidden_len: int = 100
    max_epochs: int = 3
    learning_rate_exp_decay = 0.8

hparams = HyperParams()

In [5]:
dataset = WikitextDataset(
    DATAROOT / "wiki.train.raw", 
    context_len=hparams.context_len
)

print(len(dataset))

2075677


In [6]:
dl = DataLoader(dataset, batch_size=hparams.batch_size, shuffle=True)

In [7]:
lang_net = LangNet(
    vocab_len=len(dataset.vocab),
    context_len=hparams.context_len,
    hidden_len=hparams.hidden_len,
    emb_dim=hparams.emb_dim,
    learning_rate=hparams.learning_rate,
    learning_rate_exp_decay=hparams.learning_rate_exp_decay,
    batch_size=hparams.batch_size  # only passing so it can be logged with other hparams
)

In [8]:
wandb.finish()
logger = WandbLogger(project="Makemore2", save_dir=RUNROOT, log_model="all")
# logger = CSVLogger(save_dir=RUNROOT)
trainer = L.Trainer(
    default_root_dir=RUNROOT,
    max_epochs=hparams.max_epochs,
    logger=logger,
    callbacks=[LearningRateMonitor(logging_interval="epoch")]
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(lang_net, dl)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mavilay[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name    | Type             | Params
---------------------------------------------
0 | model   | LangModel        | 27.5 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
27.5 M    Trainable params
0         Non-trainable params
27.5 M    Total params
109.921   Total estimated model params size (MB)
/Users/avilayparekh/miniconda3/envs/ai/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
