In [2]:
import argparse
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random

if torch.cuda.is_available():
  device = torch.device('cuda:0')
  print('GPU')
else:
  device = torch.device('cpu')
  print('CPU')

CPU


In [3]:
def parse_args():
  parser = argparse.ArgumentParser()

  parser.add_argument('--directed', default=True)
  parser.add_argument('--time_step', default=40)
  parser.add_argument('--CP_true', default=[10,20,30])

  parser.add_argument('--latent_dim', default=50)
  parser.add_argument('--w_dim', default=50)
  parser.add_argument('--num_node', default=30)

  parser.add_argument('--shared_layer', default=[128, 256, 512])
  parser.add_argument('--output_layer', default=[64, 128, 256])
  parser.add_argument('--num_samples', default=100)
  parser.add_argument('--langevin_K', default=50)
  parser.add_argument('--langevin_s', default=0.5)

  parser.add_argument('--decoder_lr', default=0.01)
  parser.add_argument('--decay_rate', default=0.01)
  parser.add_argument('--penalty', default=10)
  parser.add_argument('--mu_lr', default=0.01)
  parser.add_argument('--epoch',default=100)
  parser.add_argument('-f', required=False) # needed in Colab

  return parser.parse_args()

###################
args = parse_args()

In [None]:
torch.manual_seed(0)

rho = 0.0
n = args.num_node
num_time = args.time_step
K = 3
v = args.CP_true
data = torch.zeros(num_time, n, n)
sum_holder =[]


