In [1]:
import torch
import numpy as np
import tensorflow as tf

In [2]:
n_nodes = 100
n_epochs = 500
n_features = 50
n_samples = 200

In [3]:
x = np.random.uniform(size=(n_samples, n_features))
y = np.mean(x, 1) ** 2

## Tensorflow

In [4]:
class subNNtf(tf.keras.layers.Layer):

    def __init__(self, n_hidden_nodes=100):
        super(subNNtf, self).__init__()
        self.lin1 = tf.keras.layers.Dense(n_hidden_nodes, activation="relu", use_bias=False,
                                            kernel_initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1))
        self.lin2 = tf.keras.layers.Dense(n_hidden_nodes, activation="relu", use_bias=False,
                                            kernel_initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1))
        self.lin3 = tf.keras.layers.Dense(1, activation=tf.identity, use_bias=False,
                                            kernel_initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1))

    def call(self, inputs, sample_weight=None, training=False):
        h1 = self.lin1(inputs)
        h2 = self.lin2(h1)
        h3 = self.lin3(h2)
        return h3
    
class pyGAM1(tf.keras.Model):
    
    def __init__(self, K=50):
        super(pyGAM1, self).__init__()
        self.K = K
        self.subnn = [subNNtf() for k in range(self.K)]
        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_fn = tf.keras.losses.MeanSquaredError()

    def __call__(self, inputs):
        out = []
        for k, subnn in enumerate(self.subnn):
            xk = tf.gather(inputs, [k], axis=1)
            out.append(subnn(xk))
        return tf.reduce_sum(tf.squeeze(tf.stack(out, 1)), 1)
    
    @tf.function
    def train(self, inputs, label):

        with tf.GradientTape() as tape:
            pred = self.__call__(inputs)
            total_loss = self.loss_fn(label, pred)

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
class pyGAM2(tf.keras.Model):
    
    def __init__(self, K=50):
        super(pyGAM2, self).__init__()
        self.K = K
        self.tfw1 = self.add_weight(name="w1", shape=[K, 1, n_nodes], trainable=True, dtype=tf.float32,
                                    initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1))
        self.tfw2 = self.add_weight(name="w2", shape=[K, n_nodes, n_nodes], trainable=True, dtype=tf.float32,
                                    initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1))
        self.tfw3 = self.add_weight(name="w3", shape=[K, n_nodes, 1], trainable=True, dtype=tf.float32,
                                    initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1))
        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_fn = tf.keras.losses.MeanSquaredError()

    def __call__(self, inputs):
        inputs = tf.cast(inputs, tf.float32)
        xs = tf.expand_dims(tf.transpose(inputs, [1, 0]), 2)
        h1 = tf.nn.relu(tf.linalg.matmul(xs, self.tfw1))
        h2 = tf.nn.relu(tf.linalg.matmul(h1, self.tfw2))
        h3 = tf.linalg.matmul(h2, self.tfw3)
        return tf.reduce_sum(tf.squeeze(tf.transpose(h3, [1, 0, 2])), 1)

    @tf.function
    def train(self, inputs, label):

        with tf.GradientTape() as tape:
            pred = self.__call__(inputs)
            total_loss = self.loss_fn(label, pred)

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

In [5]:
%%time
clf1 = pyGAM1(K=n_features)
for i in range(n_epochs):
    clf1.train(x, y)

CPU times: user 41.8 s, sys: 5.51 s, total: 47.3 s
Wall time: 6.96 s


In [6]:
%%time
clf2 = pyGAM2(K=n_features)
for i in range(n_epochs):
    clf2.train(tf.cast(x, tf.float32), y)

CPU times: user 37 s, sys: 11.2 s, total: 48.2 s
Wall time: 14.4 s


## Pytorch 

In [7]:
import torch
from torch.nn import Linear

class subNN(torch.nn.Module):
    
    def __init__(self, n_hidden_nodes=100):
        super(subNN, self).__init__()
        self.lin1 = Linear(1, n_hidden_nodes, bias=False)
        self.lin2 = Linear(n_hidden_nodes, n_hidden_nodes, bias=False)
        self.lin3 = Linear(n_hidden_nodes, 1, bias=False)
        torch.nn.init.uniform_(self.lin1.weight, -0.1, 0.1)
        torch.nn.init.uniform_(self.lin2.weight, -0.1, 0.1)
        torch.nn.init.uniform_(self.lin3.weight, -0.1, 0.1)
        
    def forward(self, inputs):
        h1 = torch.relu(self.lin1(inputs))
        h2 = torch.relu(self.lin2(h1))
        h3 = self.lin3(h2)
        return h3

