In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F

from models.dataset import *
from models.discriminator import *
from models.generator import *
from models.inverter import *
from gw_loss import *

from models.GAM.src.param_parser import *
from models.GAM.src.gam import *

from tqdm import tqdm
import warnings

from ot.gromov import gromov_wasserstein
from models.args import *

In [2]:
def fxn():
    warnings.warn("deprecated", UserWarning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()

def choose_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'

In [4]:
args = Args()
train, labels = get_dataset_with_label(args.graph_type) # entire dataset as train
train_dataset = Graph_sequence_sampler_pytorch(train, labels, args)
train_loader, adj_shape = get_dataloader_labels(train_dataset, args)
noise_dim = args.hidden_size_rnn

GAMachineTrainer = GAMTrainer(args, args.graph_type) # maps from graphs to latent space (of embeddings)
gam_optimizer = torch.optim.Adam(GAMachineTrainer.model.parameters(),
                                lr=args.learning_rate,
                                weight_decay=args.weight_decay)
netG = GraphRNN(args=args)
optimizer_trainer, G_optimizer_output, G_scheduler_rnn, G_scheduler_output = netG.init_optimizer(lr=0.1)
optimizer_generator = None

all_data = []
for d in train_loader:
    all_data.append(d)

calculating max previous node, total iteration: 12000
iter 0 times
iter 2400 times
iter 4800 times
iter 7200 times
iter 9600 times
max previous node: 10


In [21]:
loss = nn.MSELoss()
simple = nn.Linear(29, 29)
optimizer = optim.SGD(simple.parameters(), lr=0.01, momentum=0.9)

In [39]:
def trainAE(args, data):
    # TODO: figure out what is "data"

    ### from Barry's code
    X = data['x']
    Y = data['y']
    adj_mat = data['adj_mat']
    label = data['label']
    Y_len = data['len']
    ###

    ###############################################################################################
    # Embedding (GAMTrainer) Update
    ###############################################################################################
    
    GAMachineTrainer.model.train()
    netG.eval()
    for e in range(args.embedding_iters):
        noise = torch.randn(args.batch_size, args.hidden_size_rnn)
        fake_graphs = []
        # for b in range(args.batch_size):
        fake_graph = netG(noise, X, Y, Y_len) # noise[b, :]
        fake_graphs = fake_graph

        # print(fake_graph.size())
        # print(len(fake_graphs))
        # TODO: unpad (a.k.a. pack) fake_graphs

        true_graphs = adj_mat # TODO: sample from true dataset, e.g. MUTAG
        # print(true_graphs.size())

        # print(adj_mat[0])

        # from process_batch
        optimizer_trainer.zero_grad()
        batch_loss = 0
        counter = 0
        for adj in true_graphs:
            batch_loss = GAMachineTrainer.process_graph(
                batch_loss=batch_loss, 
                already_matrix=True, 
                adj=adj[0:Y_len[counter], 0:Y_len[counter]], target=0 # may be worth passing classes instead of generic "true" label
            )
            counter += 1
        counter = 0
        for adj in fake_graphs:
            # print(adj)
            batch_loss = GAMachineTrainer.process_graph(
                batch_loss=batch_loss, 
                already_matrix=True, 
                adj=adj[0:Y_len[counter], 0:Y_len[counter]], target=1
            )
            counter += 1
        batch_loss.backward(retain_graph=True)
        # print(batch_loss)
        gam_optimizer.step()
        
    # ###############################################################################################
    # # Generator (GraphRNN) Update
    # ###############################################################################################

    GAMachineTrainer.model.eval()
    netG.train(True)
    optimizer.zero_grad()

    ### generate fake graphs from noise
    noise = torch.randn(args.batch_size, args.hidden_size_rnn)
    fake_graphs = []
    # for b in range(args.batch_size):
    fake_graph = netG(noise, X, Y, Y_len)
    fake_graphs = fake_graph
    # fake_graphs.append(fake_graph) # 1x8x29x29

    # print(len(fake_graphs))
    # print(fake_graphs[0].size())
    # # TODO: unpad (a.k.a. pack) fake_graphs

    ### compute embeddings of fake_graphs, then pass embeddings to generator for reconstruction
    # recon_fake_graphs = []
    tem_embeddings = []
    for fake_adj in fake_graphs:
        datadict, features, node = GAMachineTrainer.get_datadict_features_node(fake_adj, target=1) # target unimportant
        fake_embedding = GAMachineTrainer.model(
            datadict, fake_adj, features, node, get_embedding=True
        )
        
        fake_embedding = fake_embedding.view(-1)
        tem_embeddings.append(fake_embedding)
    fake_embeddings = torch.stack(tem_embeddings)
    recon_fake_graphs = netG(fake_embeddings, X, Y, Y_len)
    # recon_fake_graphs.append(recon_graph)

    ### do same for true graphs
    true_graphs = adj_mat 
    
    # compute embeddings of true_graphs, then pass embeddings to generator for reconstruction
    recon_true_graphs = []
    true_embnd = []
    for true_adj in true_graphs:
        datadict, features, node = GAMachineTrainer.get_datadict_features_node(adj, target=0)
        true_adj = true_adj + torch.eye(true_adj.size(0), true_adj.size(1))
        true_embedding = GAMachineTrainer.model(
            datadict, true_adj, features, node, get_embedding=True
        )
        true_embedding = true_embedding.view(-1)
        true_embnd.append(true_embedding)
    ftrue_embeddings = torch.stack(true_embnd)
    recon_true_graphs = netG(ftrue_embeddings, X, Y, Y_len)
    # recon_true_graphs = simple(true_graphs.float()) # TODO
    # recon_true_graphs.append(recon_graph)
    # print(recon_true_graphs.size())

    # output = loss(true_graphs.float(), recon_true_graphs) # TODO
    # output.backward() # TODO
    # optimizer.step() # TODO

    batch_gw_loss = 0
    # compute GW distance on true_graphs/fake_graphs and recon_true_graphs/recon_fake_graphs
    for adj_o, adj_r in zip(true_graphs, recon_true_graphs):
        batch_gw_loss += GWLoss(adj_o, adj_r)
    for adj_o, adj_r in zip(fake_graphs, recon_fake_graphs):
        batch_gw_loss += GWLoss(adj_o, adj_r)
    batch_gw_loss = torch.tensor(batch_gw_loss, requires_grad=True)
    batch_gw_loss.backward()
    optimizer_trainer.step()
    G_optimizer_output.step()

    # print(list(netG.parameters()), list(GAMachineTrainer.model.parameters())[0])
    print(output.item())
    


In [19]:
# loss = nn.MSELoss()
# simple = nn.Linear(2, 2)
# optimizer = optim.SGD(simple.parameters(), lr=0.01, momentum=0.9)
# input = torch.randn(2, 2, 2, requires_grad=True)
# target = torch.randn(2, 2, 2)
# for _ in range(5):
#     optimizer.zero_grad()
#     pred = simple(input)
#     output = loss(pred, target)
#     output.backward()
#     optimizer.step()
#     print(output)


In [41]:
for d in all_data:
    trainAE(args, all_data[0])
# trainAE(args, all_data[0])

0.06757685542106628
0.06752537935972214
0.0674520954489708
0.06735934317111969
0.06724926829338074
0.06712380051612854
0.06698474287986755


KeyboardInterrupt: 

In [None]:
# a = torch.from_numpy(np.array([[[1,0], [0,0]], [[0,0], [0,0]], [[0,0], [0,0]]]))
# i = torch.eye(2,2)
# i = torch.unsqueeze(i, 0)
# addition = i.expand(3, 2, 2)  
# a + addition