for t in range(num_time):
    if t == 0 or t == v[1]:
        P = torch.full((n, n), 0.3)
        P[:n // K, :n // K] = 0.5
        P[n // K:2 * (n // K), n // K:2 * (n // K)] = 0.5
        P[2 * (n // K):n, 2 * (n // K):n] = 0.5
        torch.diagonal(P).zero_()
        A = torch.bernoulli(P)

    if t == v[0] or t == v[2]:
        Q = torch.full((n, n), 0.2)
        Q[:n // K, :n // K] = 0.45
        Q[n // K:2 * (n // K), n // K:2 * (n // K)] = 0.45
        Q[2 * (n // K):n, 2 * (n // K):n] = 0.45
        torch.diagonal(Q).zero_()
        A = torch.bernoulli(Q)

    if (t > 0 and t < v[0]) or (t > v[1] and t < v[2]):
        aux1 = (1 - P) * rho + P
        aux2 = P * (1 - rho)
        aux1 = torch.bernoulli(aux1)
        aux2 = torch.bernoulli(aux2)
        A = aux1 * A + aux2 * (1 - A)

    if (t > v[0] and t < v[1]) or (t > v[2] and t <= num_time):
        aux1 = (1 - Q) * rho + Q
        aux2 = Q * (1 - rho)
        aux1 = torch.bernoulli(aux1)
        aux2 = torch.bernoulli(aux2)
        A = aux1 * A + aux2 * (1 - A)

    torch.diagonal(A).zero_()
    #A = A + torch.eye(args.num_node)
    data[t,:,:] = A.clone()
    sum_holder.append(torch.sum(A))

print(data.shape)

plt.plot(np.arange(0, args.time_step), sum_holder)  
plt.show()  

In [None]:
'''
plt.imshow(data[9,:,:].numpy(),cmap="Greys")
plt.show()
plt.imshow(data[10,:,:].numpy(),cmap="Greys")
plt.show()

plt.imshow(data[19,:,:].numpy(),cmap="Greys")
plt.show()
plt.imshow(data[20,:,:].numpy(),cmap="Greys")
plt.show()

plt.imshow(data[29,:,:].numpy(),cmap="Greys")
plt.show()
plt.imshow(data[30,:,:].numpy(),cmap="Greys")
plt.show()
'''

In [5]:
class CPD(nn.Module):
  def __init__(self, args):
    super(CPD, self).__init__()
    
    #self.l1 = nn.Linear( args.latent_dim,      args.shared_layer[0] )
    #self.l2 = nn.Linear( args.shared_layer[0], args.shared_layer[1] )
    #self.l3 = nn.Linear( args.shared_layer[1], args.shared_layer[2] )
    #self.l4 = nn.Linear( args.shared_layer[2], args.num_node * args.num_node )

    self.l1 = nn.Linear( args.latent_dim, args.output_layer[0] )
    self.left1 = nn.Linear( args.output_layer[0], args.output_layer[1] ) 
    self.left2 = nn.Linear( args.output_layer[1], args.num_node * args.w_dim ) 
    self.middle1 = nn.Linear( args.output_layer[0], args.output_layer[1] ) 
    self.middle2 = nn.Linear( args.output_layer[1], args.w_dim * args.w_dim ) 
    self.right1 = nn.Linear( args.output_layer[0], args.output_layer[1] ) 
    self.right2 = nn.Linear( args.output_layer[1], args.num_node * args.w_dim ) 

  def forward(self, z):
    #output = self.l1(z).tanh()
    #output = self.l2(output).tanh()
    #output = self.l3(output).tanh()
    #output = self.l4(output).sigmoid()
    
    output = self.l1(z).tanh()
    w_left = self.left1(output).tanh()
    w_left = self.left2(w_left).tanh()
    w_middle = self.middle1(output).tanh()
    w_middle = self.middle2(w_middle).tanh()
    w_right = self.right1(output).tanh() 
    w_right = self.right2(w_right).tanh()

    w_left = w_left.reshape(args.num_samples, args.num_node, args.w_dim)
    w_middle = w_middle.reshape(args.num_samples, args.w_dim, args.w_dim)
    w_right = w_right.reshape(args.num_samples, args.num_node, args.w_dim)
    output = torch.bmm(torch.bmm(w_left, w_middle),torch.transpose(w_right, 1, 2)).sigmoid() # n by n

    return output
    
  def infer_z(self, z, adj_gt_vec, mu_t):
    '''
    z: m by d
    adj_gt_vec: m*n*n (with repetition)
    mu_t_mat: m by d (with repetition)
    '''

    criterion = nn.BCELoss(reduction='sum') # take the sum ???? divided by m

    for k in range(args.langevin_K):

      z = z.detach().clone()
      z.requires_grad = True
      assert z.grad is None

      adj_prob = self.forward(z) # m by (n*n)
      nll = criterion( adj_prob.view(-1), adj_gt_vec ) # both are m*n*n
      z_grad_nll = torch.autograd.grad(nll, z)[0] # m by d 

      z = z - args.langevin_s * (z_grad_nll + (z-mu_t)) + \
          torch.sqrt(2*torch.tensor(args.langevin_s)) * torch.randn(args.num_samples, args.latent_dim).to(device)

    z = z.detach().clone()
    return z


In [6]:
data = data.to(device)
T = data.shape[0] # data is T by n by n
mu = torch.zeros(T, args.latent_dim).to(device) # initialize as random, divided by norm of row diff (cannot be identical)


mu_old = mu.detach().clone()
loss_holder = []

model = CPD(args).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.decoder_lr) # weight_decay=args.decay_rate
criterion = nn.BCELoss(reduction='sum') # sum for expectation, later divided by m
model.train()

CPD(
  (l1): Linear(in_features=50, out_features=64, bias=True)
  (left1): Linear(in_features=64, out_features=128, bias=True)
  (left2): Linear(in_features=128, out_features=1500, bias=True)
  (middle1): Linear(in_features=64, out_features=128, bias=True)
  (middle2): Linear(in_features=128, out_features=2500, bias=True)
  (right1): Linear(in_features=64, out_features=128, bias=True)
  (right2): Linear(in_features=128, out_features=1500, bias=True)
)

In [None]:
for learning_iter in range(args.epoch):

  loss = 0.0

  for t in range(T):
    # transformation
    mu_t = mu[t,:].clone() # d
    #mu_t_mat = mu_t.repeat(args.num_samples, 1) # m by d (with repetition)

    adj_gt = data[t,:,:].clone() # n by n
    adj_gt_vec = adj_gt.view(-1).repeat(args.num_samples) # m*n*n (with repetition)
    
    # sample from posterior
    init_z = torch.randn(args.num_samples, args.latent_dim).to(device) # m by d, starts from N(0,1)
    sampled_z = model.infer_z(init_z, adj_gt_vec, mu_t) # m by d, m samples of z from langevin

    adj_prob = model(sampled_z) # m by (n*n) # m samples of adj_prob from the decoder
    loss += criterion(adj_prob.view(-1), adj_gt_vec) / args.num_samples  # both are m*n*n

  #loss /= args.time_step * args.num_samples # sum divided by m * T
  loss_holder.append(loss.detach().cpu().numpy())

    # update decoder (after all time t)
  for param in model.parameters():
    param.grad = None
  loss.backward()
  optimizer.step()

  if (learning_iter+1) % 10 == 0:
    print('\n')
    print('learning iter =', learning_iter)
    print('decoder loss =',loss)

In [15]:
mu = torch.randn(T, args.latent_dim).to(device) # cannot initialize as zero
print(mu)

tensor([[ 0.7137,  0.9684,  0.2790,  ..., -1.1308, -0.6657,  1.1658],
        [ 0.1002, -1.1241, -0.6680,  ...,  0.1527,  0.6374, -0.0591],
        [ 1.1298, -0.0846, -2.2610,  ...,  0.4736,  0.5762,  0.8806],
        ...,
        [-0.0831, -0.9283, -0.5310,  ..., -0.4127, -0.3400, -0.6185],
        [ 0.6562,  0.3045, -0.0962,  ...,  0.8002, -0.9390, -1.0531],
        [-1.2125, -0.5826, -0.7662,  ..., -1.0184, -0.5725,  1.1501]],
       device='cuda:0')


In [None]:

for learning_iter in range(100):
  
  for t in range(T):
    # transformation
    mu_t = mu[t,:].clone() # d
    #mu_t_mat = mu_t.repeat(args.num_samples, 1) # m by d (with repetition)

    adj_gt = data[t,:,:].clone() # n by n
    adj_gt_vec = adj_gt.view(-1).repeat(args.num_samples) # m*n*n (with repetition)
    
    # sample from posterior
    init_z = torch.randn(args.num_samples, args.latent_dim).to(device) # m by d, starts from N(0,1)
    sampled_z = model.infer_z(init_z, adj_gt_vec, mu_t) # m by d, m samples of z from langevin

    if t == 0:  
      grad_mu_t = -(sampled_z - mu_t).mean(dim=0) - args.penalty * (1/torch.norm(mu[1,:] - mu[0,:],p=2)) * (mu[1,:] - mu[0,:])
    elif t == T-1:
      grad_mu_t = -(sampled_z - mu_t).mean(dim=0) + args.penalty * (1/torch.norm(mu[t,:] - mu[t-1,:],p=2)) * (mu[t,:] - mu[t-1,:])
    else:
      grad_mu_t = -(sampled_z - mu_t).mean(dim=0) - args.penalty * (1/torch.norm(mu[t+1,:] - mu[t,:],p=2)) * (mu[t+1,:] - mu[t,:]) \
                                                  + args.penalty * (1/torch.norm(mu[t,:] - mu[t-1,:],p=2)) * (mu[t,:] - mu[t-1,:]) 

    mu = mu.detach().clone()
    mu[t,:] -=  0.01 * grad_mu_t # gradient descent # args.mu_lr

  if (learning_iter+1) % 10 == 0:
    print('\n')
    print('learning iter =', learning_iter)
    print('mu residual =',torch.mean((mu-mu_old)**2))
    print('mu relative difference =',torch.norm(mu-mu_old,  p='fro') / torch.norm(mu_old,  p='fro'))
    mu_old = mu.detach().clone()

    signal = torch.norm(mu, p=2, dim=1)**2 
    signal = signal.cpu().detach().numpy()
    plt.plot(np.arange(0, args.time_step), signal)  
    plt.show() 

In [None]:
signal = torch.norm(mu, p=2, dim=1)**2 
signal = signal.cpu().detach().numpy()
plt.plot(np.arange(0, args.time_step), signal)  
plt.show() 

In [None]:
#mu, loss_holder = main(args, data)

In [None]:
# torch.diff: second row - first row (then minus sign)# torch.norm: 2-norm of each row# take a squared
#signal = 0.5*torch.norm(-torch.diff(mu, dim=0), p=2, dim=1)**2 
#signal = signal.cpu().detach().numpy()
#plt.plot(np.arange(0, args.time_step-1), signal)  
#plt.show()  

#print(len(loss_holder))
plt.plot(np.arange(0,len(loss_holder)), loss_holder)  
plt.show()  