class pyGAM3(torch.nn.Module):
    
    def __init__(self, K=50, n_hidden_nodes=100, device="cpu"):
        super(pyGAM3, self).__init__()
        self.K = K
        self.subnn = torch.nn.ModuleList([subNN(n_hidden_nodes=n_hidden_nodes) for k in range(self.K)])
        self.lossfn = torch.nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters())
        self.to(device)

    def forward(self, inputs):
        out = []
        for k, subnn in enumerate(self.subnn):
            xk = inputs[:, [k]]
            out.append(subnn(xk))
        return torch.sum(torch.squeeze(torch.stack(out, 1)), 1)
        
        
class pyGAM4(torch.nn.Module):
    
    def __init__(self, K=50, n_hidden_nodes=100, device="cpu"):
        super(pyGAM4, self).__init__()
        self.K = K
        self.ww1 = torch.empty(size=(K, 1, n_hidden_nodes), dtype=torch.float,
                                           requires_grad=True, device=device)
        self.ww2 = torch.empty(size=(K, n_hidden_nodes, n_hidden_nodes), dtype=torch.float,
                                           requires_grad=True, device=device)
        self.ww3 = torch.empty(size=(K, n_hidden_nodes, 1), dtype=torch.float,
                                           requires_grad=True, device=device)
        
        torch.nn.init.uniform_(self.ww1, -0.1, 0.1)
        torch.nn.init.uniform_(self.ww2, -0.1, 0.1)
        torch.nn.init.uniform_(self.ww3, -0.1, 0.1)
        
    def forward(self, inputs):
        xs = torch.unsqueeze(torch.transpose(inputs, 0, 1), 2)
        h1 = torch.relu(torch.matmul(xs, self.ww1))
        h2 = torch.relu(torch.matmul(h1, self.ww2))
        h3 = torch.matmul(h2, self.ww3)
        return torch.sum(torch.squeeze(torch.transpose(h3, 0, 1)), 1)

In [8]:
xx = torch.tensor(x, dtype=torch.float)
yy = torch.tensor(y, dtype=torch.float)

In [9]:
# %%time
# clf1 = pyGAM3(K=n_features, n_hidden_nodes=n_nodes, device="cpu")
# lossfn = torch.nn.MSELoss()
# opt = torch.optim.Adam(clf1.parameters())
# for epoch in range(n_epochs):
#     opt.zero_grad()
#     loss = lossfn(clf1(xx), yy)
#     loss.backward()
#     opt.step()

In [10]:
%%time
clf2 = pyGAM4(K=n_features, n_hidden_nodes=n_nodes, device="cpu")
lossfn = torch.nn.MSELoss()
opt = torch.optim.Adam([clf2.ww1, clf2.ww2, clf2.ww3])
for epoch in range(n_epochs):
    opt.zero_grad()
    loss = lossfn(clf2(xx), yy)
    loss.backward()
    opt.step()

CPU times: user 3min 20s, sys: 8.84 s, total: 3min 29s
Wall time: 15 s


## Pytorch compiled

In [11]:
%%time
clf1 = torch.jit.script(pyGAM3(K=n_features, n_hidden_nodes=n_nodes, device="cpu"))
lossfn = torch.nn.MSELoss()
opt = torch.optim.Adam(clf1.parameters())
for epoch in range(n_epochs):
    opt.zero_grad()
    loss = lossfn(clf1(xx), yy)
    loss.backward()
    opt.step()

CPU times: user 46min 59s, sys: 3min 45s, total: 50min 45s
Wall time: 4min 37s


In [12]:
%%time
clf2 = torch.jit.script(pyGAM4(K=n_features, n_hidden_nodes=n_nodes, device="cpu"))
lossfn = torch.nn.MSELoss()
opt = torch.optim.Adam([clf2.ww1, clf2.ww2, clf2.ww3])
for epoch in range(n_epochs):
    opt.zero_grad()
    loss = lossfn(clf2(xx), yy)
    loss.backward()
    opt.step()

CPU times: user 7min 58s, sys: 40.4 s, total: 8min 39s
Wall time: 39.4 s


on GPU

In [13]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
clf1 = pyGAM3(K=n_features, n_hidden_nodes=n_nodes, device=device)

lossfn = torch.nn.MSELoss()
opt = torch.optim.Adam(clf1.parameters())
for epoch in range(n_epochs):
    opt.zero_grad()
    loss = lossfn(clf1(xx.to(device)), yy.to(device))
    loss.backward()
    opt.step()

    Found GPU%d %s which is of cuda capability %d.%d.
    PyTorch no longer supports this GPU because it is too old.
    The minimum cuda capability supported by this library is %d.%d.
    


In [15]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
clf2 = pyGAM4(K=n_features, n_hidden_nodes=n_nodes, device=device)

lossfn = torch.nn.MSELoss()
opt = torch.optim.Adam([clf2.ww1, clf2.ww2, clf2.ww3])
for epoch in range(n_epochs * 100):
    opt.zero_grad()
    loss = lossfn(clf2(xx.to(device)), yy.to(device))
    loss.backward()
    opt.step()