diff --git a/README.md b/README.md index 1e8bd1c..a34cfee 100644 --- a/README.md +++ b/README.md @@ -66,8 +66,8 @@ pip install -e . Use the following dummy script to make sure everything is working as expected ```bash -WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset -WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset +WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu +WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu ``` Then make use of the training script @@ -81,7 +81,8 @@ python -m kan_gpt.train - [x] Dataset downloading script for [WebText](https://github.com/openai/gpt-2-output-dataset) - [x] PyTorch Dataset parser for [WebText](https://github.com/openai/gpt-2-output-dataset) - [ ] Mini training POC for KAN-GPT - - [ ] Integrate KAN training logic from `KAN.train_kan` + - [x] Integrate KAN training logic from `KAN.train_kan` + - [ ] Train a dummy batch - [x] Mini training POC for MLP-GPT - [x] Train MLP-GPT on the webtext dataset as a baseline - [ ] Auto Save checkpoints @@ -89,9 +90,9 @@ python -m kan_gpt.train - [ ] Script to load checkpoint in interactive mode - [ ] Training script to PyTorch Lighting - [ ] Test Cases - - [ ] KAN: Forward-Backward test - - [ ] GPT: Forward-Backward test - - [ ] KAN_GPT: Forward-Backward test + - [x] KAN: Forward-Backward test + - [x] GPT: Forward-Backward test + - [x] KAN_GPT: Forward-Backward test ## Development diff --git a/kan_gpt/model.py b/kan_gpt/model.py index 0e01057..37f72e3 100644 --- a/kan_gpt/model.py +++ b/kan_gpt/model.py @@ -230,6 +230,61 @@ def __init__(self, config): n_params = sum(p.numel() for p in self.transformer.parameters()) print("number of parameters: %.2fM" % (n_params / 1e6,)) + def kan_loss( + self, + x: torch.Tensor, + lamb_l1=1.0, + lamb_entropy=2.0, + lamb_coef=0.0, + lamb_coefdiff=0.0, + small_mag_threshold=1e-16, + small_reg_factor=1.0, + ): + + def reg(mod): + + def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): + return (x < th) * x * factor + (x > th) * ( + x + (factor - 1) * th + ) + + reg_ = 0.0 + for i in range(len(mod.acts_scale)): + vec = mod.acts_scale[i].reshape( + -1, + ) + + p = vec / torch.sum(vec) + l1 = torch.sum(nonlinear(vec)) + entropy = -torch.sum(p * torch.log2(p + 1e-4)) + reg_ += ( + lamb_l1 * l1 + lamb_entropy * entropy + ) # both l1 and entropy + + # regularize coefficient to encourage spline to be zero + for i in range(len(mod.act_fun)): + coeff_l1 = torch.sum( + torch.mean(torch.abs(mod.act_fun[i].coef), dim=1) + ) + coeff_diff_l1 = torch.sum( + torch.mean( + torch.abs(torch.diff(mod.act_fun[i].coef)), dim=1 + ) + ) + reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1 + + return reg_ + + total_reg = torch.tensor(0.0).to(device=x.device, dtype=torch.float32) + size = 0 + for mod in self.modules(): + if isinstance(mod, KAN): + total_reg += reg(mod) + size += 1 + + mean_reg = total_reg / size + return mean_reg + def _init_weights(self, module): if isinstance(module, KAN): # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -359,7 +414,18 @@ def configure_optimizers(self, train_config): ) return optimizer - def forward(self, idx, targets=None): + def forward( + self, + idx, + targets=None, + lamb=0.01, + lamb_l1=1.0, + lamb_entropy=2.0, + lamb_coef=0.0, + lamb_coefdiff=0.0, + small_mag_threshold=1e-16, + small_reg_factor=1.0, + ): device = idx.device b, t = idx.size() assert t <= self.block_size, ( @@ -392,6 +458,17 @@ def forward(self, idx, targets=None): ignore_index=-1, ) + reg = self.kan_loss( + x=idx, + lamb_l1=lamb_l1, + lamb_entropy=lamb_entropy, + lamb_coef=lamb_coef, + lamb_coefdiff=lamb_coefdiff, + small_mag_threshold=small_mag_threshold, + small_reg_factor=small_reg_factor, + ) + loss = loss + lamb * reg + return logits, loss @torch.no_grad() diff --git a/tests/test_gpt_kan.py b/tests/test_gpt_kan.py new file mode 100644 index 0000000..af2ce76 --- /dev/null +++ b/tests/test_gpt_kan.py @@ -0,0 +1,95 @@ +import torch +from kan_gpt.model import GPT as KAN_GPT + +VOCAB_SIZE = 8 +BLOCK_SIZE = 16 +MODEL_TYPE = "gpt-nano" + + +def get_gpt_model() -> KAN_GPT: + model_config = KAN_GPT.get_default_config() + model_config.model_type = MODEL_TYPE + model_config.vocab_size = VOCAB_SIZE + model_config.block_size = BLOCK_SIZE + model = KAN_GPT(model_config) + return model + + +def test_forward(): + with torch.no_grad(): + model = get_gpt_model() + x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) + + y, loss = model.forward(x) + + assert y.shape == ( + 1, + BLOCK_SIZE, + VOCAB_SIZE, + ), f"Shape mismatch: {y.shape}" + + +def test_backward(): + model = get_gpt_model() + x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) + + # Make sure grads exist + requires_grad_set = set() + for param in model.parameters(): + if param.requires_grad: + requires_grad_set.add(param) + assert len(requires_grad_set) > 0, "requires_grad is not set" + + y, loss = model.forward(x) + + assert y.shape == (1, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}" + + loss = y.mean() + loss.backward() + + # Make sure grads exist + grad_set = set() + for param in model.parameters(): + if isinstance(param.grad, torch.Tensor): + grad_set.add(param) + assert len(grad_set) > 0, f"Tensor.grad missing" + + +def test_forward_batched(): + with torch.no_grad(): + model = get_gpt_model() + x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long) + + y, loss = model.forward(x) + + assert y.shape == ( + 2, + BLOCK_SIZE, + VOCAB_SIZE, + ), f"Shape mismatch: {y.shape}" + + +def test_backward_batched(): + model = get_gpt_model() + x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long) + + # Make sure grads exist + requires_grad_set = set() + for param in model.parameters(): + if param.requires_grad: + requires_grad_set.add(param) + assert len(requires_grad_set) > 0, "requires_grad is not set" + + y, loss = model.forward(x) + + assert y.shape == (2, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}" + + loss = y.mean() + loss.backward() + + # Make sure grads exist + grad_set = set() + for param in model.parameters(): + if isinstance(param.grad, torch.Tensor): + grad_set.add(param) + assert len(grad_set) > 0, f"Tensor.grad missing" diff --git a/tests/test_gpt_mlp.py b/tests/test_gpt_mlp.py new file mode 100644 index 0000000..29182e7 --- /dev/null +++ b/tests/test_gpt_mlp.py @@ -0,0 +1,95 @@ +import torch +from kan_gpt.mingpt.model import GPT as MLP_GPT + +VOCAB_SIZE = 8 +BLOCK_SIZE = 16 +MODEL_TYPE = "gpt-nano" + + +def get_gpt_model() -> MLP_GPT: + model_config = MLP_GPT.get_default_config() + model_config.model_type = MODEL_TYPE + model_config.vocab_size = VOCAB_SIZE + model_config.block_size = BLOCK_SIZE + model = MLP_GPT(model_config) + return model + + +def test_forward(): + with torch.no_grad(): + model = get_gpt_model() + x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) + + y, loss = model.forward(x) + + assert y.shape == ( + 1, + BLOCK_SIZE, + VOCAB_SIZE, + ), f"Shape mismatch: {y.shape}" + + +def test_backward(): + model = get_gpt_model() + x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long) + + # Make sure grads exist + requires_grad_set = set() + for param in model.parameters(): + if param.requires_grad: + requires_grad_set.add(param) + assert len(requires_grad_set) > 0, "requires_grad is not set" + + y, loss = model.forward(x) + + assert y.shape == (1, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}" + + loss = y.mean() + loss.backward() + + # Make sure grads exist + grad_set = set() + for param in model.parameters(): + if isinstance(param.grad, torch.Tensor): + grad_set.add(param) + assert len(grad_set) > 0, f"Tensor.grad missing" + + +def test_forward_batched(): + with torch.no_grad(): + model = get_gpt_model() + x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long) + + y, loss = model.forward(x) + + assert y.shape == ( + 2, + BLOCK_SIZE, + VOCAB_SIZE, + ), f"Shape mismatch: {y.shape}" + + +def test_backward_batched(): + model = get_gpt_model() + x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long) + + # Make sure grads exist + requires_grad_set = set() + for param in model.parameters(): + if param.requires_grad: + requires_grad_set.add(param) + assert len(requires_grad_set) > 0, "requires_grad is not set" + + y, loss = model.forward(x) + + assert y.shape == (2, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}" + + loss = y.mean() + loss.backward() + + # Make sure grads exist + grad_set = set() + for param in model.parameters(): + if isinstance(param.grad, torch.Tensor): + grad_set.add(param) + assert len(grad_set) > 0, f"Tensor.grad missing" diff --git a/tests/test_kan.py b/tests/test_kan.py new file mode 100644 index 0000000..3718cd9 --- /dev/null +++ b/tests/test_kan.py @@ -0,0 +1,74 @@ +import torch +from kan_gpt.kan.KAN import KAN + + +def test_forward(): + with torch.no_grad(): + model = KAN(width=[2, 5, 2]) + x = torch.zeros((1, 1, 2), dtype=torch.float32) + + y = model.forward(x) + + assert y.shape == (1, 1, 2), f"Shape mismatch: {y.shape}" + + +def test_backward(): + model = KAN(width=[2, 5, 2]) + x = torch.zeros((1, 1, 2), dtype=torch.float32) + + # Make sure grads exist + requires_grad_set = set() + for param in model.parameters(): + if param.requires_grad: + requires_grad_set.add(param) + assert len(requires_grad_set) > 0, "requires_grad is not set" + + y = model.forward(x) + + assert y.shape == (1, 1, 2), f"Shape mismatch: {y.shape}" + + loss = y.mean() + loss.backward() + + # Make sure grads exist + grad_set = set() + for param in model.parameters(): + if isinstance(param.grad, torch.Tensor): + grad_set.add(param) + assert len(grad_set) > 0, f"Tensor.grad missing" + + +def test_forward_batched(): + with torch.no_grad(): + model = KAN(width=[2, 5, 2]) + x = torch.zeros((2, 1, 2), dtype=torch.float32) + + y = model.forward(x) + + assert y.shape == (2, 1, 2), f"Shape mismatch: {y.shape}" + + +def test_backward_batched(): + model = KAN(width=[2, 5, 2]) + x = torch.zeros((2, 1, 2), dtype=torch.float32) + + # Make sure grads exist + requires_grad_set = set() + for param in model.parameters(): + if param.requires_grad: + requires_grad_set.add(param) + assert len(requires_grad_set) > 0, "requires_grad is not set" + + y = model.forward(x) + + assert y.shape == (2, 1, 2), f"Shape mismatch: {y.shape}" + + loss = y.mean() + loss.backward() + + # Make sure grads exist + grad_set = set() + for param in model.parameters(): + if isinstance(param.grad, torch.Tensor): + grad_set.add(param) + assert len(grad_set) > 0, f"Tensor.grad missing"