Skip to content

Commit

Permalink
test(kan_gpt,mlp_gpt,kan): test cases for forward-backward passes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 4, 2024
1 parent 0615a04 commit a037758
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 7 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ pip install -e .

Use the following dummy script to make sure everything is working as expected
```bash
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu
```

Then make use of the training script
Expand All @@ -81,17 +81,18 @@ python -m kan_gpt.train
- [x] Dataset downloading script for [WebText](https://github.com/openai/gpt-2-output-dataset)
- [x] PyTorch Dataset parser for [WebText](https://github.com/openai/gpt-2-output-dataset)
- [ ] Mini training POC for KAN-GPT
- [ ] Integrate KAN training logic from `KAN.train_kan`
- [x] Integrate KAN training logic from `KAN.train_kan`
- [ ] Train a dummy batch
- [x] Mini training POC for MLP-GPT
- [x] Train MLP-GPT on the webtext dataset as a baseline
- [ ] Auto Save checkpoints
- [ ] Auto Save checkpoints to W&B
- [ ] Script to load checkpoint in interactive mode
- [ ] Training script to PyTorch Lighting
- [ ] Test Cases
- [ ] KAN: Forward-Backward test
- [ ] GPT: Forward-Backward test
- [ ] KAN_GPT: Forward-Backward test
- [x] KAN: Forward-Backward test
- [x] GPT: Forward-Backward test
- [x] KAN_GPT: Forward-Backward test

## Development

Expand Down
79 changes: 78 additions & 1 deletion kan_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,61 @@ def __init__(self, config):
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params / 1e6,))

def kan_loss(
self,
x: torch.Tensor,
lamb_l1=1.0,
lamb_entropy=2.0,
lamb_coef=0.0,
lamb_coefdiff=0.0,
small_mag_threshold=1e-16,
small_reg_factor=1.0,
):

def reg(mod):

def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
return (x < th) * x * factor + (x > th) * (
x + (factor - 1) * th
)

reg_ = 0.0
for i in range(len(mod.acts_scale)):
vec = mod.acts_scale[i].reshape(
-1,
)

p = vec / torch.sum(vec)
l1 = torch.sum(nonlinear(vec))
entropy = -torch.sum(p * torch.log2(p + 1e-4))
reg_ += (
lamb_l1 * l1 + lamb_entropy * entropy
) # both l1 and entropy

# regularize coefficient to encourage spline to be zero
for i in range(len(mod.act_fun)):
coeff_l1 = torch.sum(
torch.mean(torch.abs(mod.act_fun[i].coef), dim=1)
)
coeff_diff_l1 = torch.sum(
torch.mean(
torch.abs(torch.diff(mod.act_fun[i].coef)), dim=1
)
)
reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1

return reg_

total_reg = torch.tensor(0.0).to(device=x.device, dtype=torch.float32)
size = 0
for mod in self.modules():
if isinstance(mod, KAN):
total_reg += reg(mod)
size += 1

mean_reg = total_reg / size
return mean_reg

def _init_weights(self, module):
if isinstance(module, KAN):
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
Expand Down Expand Up @@ -359,7 +414,18 @@ def configure_optimizers(self, train_config):
)
return optimizer

def forward(self, idx, targets=None):
def forward(
self,
idx,
targets=None,
lamb=0.01,
lamb_l1=1.0,
lamb_entropy=2.0,
lamb_coef=0.0,
lamb_coefdiff=0.0,
small_mag_threshold=1e-16,
small_reg_factor=1.0,
):
device = idx.device
b, t = idx.size()
assert t <= self.block_size, (
Expand Down Expand Up @@ -392,6 +458,17 @@ def forward(self, idx, targets=None):
ignore_index=-1,
)

reg = self.kan_loss(
x=idx,
lamb_l1=lamb_l1,
lamb_entropy=lamb_entropy,
lamb_coef=lamb_coef,
lamb_coefdiff=lamb_coefdiff,
small_mag_threshold=small_mag_threshold,
small_reg_factor=small_reg_factor,
)
loss = loss + lamb * reg

return logits, loss

@torch.no_grad()
Expand Down
95 changes: 95 additions & 0 deletions tests/test_gpt_kan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from kan_gpt.model import GPT as KAN_GPT

VOCAB_SIZE = 8
BLOCK_SIZE = 16
MODEL_TYPE = "gpt-nano"


