In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import os
import copy

# Note: Model name used in paper and in code (here) are different.
# NNPhD -> SUPERLNN
# UAN -> BASE

In [2]:
class LNN(nn.Module):
    def __init__(self, d = 2, w=200):
        super(LNN, self).__init__()
        
        self.d = d
        self.l11 = nn.Linear(2*d,w)
        self.l12 = nn.Linear(2*d,w)
        self.l13 = nn.Linear(2*d,w)
        self.l3 = nn.Linear(2*w,1)
        
    def forward(self, x):
        self.x12 = self.l12(x)**2
        self.x13 = self.l13(x)
        self.x1 = torch.cat([self.x12,self.x13],dim=1)
        self.x3 = self.l3(self.x1)
        return self.x3
    
    def Lq(self, y, x):
        grads = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True, retain_graph=True)[0]
        return grads

    def Lqq(self, y, x):
        grads = self.Lq(y, x)
        ggs = []
        shp = grads.shape
        for i in range(shp[0]):
            gg = []
            for j in range(shp[1]):
                g = torch.autograd.grad(grads[i,j], x, torch.tensor(1.), create_graph=True, retain_graph=True)[0][i]
                gg.append(g)
            gg = torch.stack(gg)
            ggs.append(gg)
        return torch.stack(ggs)

    def qtt(self, Lqq_, Lq_, qt_):
        a = torch.tensor([[1,0],[0,1]], dtype=torch.float)
        a = Variable(a, requires_grad=True)
        Lvv = a # Lqq_[:,self.d:2*self.d,self.d:2*self.d] + a
        Lxv = Lqq_[:,self.d:2*self.d,0:self.d]
        result = torch.matmul(torch.inverse(Lvv), torch.unsqueeze(Lq_[:,0:self.d],dim=2)-torch.matmul(Lxv,torch.unsqueeze(qt_,dim=2)))
        return result
    
    def predict(self, x):
        outputs = self.forward(x)
        Lqq_ = self.Lqq(outputs,x)
        Lq_ = self.Lq(outputs,x)
        qt_ = Variable(x[:,self.d:2*self.d],requires_grad=True)
        qtt_pred = self.qtt(Lqq_, Lq_, qt_)
        return qtt_pred

class BASE(nn.Module):
    def __init__(self, d = 2, w=200, mode="0"):
        super(BASE, self).__init__()
        self.mode = mode
        self.d = d
        if mode == "t":
            self.l1 = nn.Linear(1,w)
        elif mode == "qq":
            self.l1 = nn.Linear(2*d,w)
        elif mode == "qqt":
            self.l1 = nn.Linear(2*d+1,w)
        self.l2 = nn.Linear(w,w)
        self.l3 = nn.Linear(w,d)
        
    def forward(self, x):
        act = "relu"
        f = nn.LeakyReLU(0.1)
        
        if self.mode == "t":
            self.x1 = f(self.l1(torch.unsqueeze(x[:,-1],dim=1)))
        elif self.mode == "qq":
            self.x1 = f(self.l1(x[:,:2*self.d]))
        elif self.mode == "qqt":
            self.x1 = f(self.l1(x))
        self.x2 = f(self.l2(self.x1))
        self.x3 = self.l3(self.x2)
        return self.x3

class SUPERLNN(nn.Module):
    def __init__(self, d = 2, w=200, mode="0"):
        super(SUPERLNN, self).__init__()
        self.lnn = LNN(d=d)
        self.d = d
        self.base = BASE(d=d, mode=mode)
        self.mode = mode
            
    def forward(self, x):
        self.x1 = self.lnn(x[:,:2*self.d])
        if self.mode == "0":
            return self.x1
        else:
            self.x2 = self.base(x)
            return self.x1 + self.x2




In [3]:
# data
n_train = 100
d = 2
x = torch.normal(0,1,size=(n_train,2*d+1))
y = torch.normal(0,1,size=(n_train,d))
x = Variable(x, requires_grad=True)
y[:,0] = -x[:,0] + x[:,3]
y[:,1] = -x[:,1] - x[:,2]

In [4]:
#modes = ["qqt","qq","t","0"]
modes = ["qqt"]

batch_size = 32
log = 10
lrs = [1e-2,1e-3,1e-4,1e-5]

