In [1]:
import torch
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torch.distributions import Uniform
# from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR, OneCycleLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import torchmetrics

import numpy as np
import matplotlib.pyplot as plt
import random

import wandb

PATH_DATASETS = "."
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
# BATCH_SIZE=1

In [2]:
pl.seed_everything(125)

Global seed set to 125


125

In [3]:
NORMALIZE = True
NUM_LAYERS = 2
HIDDEN_DIM = 2
LEARNING_RATE = 1e-2
NUM_ITERS = int(1e5)
RANGE = [5, 10]

## Problem Description

$$a,\, b,\, c,\, d,\, e \overset{NAC}{\Longrightarrow} a+c,\,b+d,\,e \overset{MLP}{\Longrightarrow} (a+c)^2 + \cos(b+d) - e$$

## Generate Data

In [4]:
class DataWrapper(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [21]:
def generate_data(range_feature=[0,1], size=None):
    U = Uniform(*range_feature)
    X = U.sample((size, 5))
    Y = (X[:,0] + X[:,2]) ** 2 + torch.cos(X[:,1] + X[:,3]) - X[:,4]
    ds = DataWrapper(X, Y)
    return ds

In [22]:
ds_test = generate_data(range_feature=[0, 1], size=100)

In [23]:
ds_test[0]

(tensor([0.0138, 0.7151, 0.6569, 0.5526, 0.7840]), tensor(-0.0355))

In [24]:
len(ds_test)

100

## Network Setup

In [25]:
class NAC_MLP(pl.LightningModule):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.W_hat = Parameter(torch.Tensor(n_out, n_in))
        self.M_hat = Parameter(torch.Tensor(n_out, n_in))
        self.reset_parameters()
        
        self.mlp = nn.Sequential(
            nn.Linear(n_out, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1)
        )
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_hat)
        nn.init.kaiming_uniform_(self.M_hat)
    
    def forward(self, x):
        weights = torch.tanh(self.W_hat)**2 * torch.sigmoid(self.M_hat)
        principals = F.linear(x, weights)
        return self.mlp(principals)
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        
        out = self(X)
        loss = F.mse_loss(out, y)
        mea = torch.mean(torch.abs(y - out))
        self.log('train_loss', loss)
        self.log('train_mea', mea)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        
        out = self(X)
        loss = F.mse_loss(out, y)
        mea = torch.mean(torch.abs(y - out))
        self.log('val_loss', loss)
        self.log('val_mea', mea)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
        return optimizer
    
    def prepare_data(self):
        self.ds_train = generate_data(range_feature=[0,1], size=10000)
        self.ds_val = generate_data(range_feature=[0,1], size=1000)
        
    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=256)
    
    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=128)

In [26]:
model = NAC_MLP(
    n_in=5,
    n_out=3
)

wandb_logger = WandbLogger(
    project='NALU_PCA'
)

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=200,
    gpus=AVAIL_GPUS,
    enable_progress_bar=False,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [27]:
model

NAC_MLP(
  (mlp): Sequential(
    (0): Linear(in_features=3, out_features=512, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=512, out_features=1, bias=True)
  )
)

In [28]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | mlp  | Sequential | 267 K 
------------------------------------
267 K     Trainable params
0         Non-trainable params
267 K     Total params
1.069     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
  rank_zero_warn(
  loss = F.mse_loss(out, y)
Global seed set to 125
  rank_zero_warn(
  rank_zero_warn(
  loss = F.mse_loss(out, y)
  loss = F.mse_loss(out, y)
  loss = F.mse_loss(out, y)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  loss = F.mse_loss(out, y)
  loss = F.mse_loss(out, y)
  loss = F.mse_loss(out, y)
  loss = F.mse_loss(out, y)


In [29]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,███████████████████████████████████████▁
train_mea,▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,▂▁▁▁▁▁▁▁▁▁▃▁▂▃█▁▂▂▂▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
val_mea,▁▁▁▁▁▁▁▁▁▁▃▁▂▃█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train_loss,0.46938
train_mea,0.62022
trainer/global_step,7999.0
val_loss,0.83808
val_mea,0.72466


In [30]:
model.W_hat

Parameter containing:
tensor([[-5.9468e-01,  5.2503e-01,  8.4526e-01, -6.5455e-01, -5.4809e-05],
        [-3.3708e-01,  8.0540e-01,  1.1748e-01, -1.0082e+00, -2.6944e-01],
        [-1.3375e-01,  2.7850e-01,  5.5430e-01,  3.2299e-01,  7.1952e-01]],
       requires_grad=True)

In [31]:
model.M_hat

Parameter containing:
tensor([[ 0.4514, -0.2281,  0.2362,  0.7311, -0.6150],
        [-0.0545,  1.1733, -0.8530,  0.8046, -0.1618],
        [ 0.3106, -0.9981, -0.4327,  0.8649,  1.2519]], requires_grad=True)

In [32]:
model.eval()

torch.tanh(model.W_hat)**2 * torch.sigmoid(model.M_hat)

tensor([[1.7373e-01, 1.0279e-01, 2.6495e-01, 2.2297e-01, 1.0542e-09],
        [5.1332e-02, 3.3983e-01, 4.0866e-03, 4.0437e-01, 3.1817e-02],
        [1.0200e-02, 1.9854e-02, 9.9848e-02, 6.8590e-02, 2.9566e-01]],
       grad_fn=<MulBackward0>)

In [33]:
W = torch.tanh(model.W_hat)**2 * torch.sigmoid(model.M_hat)

In [34]:
# Remove dependencies
for i in range(W.shape[1]):
    # Find maximum
    m = torch.max(W[:,i])
    W[:,i][W[:,i] < m] = 0.0

# Remove low order
m = torch.max(W)
W[W < m * 0.01] = 0.0

# Make one
W[W > 0.0] = 1.0

In [35]:
W

tensor([[1., 0., 1., 0., 0.],
        [0., 1., 0., 1., 0.],
        [0., 0., 0., 0., 1.]], grad_fn=<IndexPutBackward0>)

In [36]:
X, Y = ds_test[:]

In [37]:
X_hat = F.linear(X, W)

In [38]:
X = X.detach().numpy()
X_hat = X_hat.detach().numpy()
Y = Y.detach().numpy()

In [39]:
np.savez("nac_complicated.npz", X, X_hat, Y)