In [1]:
import torch
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torch.distributions import Uniform
# from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR, OneCycleLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import torchmetrics

import numpy as np
import matplotlib.pyplot as plt
import random

import wandb

PATH_DATASETS = "."
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
# BATCH_SIZE=1

In [2]:
pl.seed_everything(125)

Global seed set to 125


125

In [3]:
NORMALIZE = True
NUM_LAYERS = 2
HIDDEN_DIM = 2
LEARNING_RATE = 1e-2
NUM_ITERS = int(1e5)
RANGE = [5, 10]

## Problem Description

$$a,\, b,\, c,\, d,\, e \overset{NAC}{\Longrightarrow} a,\,c,\,d \overset{MLP}{\Longrightarrow} a^2 + c^2 - \sqrt{d}$$

## Generate Data

In [4]:
class DataWrapper(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [5]:
def generate_data(range_feature=[0,1], size=None):
    U = Uniform(*range_feature)
    X = U.sample((size, 5))
    Y = torch.column_stack([X[:,0]**2 + X[:,2]**2 - torch.sqrt(X[:,3])])
    ds = DataWrapper(X, Y)
    return ds

In [6]:
ds_test = generate_data(range_feature=[0, 1], size=100)

In [7]:
ds_test[0]

(tensor([0.7219, 0.3854, 0.7279, 0.1047, 0.6605]), tensor([0.7272]))

In [8]:
len(ds_test)

100

## Network Setup

In [9]:
class NAC_MLP(pl.LightningModule):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.W_hat = Parameter(torch.Tensor(n_out, n_in))
        self.M_hat = Parameter(torch.Tensor(n_out, n_in))
        self.reset_parameters()
        
        self.mlp = nn.Sequential(
            nn.Linear(n_out, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1)
        )
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_hat)
        nn.init.kaiming_uniform_(self.M_hat)
    
    def forward(self, x):
        weights = torch.tanh(self.W_hat)**2 * torch.sigmoid(self.M_hat)
        principals = F.linear(x, weights)
        return self.mlp(principals)
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        
        out = self(X)
        loss = F.mse_loss(out, y)
        mea = torch.mean(torch.abs(y - out))
        self.log('train_loss', loss)
        self.log('train_mea', mea)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        
        out = self(X)
        loss = F.mse_loss(out, y)
        mea = torch.mean(torch.abs(y - out))
        self.log('val_loss', loss)
        self.log('val_mea', mea)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
        return optimizer
    
    def prepare_data(self):
        self.ds_train = generate_data(range_feature=[0,1], size=10000)
        self.ds_val = generate_data(range_feature=[0,1], size=1000)
        
    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=256)
    
    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=128)

In [10]:
model = NAC_MLP(
    n_in=5,
    n_out=3
)

wandb_logger = WandbLogger(
    project='NALU_PCA'
)

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=200,
    gpus=AVAIL_GPUS,
    enable_progress_bar=False,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [11]:
model

NAC_MLP(
  (mlp): Sequential(
    (0): Linear(in_features=3, out_features=512, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=512, out_features=1, bias=True)
  )
)

In [12]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | mlp  | Sequential | 267 K 
------------------------------------
267 K     Trainable params
0         Non-trainable params
267 K     Total params
1.069     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
  rank_zero_warn(
Global seed set to 125
  rank_zero_warn(
  rank_zero_warn(
[34m[1mwandb[0m: Currently logged in as: [33maxect[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [13]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▆▅▄▄▂▄▆▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
train_mea,▅▄▄▄▃▄▅▄▃▃▃▃▃▃▂▃▂▂▂▁▂▂▁▁▂▁▂▂▂▁▁▁▁▁▁▁▂▂▂█
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,▄▅▂▂▃▃▃▄▆█▂▂▂▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_mea,▆▆▄▄▅▅▅▆▇█▃▂▃▃▃▂▃▂▂▃▃▃▂▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train_loss,0.03739
train_mea,0.18647
trainer/global_step,7999.0
val_loss,0.00025
val_mea,0.01281


In [14]:
model.W_hat

Parameter containing:
tensor([[ 3.6704e-01, -1.4468e-07,  1.0847e+00, -1.6784e-01, -1.2082e-02],
        [ 1.7089e-40,  5.0595e-03, -1.0072e-02,  1.1546e+00,  8.6971e-03],
        [-1.3451e+00,  4.0758e-02,  5.0243e-01,  9.0300e-01, -4.4044e-02]],
       requires_grad=True)

In [15]:
model.M_hat

Parameter containing:
tensor([[-0.5983,  0.3591,  0.2443,  0.5103,  0.2418],
        [-0.3731,  0.0088,  0.4614,  0.4620, -1.2909],
        [ 0.8197, -0.6701, -0.9186,  0.7601, -1.3341]], requires_grad=True)

In [16]:
model.eval()

torch.tanh(model.W_hat)**2 * torch.sigmoid(model.M_hat)

tensor([[4.3803e-02, 1.2325e-14, 3.5437e-01, 1.7277e-02, 8.1761e-05],
        [0.0000e+00, 1.2855e-05, 6.2211e-05, 4.1180e-01, 1.6315e-05],
        [5.2891e-01, 5.6166e-04, 6.1418e-02, 3.5102e-01, 4.0392e-04]],
       grad_fn=<MulBackward0>)

In [17]:
W = torch.tanh(model.W_hat)**2 * torch.sigmoid(model.M_hat)

In [18]:
# Remove dependencies
for i in range(W.shape[1]):
    # Find maximum
    m = torch.max(W[:,i])
    W[:,i][W[:,i] < m] = 0.0

# Remove low order
m = torch.max(W)
W[W < m * 0.01] = 0.0

# Make one
W[W > 0.0] = 1.0

# To index
ics = []
for i in range(W.shape[1]):
    if torch.any(W[:,i] == 1.0):
        ics.append(i)

In [19]:
W

tensor([[0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.]], grad_fn=<IndexPutBackward0>)

In [25]:
ics

[0, 2, 3]

In [28]:
X, Y = ds_test[:]

In [29]:
X_hat = X[:,ics]

In [30]:
X = X.detach().numpy()
X_hat = X_hat.detach().numpy()
Y = Y.detach().numpy()

In [31]:
np.savez("nac_filtered.npz", X, X_hat, Y)