In [17]:
from src.efficient_kan import KAN

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

In [19]:
import numpy as np
import random

In [20]:
config = {}
config['input_dim'] = 128
config['batch_size'] = 64
config['hidden_layers'] = [64, 5]
config['lr'] = 1e-3



In [21]:
class kae_dataset(Dataset):
    def __init__(self, num_data=100, input_dim=128):
        super(kae_dataset, self).__init__()
        self.num_data = num_data
        self.target = []
        for i in range(num_data):
            scale = random.random()
            phase_left = random.random() + 0.1
            phase_right = random.random() + 0.1
            x=np.arange(-phase_left*2*np.pi,phase_right*2*np.pi,(phase_right*2*np.pi+phase_left*2*np.pi)/input_dim)
            x=x.reshape(len(x),1)
            y=scale * np.sin(x)
            self.target.append(y)

    def __getitem__(self, index):
        target = torch.from_numpy(self.target[index]).squeeze().unsqueeze(0).to(torch.float32)
        return target

    def __len__(self):
        return len(self.target)

trainset = kae_dataset(10000, config['input_dim'])
valset = kae_dataset(100, config['input_dim'])
trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True)
valloader = DataLoader(valset, batch_size=config['batch_size'], shuffle=False)

In [22]:
class KAE(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        input_dim,
    ):
        super(KAE, self).__init__()
        self.encoder = KAN([input_dim, layers_hidden[0], layers_hidden[1]])
        self.decoder = KAN([layers_hidden[1], layers_hidden[0], input_dim])

    def forward(self, x: torch.Tensor):
        x = self.decoder(self.encoder(x))
        return x

In [23]:
# Define model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = KAE(layers_hidden=config['hidden_layers'], input_dim=config['input_dim'])
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.MSELoss()

In [24]:
for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, targets in enumerate(pbar):
            targets = targets.view(-1, config['input_dim']).to(device)
            optimizer.zero_grad()
            output = model(targets)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for targets in valloader:
            targets = targets.view(-1, config['input_dim']).to(device)
            output = model(targets)
            val_loss += criterion(output, targets).item()
            
    val_loss /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}"
    )

100%|██████████| 157/157 [00:03<00:00, 45.01it/s, loss=0.0194, lr=0.001] 


Epoch 1, Val Loss: 0.0056625246070325375


100%|██████████| 157/157 [00:03<00:00, 45.12it/s, loss=0.00203, lr=0.0008] 


Epoch 2, Val Loss: 0.000848717987537384


100%|██████████| 157/157 [00:03<00:00, 45.15it/s, loss=0.000341, lr=0.00064]


Epoch 3, Val Loss: 0.00034628852154128253


100%|██████████| 157/157 [00:03<00:00, 47.27it/s, loss=0.000256, lr=0.000512]


Epoch 4, Val Loss: 0.00020428435527719557


100%|██████████| 157/157 [00:03<00:00, 46.46it/s, loss=0.000103, lr=0.00041]


Epoch 5, Val Loss: 0.00015337508375523612


100%|██████████| 157/157 [00:03<00:00, 44.89it/s, loss=6.7e-5, lr=0.000328]  


Epoch 6, Val Loss: 0.0001200814986077603


100%|██████████| 157/157 [00:03<00:00, 47.26it/s, loss=0.000109, lr=0.000262]


Epoch 7, Val Loss: 0.00010019696128438227


100%|██████████| 157/157 [00:03<00:00, 47.15it/s, loss=8.18e-5, lr=0.00021] 


Epoch 8, Val Loss: 8.816278932499699e-05


100%|██████████| 157/157 [00:03<00:00, 44.28it/s, loss=0.000205, lr=0.000168]


Epoch 9, Val Loss: 8.009559678612277e-05


100%|██████████| 157/157 [00:03<00:00, 46.47it/s, loss=5.29e-5, lr=0.000134] 


Epoch 10, Val Loss: 7.336376620514784e-05