def get_gpt_model() -> KAN_GPT:
model_config = KAN_GPT.get_default_config()
model_config.model_type = MODEL_TYPE
model_config.vocab_size = VOCAB_SIZE
model_config.block_size = BLOCK_SIZE
model = KAN_GPT(model_config)
return model


def test_forward():
with torch.no_grad():
model = get_gpt_model()
x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)

y, loss = model.forward(x)

assert y.shape == (
1,
BLOCK_SIZE,
VOCAB_SIZE,
), f"Shape mismatch: {y.shape}"


def test_backward():
model = get_gpt_model()
x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)

# Make sure grads exist
requires_grad_set = set()
for param in model.parameters():
if param.requires_grad:
requires_grad_set.add(param)
assert len(requires_grad_set) > 0, "requires_grad is not set"

y, loss = model.forward(x)

assert y.shape == (1, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}"

loss = y.mean()
loss.backward()

# Make sure grads exist
grad_set = set()
for param in model.parameters():
if isinstance(param.grad, torch.Tensor):
grad_set.add(param)
assert len(grad_set) > 0, f"Tensor.grad missing"


def test_forward_batched():
with torch.no_grad():
model = get_gpt_model()
x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)

y, loss = model.forward(x)

assert y.shape == (
2,
BLOCK_SIZE,
VOCAB_SIZE,
), f"Shape mismatch: {y.shape}"


def test_backward_batched():
model = get_gpt_model()
x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)

# Make sure grads exist
requires_grad_set = set()
for param in model.parameters():
if param.requires_grad:
requires_grad_set.add(param)
assert len(requires_grad_set) > 0, "requires_grad is not set"

y, loss = model.forward(x)

assert y.shape == (2, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}"

loss = y.mean()
loss.backward()

# Make sure grads exist
grad_set = set()
for param in model.parameters():
if isinstance(param.grad, torch.Tensor):
grad_set.add(param)
assert len(grad_set) > 0, f"Tensor.grad missing"
95 changes: 95 additions & 0 deletions tests/test_gpt_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from kan_gpt.mingpt.model import GPT as MLP_GPT

VOCAB_SIZE = 8
BLOCK_SIZE = 16
MODEL_TYPE = "gpt-nano"


def get_gpt_model() -> MLP_GPT:
model_config = MLP_GPT.get_default_config()
model_config.model_type = MODEL_TYPE
model_config.vocab_size = VOCAB_SIZE
model_config.block_size = BLOCK_SIZE
model = MLP_GPT(model_config)
return model


def test_forward():
with torch.no_grad():
model = get_gpt_model()
x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)

y, loss = model.forward(x)

assert y.shape == (
1,
BLOCK_SIZE,
VOCAB_SIZE,
), f"Shape mismatch: {y.shape}"


def test_backward():
model = get_gpt_model()
x = torch.zeros((1, BLOCK_SIZE), dtype=torch.long)

# Make sure grads exist
requires_grad_set = set()
for param in model.parameters():
if param.requires_grad:
requires_grad_set.add(param)
assert len(requires_grad_set) > 0, "requires_grad is not set"

y, loss = model.forward(x)

assert y.shape == (1, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}"

loss = y.mean()
loss.backward()

# Make sure grads exist
grad_set = set()
for param in model.parameters():
if isinstance(param.grad, torch.Tensor):
grad_set.add(param)
assert len(grad_set) > 0, f"Tensor.grad missing"


def test_forward_batched():
with torch.no_grad():
model = get_gpt_model()
x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)

y, loss = model.forward(x)

assert y.shape == (
2,
BLOCK_SIZE,
VOCAB_SIZE,
), f"Shape mismatch: {y.shape}"


def test_backward_batched():
model = get_gpt_model()
x = torch.zeros((2, BLOCK_SIZE), dtype=torch.long)

# Make sure grads exist
requires_grad_set = set()
for param in model.parameters():
if param.requires_grad:
requires_grad_set.add(param)
assert len(requires_grad_set) > 0, "requires_grad is not set"

y, loss = model.forward(x)

assert y.shape == (2, BLOCK_SIZE, VOCAB_SIZE), f"Shape mismatch: {y.shape}"

loss = y.mean()
loss.backward()

# Make sure grads exist
grad_set = set()
for param in model.parameters():
if isinstance(param.grad, torch.Tensor):
grad_set.add(param)
assert len(grad_set) > 0, f"Tensor.grad missing"
Loading

0 comments on commit a037758

Please sign in to comment.