for j in range(len(modes)):
    np.random.seed(1)
    torch.manual_seed(1)
    loss_qtts = []
    mode = modes[j]
    print("***************mode={}***************".format(mode))
    superlnn = SUPERLNN(d=2, mode=mode)
    epochs = 1000
    if mode == "0":
        lambs = [0]
    else:
        lambs = [0.01,0.02,0.05,0.1,0.2,0.5,1,2,5,10,20,50,100]
        #lambs = [2]
    for i in range(len(lambs)):
        lamb = lambs[i]
        print("--------------lamb={}-------------".format(lamb))
        for epoch in range(epochs):
            if epoch % int(epochs/4) == 0:
                idd = int(np.ceil(4*epoch/epochs))
                optimizer = optim.Adam(superlnn.parameters(), lr = lrs[idd])
            superlnn.train()
            optimizer.zero_grad()
            if epoch < epochs-1:
                choices = np.random.choice(n_train, batch_size)
            else:
                choices = np.arange(n_train)
            inputs = torch.tensor(x[choices] ,dtype=torch.float, requires_grad=True)
            labels = torch.tensor(y[choices] ,dtype=torch.float, requires_grad=True)
            # LNN predict
            lnn_pred = superlnn.lnn.predict(inputs[:,:2*superlnn.d])[:,:,0]
            # Base predict
            if mode == "0":
                loss = loss_qtt
            else:
                base_pred = superlnn.base(inputs)
                qtt_pred = lnn_pred + base_pred
                if d==1:
                    loss_qtt = torch.norm(qtt_pred[:,0]-labels, p=2)/torch.sqrt(torch.tensor(choices.shape[0],dtype=torch.float))
                else:
                    loss_qtt = torch.norm(qtt_pred-labels, p=2)/torch.sqrt(torch.tensor(choices.shape[0],dtype=torch.float))
                loss_baseact = torch.norm(base_pred, p=2)/torch.sqrt(torch.tensor(choices.shape[0],dtype=torch.float))
                loss = loss_qtt + lamb * loss_baseact
            loss.backward(retain_graph=True)
            optimizer.step()
            if epoch%log == 0:
                if mode == "0":
                    print('Epoch:  %d | Loss_qtt: %.4f' %(epoch, loss_qtt))
                else:
                    print('Epoch:  %d | Loss_qtt: %.4f | Loss_base: %.4f' %(epoch, loss_qtt, loss_baseact))
        loss_qtts.append(loss_qtt.detach().numpy())
        torch.save(superlnn.state_dict(), "./models/mag/mode_%s_lamb_%.2f_p2"%(mode, lamb))
    loss_qtts = np.array(loss_qtts)
    np.savetxt('./results/mag/mode_%s_p2'%mode, np.array([np.array(lambs), loss_qtts]))
    




***************mode=qqt***************
--------------lamb=0.01-------------




Epoch:  0 | Loss_qtt: 1.6502 | Loss_base: 0.0818
Epoch:  10 | Loss_qtt: 0.4974 | Loss_base: 0.5645
Epoch:  20 | Loss_qtt: 0.2045 | Loss_base: 0.1771
Epoch:  30 | Loss_qtt: 0.1540 | Loss_base: 0.1243
Epoch:  40 | Loss_qtt: 0.0844 | Loss_base: 0.0585
Epoch:  50 | Loss_qtt: 0.0688 | Loss_base: 0.0498
Epoch:  60 | Loss_qtt: 0.0917 | Loss_base: 0.0529
Epoch:  70 | Loss_qtt: 0.0607 | Loss_base: 0.0323
Epoch:  80 | Loss_qtt: 0.1270 | Loss_base: 0.0388
Epoch:  90 | Loss_qtt: 0.0488 | Loss_base: 0.0327
Epoch:  100 | Loss_qtt: 0.0761 | Loss_base: 0.0333
Epoch:  110 | Loss_qtt: 0.0317 | Loss_base: 0.0251
Epoch:  120 | Loss_qtt: 0.0549 | Loss_base: 0.0207
Epoch:  130 | Loss_qtt: 0.0883 | Loss_base: 0.0228
Epoch:  140 | Loss_qtt: 0.0643 | Loss_base: 0.0154
Epoch:  150 | Loss_qtt: 0.0697 | Loss_base: 0.0190
Epoch:  160 | Loss_qtt: 0.0417 | Loss_base: 0.0193
Epoch:  170 | Loss_qtt: 0.0550 | Loss_base: 0.0185
Epoch:  180 | Loss_qtt: 0.0453 | Loss_base: 0.0104
Epoch:  190 | Loss_qtt: 0.0545 | Loss_base

