In [1]:
import torch
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torch.distributions import Uniform
# from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR, OneCycleLR

import pytorch_lightning as pl
from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import torchmetrics

import numpy as np
import matplotlib.pyplot as plt
import random

import wandb

PATH_DATASETS = "."
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512 if AVAIL_GPUS else 64
# BATCH_SIZE=1

In [2]:
pl.seed_everything(125)

Global seed set to 125


125

In [3]:
NORMALIZE = True
NUM_LAYERS = 2
HIDDEN_DIM = 2
LEARNING_RATE = 1e-2
NUM_ITERS = int(1e5)
RANGE = [5, 10]
ARITHMETIC_FUNCTIONS = {
    'add': lambda x, y: x + y,
    'sub': lambda x, y: x - y,
    'mul': lambda x, y: x * y,
    'div': lambda x, y: x / y,
    'squared': lambda x, y: torch.pow(x, 2),
    'root': lambda x, y: torch.sqrt(x),
}

## Generate Data

In [4]:
def generate_data(num_train, num_test, dim, num_sum, fn, support):
    data = torch.FloatTensor(dim).uniform_(*support).unsqueeze_(1)
    X, y = [], []
    for i in range(num_train + num_test):
        idx_a = random.sample(range(dim), num_sum)
        idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X = torch.FloatTensor(X)
    y = torch.FloatTensor(y).unsqueeze_(1)
    indices = list(range(num_train + num_test))
    np.random.shuffle(indices)
    X_train, y_train = X[indices[num_test:]], y[indices[num_test:]]
    X_test, y_test = X[indices[:num_test]], y[indices[:num_test]]
    return X_train, y_train, X_test, y_test

In [5]:
class DataWrapper(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [6]:
X_train, y_train, X_val, y_val = generate_data(100, 10, 100, 5, lambda x, y: x + y, RANGE)
ds_train = DataWrapper(X_train, y_train)

In [7]:
ds_train[3]

(tensor([40.0543, 35.0587]), tensor([75.1130]))

In [8]:
len(ds_train)

100

## Network Setup

In [9]:
class NAC(pl.LightningModule):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.W_hat = Parameter(torch.Tensor(n_out, n_in))
        self.M_hat = Parameter(torch.Tensor(n_out, n_in))
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_hat)
        nn.init.kaiming_uniform_(self.M_hat)
    
    def forward(self, x):
        weights = torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
        return F.linear(x, weights)
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        
        out = self(X)
        loss = F.mse_loss(out, y)
        mea = torch.mean(torch.abs(y - out))
        self.log('train_loss', loss)
        self.log('train_mea', mea)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        
        out = self(X)
        loss = F.mse_loss(out, y)
        mea = torch.mean(torch.abs(y - out))
        self.log('val_loss', loss)
        self.log('val_mea', mea)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.RMSprop(self.parameters(), lr=LEARNING_RATE)
        return optimizer
    
    def prepare_data(self):
        X_train, y_train, X_val, y_val = generate_data(
            500, 
            50, 
            100, 
            5,
            lambda x, y: x + y,
            RANGE
        )
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.ds_train = DataWrapper(self.X_train, self.y_train)
            self.ds_val = DataWrapper(self.X_val, self.y_val)
        if stage == "test" or stage is None:
            self.ds_test = DataWrapper(self.X_val, self.y_val)
        
    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=1)
    
    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=1)

In [10]:
model = NAC(
    n_in=2,
    n_out=1
)

wandb_logger = WandbLogger(
    project='NALU_Test'
)

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=100,
    gpus=AVAIL_GPUS,
    enable_progress_bar=False,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [11]:
model

NAC()

In [12]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(
Global seed set to 125
  rank_zero_warn(
[34m[1mwandb[0m: Currently logged in as: [33maxect[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.14 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [13]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_mea,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_mea,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
train_loss,0.0
train_mea,0.0
trainer/global_step,49999.0
val_loss,0.0
val_mea,0.0


In [14]:
X_0, y_0 = ds_train[0:10]

In [15]:
X_0

tensor([[37.8622, 40.9403],
        [42.4239, 40.7112],
        [38.4500, 39.7999],
        [40.0543, 35.0587],
        [39.8570, 41.2656],
        [45.1458, 42.6573],
        [37.1341, 40.7882],
        [35.9326, 34.9819],
        [39.1999, 41.4297],
        [37.5842, 44.6603]])

In [16]:
y_0

tensor([[78.8026],
        [83.1351],
        [78.2499],
        [75.1130],
        [81.1226],
        [87.8031],
        [77.9223],
        [70.9145],
        [80.6296],
        [82.2445]])

In [17]:
model(X_0)

tensor([[78.8026],
        [83.1351],
        [78.2499],
        [75.1130],
        [81.1226],
        [87.8031],
        [77.9223],
        [70.9145],
        [80.6296],
        [82.2445]], grad_fn=<MmBackward0>)

In [25]:
model.W_hat

Parameter containing:
tensor([[9.0121, 9.0110]], requires_grad=True)

In [19]:
model.M_hat

Parameter containing:
tensor([[16.6356, 16.6356]], requires_grad=True)

In [26]:
model.eval()

torch.tanh(model.W_hat) * torch.sigmoid(model.M_hat)

tensor([[1., 1.]], grad_fn=<MulBackward0>)