In [1]:
import torch
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torch.distributions import Uniform
# from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR, OneCycleLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import torchmetrics

import numpy as np
import matplotlib.pyplot as plt
import random

import wandb

PATH_DATASETS = "."
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
# BATCH_SIZE=1

In [2]:
pl.seed_everything(125)

Global seed set to 125


125

In [3]:
NORMALIZE = True
NUM_LAYERS = 2
HIDDEN_DIM = 2
LEARNING_RATE = 1e-2
NUM_ITERS = int(1e5)
RANGE = [5, 10]

## Problem Description

$$a,\, b,\, c,\, d(=-c) \Rightarrow a-c(= a+d),~b+c(=b-d)$$

## Generate Data

In [4]:
class DataWrapper(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [5]:
def degenerate_data(range_feature=[0,1], size=None):
    U = Uniform(*range_feature)
    X = U.sample((size, 3))
    X = torch.column_stack([X, -X[:,2]])
    Y = torch.column_stack([X[:,0] - X[:,2], X[:,1] + X[:,2]])
    ds = DataWrapper(X, Y)
    return ds

In [6]:
ds_test = degenerate_data(range_feature=[0, 1], size=100)

In [7]:
ds_test[0]

(tensor([ 0.7219,  0.3854,  0.7279, -0.7279]), tensor([-0.0060,  1.1132]))

In [8]:
len(ds_test)

100

## Network Setup

In [9]:
class NAC(pl.LightningModule):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.W_hat = Parameter(torch.Tensor(n_out, n_in))
        self.M_hat = Parameter(torch.Tensor(n_out, n_in))
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_hat)
        nn.init.kaiming_uniform_(self.M_hat)
    
    def forward(self, x):
        weights = torch.tanh(self.W_hat)**2 * torch.sigmoid(self.M_hat)
        return F.linear(x, weights)
    
    def training_step(self, batch, batch_idx):
        X, Y = batch
        
        out = self(X)
        loss = F.mse_loss(out, Y)
        mea = torch.mean(torch.abs(Y - out))
        self.log('train_loss', loss)
        self.log('train_mea', mea)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, Y = batch
        
        out = self(X)
        loss = F.mse_loss(out, Y)
        mea = torch.mean(torch.abs(Y - out))
        self.log('val_loss', loss)
        self.log('val_mea', mea)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.RMSprop(self.parameters(), lr=LEARNING_RATE)
        return optimizer
    
    def prepare_data(self):
        self.ds_train = degenerate_data(range_feature=[0,1], size=500)
        self.ds_val = degenerate_data(range_feature=[0,1], size=100)
        
#     def setup(self, stage=None):
#         if stage == "fit" or stage is None:
#             self.ds_train = DataWrapper(self.X_train, self.y_train)
#             self.ds_val = DataWrapper(self.X_val, self.y_val)
#         if stage == "test" or stage is None:
#             self.ds_test = DataWrapper(self.X_val, self.y_val)
        
    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=50)
    
    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=10)

In [10]:
model = NAC(
    n_in=4,
    n_out=2
)

wandb_logger = WandbLogger(
    project='NALU_Degenerate_Test'
)

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=200,
    gpus=AVAIL_GPUS,
    enable_progress_bar=False,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [11]:
model

NAC()

In [12]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
------------------------------
16        Trainable params
0         Non-trainable params
16        Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
  rank_zero_warn(
Global seed set to 125
  rank_zero_warn(
  rank_zero_warn(
[34m[1mwandb[0m: Currently logged in as: [33maxect[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [13]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▅▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_mea,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_mea,█▅▄▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,199.0
train_loss,0.0
train_mea,0.00095
trainer/global_step,1999.0
val_loss,0.0
val_mea,0.00076


In [14]:
model.W_hat

Parameter containing:
tensor([[-5.2111e+00, -1.9225e-02, -1.2612e-44, -4.5368e+00],
        [ 7.6311e-02, -4.2474e+00, -4.1179e+00,  1.4013e-44]],
       requires_grad=True)

In [15]:
model.M_hat

Parameter containing:
tensor([[ 6.1629, -1.7320, -0.8251,  6.3726],
        [-0.1471,  6.9857,  7.3735,  0.5990]], requires_grad=True)

In [16]:
model.eval()

torch.tanh(model.W_hat)**2 * torch.sigmoid(model.M_hat)

tensor([[9.9778e-01, 5.5548e-05, 0.0000e+00, 9.9784e-01],
        [2.6875e-03, 9.9826e-01, 9.9831e-01, 0.0000e+00]],
       grad_fn=<MulBackward0>)

In [17]:
X, Y = ds_test[:]

In [18]:
torch.column_stack([model(X), Y])

tensor([[-0.0060,  1.1133, -0.0060,  1.1132],
        [-0.6683,  1.4328, -0.6698,  1.4350],
        [ 0.3748,  1.1730,  0.3756,  1.1725],
        [ 0.0609,  1.7957,  0.0610,  1.7961],
        [ 0.0281,  1.2784,  0.0282,  1.2786],
        [-0.4923,  1.8754, -0.4934,  1.8775],
        [ 0.9638,  0.4117,  0.9659,  0.4098],
        [-0.3148,  0.3814, -0.3155,  0.3820],
        [ 0.1133,  0.8897,  0.1135,  0.8907],
        [ 0.0775,  1.7429,  0.0777,  1.7436],
        [-0.2845,  1.4282, -0.2852,  1.4291],
        [-0.5382,  0.7611, -0.5394,  0.7624],
        [ 0.4155,  0.9116,  0.4164,  0.9113],
        [-0.3926,  1.1610, -0.3934,  1.1619],
        [-0.5392,  1.7279, -0.5403,  1.7300],
        [ 0.0423,  1.0069,  0.0424,  1.0071],
        [-0.7705,  1.4032, -0.7722,  1.4050],
        [ 0.1416,  1.5502,  0.1419,  1.5505],
        [-0.4670,  1.0198, -0.4680,  1.0214],
        [-0.7045,  1.9481, -0.7060,  1.9508],
        [ 0.1447,  0.5793,  0.1451,  0.5791],
        [-0.3278,  0.9954, -0.3286