Epoch:  190 | Loss_qtt: 0.0304 | Loss_base: 0.0025
Epoch:  200 | Loss_qtt: 0.0339 | Loss_base: 0.0026
Epoch:  210 | Loss_qtt: 0.0224 | Loss_base: 0.0017
Epoch:  220 | Loss_qtt: 0.0208 | Loss_base: 0.0019
Epoch:  230 | Loss_qtt: 0.0100 | Loss_base: 0.0015
Epoch:  240 | Loss_qtt: 0.0269 | Loss_base: 0.0055
Epoch:  250 | Loss_qtt: 0.0263 | Loss_base: 0.0015
Epoch:  260 | Loss_qtt: 0.0088 | Loss_base: 0.0021
Epoch:  270 | Loss_qtt: 0.0045 | Loss_base: 0.0005
Epoch:  280 | Loss_qtt: 0.0017 | Loss_base: 0.0006
Epoch:  290 | Loss_qtt: 0.0023 | Loss_base: 0.0006
Epoch:  300 | Loss_qtt: 0.0022 | Loss_base: 0.0002
Epoch:  310 | Loss_qtt: 0.0021 | Loss_base: 0.0005
Epoch:  320 | Loss_qtt: 0.0028 | Loss_base: 0.0005
Epoch:  330 | Loss_qtt: 0.0018 | Loss_base: 0.0003
Epoch:  340 | Loss_qtt: 0.0023 | Loss_base: 0.0001
Epoch:  350 | Loss_qtt: 0.0029 | Loss_base: 0.0004
Epoch:  360 | Loss_qtt: 0.0024 | Loss_base: 0.0002
Epoch:  370 | Loss_qtt: 0.0040 | Loss_base: 0.0002
Epoch:  380 | Loss_qtt: 0.0023 

Epoch:  380 | Loss_qtt: 0.0018 | Loss_base: 0.0002
Epoch:  390 | Loss_qtt: 0.0015 | Loss_base: 0.0001
Epoch:  400 | Loss_qtt: 0.0048 | Loss_base: 0.0001
Epoch:  410 | Loss_qtt: 0.0018 | Loss_base: 0.0001
Epoch:  420 | Loss_qtt: 0.0033 | Loss_base: 0.0004
Epoch:  430 | Loss_qtt: 0.0028 | Loss_base: 0.0005
Epoch:  440 | Loss_qtt: 0.0034 | Loss_base: 0.0002
Epoch:  450 | Loss_qtt: 0.0025 | Loss_base: 0.0002
Epoch:  460 | Loss_qtt: 0.0007 | Loss_base: 0.0002
Epoch:  470 | Loss_qtt: 0.0015 | Loss_base: 0.0002
Epoch:  480 | Loss_qtt: 0.0017 | Loss_base: 0.0002
Epoch:  490 | Loss_qtt: 0.0023 | Loss_base: 0.0004
Epoch:  500 | Loss_qtt: 0.0038 | Loss_base: 0.0003
Epoch:  510 | Loss_qtt: 0.0010 | Loss_base: 0.0001
Epoch:  520 | Loss_qtt: 0.0004 | Loss_base: 0.0000
Epoch:  530 | Loss_qtt: 0.0004 | Loss_base: 0.0000
Epoch:  540 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  550 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  560 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  570 | Loss_qtt: 0.0001 

Epoch:  570 | Loss_qtt: 0.0003 | Loss_base: 0.0000
Epoch:  580 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  590 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  600 | Loss_qtt: 0.0003 | Loss_base: 0.0000
Epoch:  610 | Loss_qtt: 0.0003 | Loss_base: 0.0000
Epoch:  620 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  630 | Loss_qtt: 0.0004 | Loss_base: 0.0000
Epoch:  640 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  650 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  660 | Loss_qtt: 0.0004 | Loss_base: 0.0000
Epoch:  670 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  680 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  690 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  700 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  710 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  720 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  730 | Loss_qtt: 0.0001 | Loss_base: 0.0000
Epoch:  740 | Loss_qtt: 0.0003 | Loss_base: 0.0000
Epoch:  750 | Loss_qtt: 0.0002 | Loss_base: 0.0000
Epoch:  760 | Loss_qtt: 0.0001 

Epoch:  760 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  770 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  780 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  790 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  800 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  810 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  820 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  830 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  840 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  850 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  860 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  870 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  880 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  890 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  900 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  910 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  920 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  930 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  940 | Loss_qtt: 0.0000 | Loss_base: 0.0000
Epoch:  950 | Loss_qtt: 0.0000 