In [1]:
from typing import Callable, List, Tuple

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions.distribution import Distribution
from pytorch_lightning.loggers import TensorBoardLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
import torch
from torch import nn
import numpy as np

class L1(nn.Linear):
    def __init__(self,n):
        super().__init__(n*(n-1)//2,1,bias=False)
        self.n=n
        self.mask2d=torch.zeros((n*(n-1) // 2),dtype=int)
        for i in range(1,n):
            for j in range(i):
                self.mask2d[i*(i-1)//2+j]=i*n+j
        self.d_ind=[(n+1)*k for k in range(n)]
        self.ones=torch.ones((n))
    
    def anti_flatten(self):
        n=self.n
        L=torch.zeros((n**2)).to(device)
        L[self.mask2d]=self.weight.to(device)
        L[self.d_ind]=self.ones
        L=torch.reshape(L,(n,n)).to(device)
        return L        
            
    def log_abs_det(self):
        return 0
    
    def forward(self,x):
        Lwt=torch.t(self.anti_flatten())
        return torch.matmul(x,Lwt)
    
    def g(self,z):
        return self.forward(z), torch.zeros((z.shape[0]))
        
    def adj(self,mat):
        Lw=self.anti_flatten()
        Lwt=torch.t(Lw)
        D=torch.matmul(Lwt,torch.matmul(mat,Lw))
        return D
    

In [3]:
import torch
from torch import nn
import numpy as np

class D(nn.Linear):
    def __init__(self,n):
        super().__init__(n,1,bias=False)
        self.n=n
        self.d_ind=[(n+1)*k for k in range(n)]
    
    def anti_flatten(self):
        n=self.n
        D=torch.zeros((n**2)).to(device)
        D[self.d_ind]=self.weight.to(device)
        D=torch.reshape(D,(n,n)).to(device)
        return D        
            
    def log_abs_det(self):
        ABS=torch.abs(self.weight)
        l=torch.log(ABS)
        lad=torch.sum(l)
        return lad
    
    def forward(self,x):
        D=self.anti_flatten()
        return torch.matmul(x,D)
    
    def g(self,z):
        lad=self.log_abs_det()
        return self.forward(z), lad * torch.ones((z.shape[0]))
        
    def adj(self,mat):
        Lw=self.anti_flatten()
        Lwt=torch.t(Lw)
        D=torch.matmul(Lwt,torch.matmul(mat,Lw))
        return D
    

In [4]:
m=L1(10)
#print(m.anti_flatten())
d=D(10)
#print(d.anti_flatten())

In [5]:
z=torch.rand((10,10))
x, lad= m.g(z)
x, lad= d.g(z)
print(x, lad)

tensor([[-0.1542,  0.0047,  0.0486,  0.0127,  0.1100, -0.0590, -0.0167, -0.0297,
         -0.2712, -0.2027],
        [-0.0516,  0.0984,  0.0572,  0.1036,  0.1347, -0.1432, -0.0105, -0.0316,
         -0.2417, -0.0762],
        [-0.1121,  0.0695,  0.0009,  0.0702,  0.0379, -0.0906, -0.0063, -0.0245,
         -0.2078, -0.1494],
        [-0.1291,  0.1591,  0.0083,  0.1512,  0.0816, -0.1365, -0.0169, -0.0723,
         -0.1708, -0.0077],
        [-0.1400,  0.0687,  0.0496,  0.0928,  0.0417, -0.0191, -0.0006, -0.0902,
         -0.1710, -0.1761],
        [-0.1289,  0.0262,  0.0193,  0.0812,  0.0394, -0.0555, -0.0045, -0.1020,
         -0.1071, -0.0184],
        [-0.0656,  0.1773,  0.0030,  0.0091,  0.0302, -0.1210, -0.0071, -0.0334,
         -0.0228, -0.1690],
        [-0.1687,  0.1891,  0.0626,  0.1293,  0.1479, -0.1799, -0.0174, -0.0826,
         -0.2242, -0.1378],
        [-0.0242,  0.1924,  0.0221,  0.0611,  0.0257, -0.0664, -0.0079, -0.0229,
         -0.2622, -0.2476],
        [-0.0376,  

In [6]:
import torch
from torch import nn
import numpy as np

class L(nn.Linear):
    def __init__(self,n):
        super().__init__(n*(n+1)//2,1,bias=False)
        self.n=n
        self.diag_mask=torch.tensor([ (k+1) * (k+2) // 2 - 1 for k in range(n)])
        self.mask2d=torch.zeros((n*(n+1) // 2),dtype=int)
        for i in range(n):
            for j in range(i+1):
                self.mask2d[i*(i+1)//2+j]=i*n+j
    
    def anti_flatten(self):
        n=self.n
        L=torch.zeros((n**2)).to(device)
        L[self.mask2d]=self.weight.to(device) 
        L=torch.reshape(L,(n,n)).to(device)
        return L        
            
    def log_abs_det(self):
        diag=self.weight[0][self.diag_mask]
        la=torch.log(torch.abs(diag))
        lad=torch.sum(la)
        return lad
    
    def forward(self,x):
        Lwt=torch.t(self.anti_flatten())
        return torch.matmul(x,Lwt)
        
    def adj(self,mat):
        Lw=self.anti_flatten()
        Lwt=torch.t(Lw)
        D=torch.matmul(Lwt,torch.matmul(mat,Lw))
        return D
    

In [7]:
from NFandist import get_O, get_A
from NFconstants import N_nod, beta
O=(torch.tensor(get_O(N_nod)).float()).to(device)
Ot=torch.t(O)
print(Ot.requires_grad)

A=(torch.tensor(get_A(N_nod,beta)).float()).to(device)
I=(torch.eye(N_nod)).to(device)

False


In [8]:
class simple_nf(nn.Module):
     
    def __init__(self,layer,ort=True):
        super().__init__()
        self.layer=layer
        self.ort=ort

    def log_abs_det(self):
        lad=self.layer.log_abs_det()
        return lad
    
    def forward(self,x):
        y=self.layer(x)
        if self.ort:
            y=torch.matmul(y,Ot)
        return y
        
    def metric(self):
        with torch.no_grad():
            if self.ort:
                B=torch.matmul(Ot,torch.matmul(A,O))
            else:
                B=A
            met=torch.linalg.matrix_norm(self.layer.adj(B)-I)
        return met  

In [9]:
from NFconstants import N_nod, beta , N_traj
from NFandist import get_A, get_C
A=(torch.tensor(get_A(N_nod,beta)).float()).to(device)
I=(torch.eye(N_nod)).to(device)
def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

def A_I(model):
    with torch.no_grad():
        A_D=model.adj(A)
    return torch.linalg.matrix_norm(A_D-I)

In [10]:
class Cheatloss(nn.Module):
    def __init__(self):
        super(Cheatloss, self).__init__()

    def forward(self, model,lad):
        A_D=model.adj(A)
        loss=0.5*torch.trace(A_D)-lad
        return loss
CL=Cheatloss()

In [11]:
class Pipeline(pl.LightningModule):
    def __init__(
        self,
        model,
        criterion,
        optimizer_class=torch.optim.Adam,
        optimizer_kwargs={"lr": 0.001,"weight_decay": 0.01}
    ) -> None:
        super().__init__()
        self.model = model
        self.loss = criterion
        self.optimizer_class = optimizer_class
        self.optimizer_kwargs = optimizer_kwargs


    def configure_optimizers(self):
        optimizer = self.optimizer_class(
            self.model.parameters(), **self.optimizer_kwargs
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        z = batch
        x, log_abs_det = self.model(z), self.model.log_abs_det()
        #log_abs_det=self.model.log_abs_det()
        loss = self.loss(x,log_abs_det)
        self.log('train_loss', loss, prog_bar=True)
        
        return loss


    def on_train_epoch_end(self):
        metric=self.model.metric()
        self.log("metric",metric, prog_bar=True)
        #self.log("metric",A_I(self.model), prog_bar=True)
        pass

In [12]:
from LOSS import KL_osc
from Data import train_loader

set_random_seed(42)
snf=simple_nf(D(N_nod),ort=True)
pipeline=Pipeline(model=snf,criterion=KL_osc, optimizer_class=torch.optim.Adam,
        optimizer_kwargs={"lr": 0.001, "weight_decay": 0.01})
trainer = pl.Trainer(
    max_epochs=1000,
    logger=TensorBoardLogger(save_dir=f"logs/nf"),
    num_sanity_val_steps=0,
)

trainer.fit(model=pipeline, train_dataloaders=train_loader)
torch.save(snf.state_dict(), "L_layer_weights.pth")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type      | Params
------------------------------------
0 | model | simple_nf | 16    
1 | loss  | KL_with_S | 0     
------------------------------------
16        Trainable params
0         Non-trainable params
16        Total params
0.000     Total estimated model params size (MB)
C:\ProgramData\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |                                                                                      | 0/? [00:00<…

C:\ProgramData\anaconda3\Lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [56]:
L_layer=L(N_nod)
model=simple_nf(L_layer,ort=True)
#model=L_layer
model.load_state_dict(torch.load("L_layer_weights.pth"))
pipeline=Pipeline(model=model,criterion=KL_osc, optimizer_class=torch.optim.Adam,
        optimizer_kwargs={"lr": 0.001,"weight_decay": 0})

trainer = pl.Trainer(
    max_epochs=2000,
    logger=TensorBoardLogger(save_dir=f"logs/nf"),
    num_sanity_val_steps=0,
)

trainer.fit(model=pipeline, train_dataloaders=train_loader)
torch.save(model.state_dict(), "L_layer_weights.pth")

RuntimeError: Error(s) in loading state_dict for simple_nf:
	Missing key(s) in state_dict: "layer.weight". 
	Unexpected key(s) in state_dict: "weight". 

In [None]:
from NFandist import get_T
T=torch.tensor(get_T(N_nod)).float()
def G(X,n_p=N_nod):
    G=np.zeros((n_p))
    Y=X.clone()
    Xt=torch.t(X)
    for s in range(n_p):
        G[s]=torch.trace(torch.matmul(Y,Xt))
        Y=torch.matmul(Y,T)
    return G/(N_traj*N_nod)
g_nf=G(trajs)

In [None]:
from NFandist import calc_G
from NFconstants import N_nod, N_traj, NG_points,beta
g_osc=calc_G(N_nod,beta,N_nod)

In [None]:
#import sys
#sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt
from NFconstants import N_nod, N_traj, NG_points,beta
#from Value import G
#import ensemble
#from NFoscillator import basic_oscillator
#from time import time
#from NFandist import calc_G

"""
ens_nf=ensemble.ensemble.load("nf_ensemble.txt",basic_oscillator)
g_nf=np.vstack(ensemble.ensemble.Vaverage_and_sigma(ens_nf,G))
g_nf=g_nf.transpose()[0]
"""

g=g_ur
print(len(g))
fig=plt.figure()
MCMC_list=np.arange(len(g))/len(g)
NF_list=np.arange(len(g_nf))/len(g_nf)
plt.scatter(MCMC_list,g)
plt.scatter(NF_list,g_nf)
plt.legend(["MCMC","normalizing flow"])
plt.grid(True)
plt.show()

In [55]:
L_layer=L(N_nod)
L_layer.load_state_dict(torch.load("L_layer_weights.pth"))
print(L_layer.adj(A))

tensor([[ 9.7450e-01,  6.8725e-03,  1.6029e-03,  ...,  2.9996e-02,
          1.2987e-02, -1.0077e-02],
        [ 6.8728e-03,  1.0180e+00, -1.9002e-02,  ...,  4.0149e-04,
         -1.8119e-02,  1.3663e-03],
        [ 1.6030e-03, -1.9002e-02,  9.8187e-01,  ...,  1.0302e-02,
         -7.3258e-03,  9.5749e-03],
        ...,
        [ 2.9996e-02,  4.0149e-04,  1.0302e-02,  ...,  9.7388e-01,
         -4.3421e-03, -1.0112e-02],
        [ 1.2987e-02, -1.8119e-02, -7.3258e-03,  ..., -4.3421e-03,
          1.0155e+00, -2.4950e-02],
        [-1.0077e-02,  1.3663e-03,  9.5749e-03,  ..., -1.0112e-02,
         -2.4950e-02,  1.0105e+00]], grad_fn=<MmBackward0>)


In [54]:
C=(torch.tensor(get_C(N_nod,beta)).float()).to(device)
Ct=torch.t(C)
A_e=torch.matmul(Ct,torch.matmul(A,C))
print(A_e)
A_D=torch.matmul(Ot,torch.matmul(A,O))
print(A_D)
print(torch.linalg.matrix_norm(A_e-I))

tensor([[ 1.0000e+00,  5.8752e-08, -1.3914e-08,  ...,  2.2872e-08,
          2.5498e-07,  0.0000e+00],
        [ 4.0992e-08,  1.0000e+00, -2.7618e-08,  ...,  2.2800e-08,
         -2.4269e-08,  2.8902e-08],
        [-5.8108e-10, -9.3866e-08,  1.0000e+00,  ...,  4.4837e-09,
         -4.7684e-07, -5.5179e-08],
        ...,
        [-1.2623e-09,  5.3057e-09, -3.6523e-10,  ...,  1.0000e+00,
         -3.5937e-09,  1.9324e-08],
        [-7.7986e-10, -1.9018e-10,  1.1782e-10,  ...,  5.4074e-10,
          1.0000e+00, -2.6425e-09],
        [ 7.4506e-09, -1.3039e-08, -2.9977e-09,  ..., -1.6764e-08,
         -5.8208e-09,  1.0000e+00]])
tensor([[ 2.0000e-01, -7.7642e-09, -1.9571e-09,  ..., -1.1199e-07,
         -3.6891e-07,  0.0000e+00],
        [-6.7893e-10,  2.7885e-01, -2.3201e-08,  ...,  2.8581e-08,
          3.0639e-09,  3.8130e-08],
        [-4.9183e-10, -5.3233e-09,  2.7885e-01,  ..., -8.5684e-10,
         -9.5367e-07, -5.2030e-07],
        ...,
        [ 3.9386e-09, -8.6246e-11, -2.1778e-09