In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

In [18]:
import numpy as np
import random

In [19]:
config = {}
config['input_dim'] = 128
config['batch_size'] = 64
config['hidden_layers'] = [64, 5]
config['lr'] = 1e-3

In [20]:
class kae_dataset(Dataset):
    def __init__(self, num_data=100, input_dim=128):
        super(kae_dataset, self).__init__()
        self.num_data = num_data
        self.target = []
        for i in range(num_data):
            scale = random.random()
            phase_left = random.random() + 0.1
            phase_right = random.random() + 0.1
            x=np.arange(-phase_left*2*np.pi,phase_right*2*np.pi,(phase_right*2*np.pi+phase_left*2*np.pi)/input_dim)
            x=x.reshape(len(x),1)
            y=scale * np.sin(x)
            self.target.append(y)

    def __getitem__(self, index):
        target = torch.from_numpy(self.target[index]).squeeze().unsqueeze(0).to(torch.float32)
        return target

    def __len__(self):
        return len(self.target)

trainset = kae_dataset(10000, config['input_dim'])
valset = kae_dataset(100, config['input_dim'])
trainloader = DataLoader(trainset, batch_size=config['batch_size'], shuffle=True)
valloader = DataLoader(valset, batch_size=config['batch_size'], shuffle=False)

In [21]:
class MAE(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        input_dim,
    ):
        super(MAE, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, layers_hidden[0]),
            nn.Linear(layers_hidden[0], layers_hidden[1]),
            nn.Linear(layers_hidden[1], layers_hidden[0]),
            nn.Linear(layers_hidden[0], input_dim)
            )

    def forward(self, x: torch.Tensor):
        x = self.model(x)
        return x

In [22]:
# Define model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MAE(layers_hidden=config['hidden_layers'], input_dim=config['input_dim'])
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.MSELoss()

In [23]:
for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, targets in enumerate(pbar):
            targets = targets.view(-1, config['input_dim']).to(device)
            optimizer.zero_grad()
            output = model(targets)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for targets in valloader:
            targets = targets.view(-1, config['input_dim']).to(device)
            output = model(targets)
            val_loss += criterion(output, targets).item()
            
    val_loss /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}"
    )

100%|██████████| 157/157 [00:01<00:00, 101.81it/s, loss=0.00519, lr=0.001]


Epoch 1, Val Loss: 0.004964141175150871


100%|██████████| 157/157 [00:01<00:00, 99.16it/s, loss=0.00844, lr=0.0008]  


Epoch 2, Val Loss: 0.0026078533264808357


100%|██████████| 157/157 [00:01<00:00, 102.79it/s, loss=0.00119, lr=0.00064] 


Epoch 3, Val Loss: 0.0026384599623270333


100%|██████████| 157/157 [00:01<00:00, 101.56it/s, loss=0.00255, lr=0.000512]


Epoch 4, Val Loss: 0.0024633380817249417


100%|██████████| 157/157 [00:01<00:00, 97.51it/s, loss=0.00117, lr=0.00041] 


Epoch 5, Val Loss: 0.0024914569803513587


100%|██████████| 157/157 [00:01<00:00, 103.94it/s, loss=0.00847, lr=0.000328]


Epoch 6, Val Loss: 0.0026117953239008784


100%|██████████| 157/157 [00:01<00:00, 81.59it/s, loss=0.00344, lr=0.000262]


Epoch 7, Val Loss: 0.002459672396071255


100%|██████████| 157/157 [00:01<00:00, 80.13it/s, loss=0.00199, lr=0.00021] 


Epoch 8, Val Loss: 0.0025119019555859268


100%|██████████| 157/157 [00:02<00:00, 76.09it/s, loss=0.000763, lr=0.000168]


Epoch 9, Val Loss: 0.0025147462729364634


100%|██████████| 157/157 [00:01<00:00, 80.17it/s, loss=0.000616, lr=0.000134]


Epoch 10, Val Loss: 0.0025913557037711143
