In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from models.args import *
from models.dataset import *
from models.discriminator import *
from models.generator import *
from models.inverter import *

from tqdm import tqdm
import warnings

In [2]:
def fxn():
    warnings.warn("deprecated", UserWarning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()

In [3]:
def choose_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'

In [9]:
def train(args, train_inverter=False, num_layers=4, clamp_lower=-0.1, clamp_upper=0.1, lr=1e-3, betas=1e-5, lamb=0.1, loss_func='MSE', device=choose_device()):
    # save losses
    iloss_lst = []
    dloss_lst = []
    gloss_lst = []

    # get the dataset
    lr = 1e-4
    train, labels = get_dataset_with_label(args.graph_type) # entire dataset as train
    train_dataset = Graph_sequence_sampler_pytorch(train, labels, args)
    train_loader = get_dataloader_labels(train_dataset, args)
    noise_dim = args.hidden_size_rnn
    print('noise dimension is: ', noise_dim)

    # initialize noise, optimizer and loss
    netI = Inverter(input_dim=128, output_dim=args.hidden_size_rnn, hidden_dim=64)
    netG = GraphRNN(args=args)
    netD = NetD(stat_input_dim=128, stat_hidden_dim=64, num_stat=2)
    hg = list(netG.parameters())[5].register_hook(lambda grad: print(f"NetG parameter Update with gradient {grad}"))

    # set up a register_hook to check parameter gradient
    # for param in netD.parameters():
    #     h = param.register_hook(lambda grad: print("Parameter Update with gradient {:.4f}".format(grad)))

    # check model parameters
    # for param in netD.parameters():
    #     print(param.name, param.data, param.requires_grad)
    # for param in netG.parameters():
        # print(param.name, param.data, param.requires_grad)

    graph2vec = get_graph2vec(args.graph_type, dim=512) # use infer() to generate new graph embedding
    optimizerI = optim.Adam(netI.parameters(), lr=lr)
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=[betas for _ in range(2)])
    lossI = WGAN_ReconLoss(device, lamb, loss_func)
    G_optimizer_rnn, G_optimizer_output, G_scheduler_rnn, G_scheduler_output = netG.init_optimizer(lr=0.1) # initialize optimizers


    noise = torch.randn(args.batch_size, noise_dim).to(device)
    one = torch.tensor(1, dtype=torch.float)
    mone = torch.tensor(-1, dtype=torch.float)

    gen_iterations = 0
    for e in range(args.epochs):
        # for now, treat the input as adj matrices
        start_time = time.time()
        e_errI, e_errD, e_errG, count_batch = 0, 0, 0, 0
        for i, data in tqdm(enumerate(train_loader), desc=f"Training epoch#{e+1}", total=len(train_loader)):
            X = data['x']
            Y = data['y']
            adj_mat = data['adj_mat']
            label = data['label']
            Y_len = data['len']

            # zero grad
            optimizerI.zero_grad()
            optimizerD.zero_grad()
            G_optimizer_rnn.zero_grad()
            G_optimizer_output.zero_grad()

            # skip uneven batch
            if adj_mat.size(0) != args.batch_size:
                continue

            ######################
            # Discriminator Update
            ######################
            # number of iteration to train the discriminator
            # if gen_iterations < 25 or gen_iterations % 500 == 0:
            #     Diters = 20
            # else:
            #     Diters = 5
            Diters = 1
            j = 0 # counter for 1, 2, ... Diters

            # enable training
            netD.train(True)
            netG.train(False)
            b_errD = 0
            while j < Diters:
                j += 1
                # TODO: commenting this part out for testing
                # weight clipping: clamp parameters to a cube
                # for p in netD.parameters():
                #     p.data.clamp_(clamp_lower, clamp_upper)
                netD.zero_grad()

                # train with real
                inputs = torch.empty_like(adj_mat).copy_(adj_mat)
                D_pred = netD(inputs)
                errD_real = D_pred
                errD_real.backward(one) # discriminator should assign 1's to true samples
                print(j, 'errD_real:', errD_real.item(), end='; ')

                # train with fake
                input = noise.normal_(0,1) # (batch_size, hidden_size)
                # insert data processing
                fake = netG(input)
                fake_tensor = netD(fake)
                errD_fake = fake_tensor
                errD_fake.backward(mone) # discriminator should assign -1's to fake samples??

                # # # compute Wasserstein distance and update parameters
                errD = errD_real - errD_fake

                # print(f"Check if the model is training: iterative value at #{j}.")
                # for p in netD.parameters():
                #     print("Parameters gradients? :", p.requires_grad, end='')
                #     print("Parameters grad: ", p.grad)

                optimizerD.step()
                print(j, 'errD_fake:', errD_fake.item())
                # print(f"errD_real {errD_real.item()} ")
                # print(f"Iterative errD {errD.item()}, errD_real {errD_real.item()}, errD_fake {errD_fake.item()}: ")
                # b_errD += errD
            
            # ========== Train Generator ==================
            netD.train(False)
            netG.train(True)
            # netG.clear_gradient_models()
            G_optimizer_rnn.zero_grad()
            G_optimizer_output.zero_grad()
            # in case our last batch was the tail batch of the dataloader,
            # make sure we feed a full batch of noise
            noisev = Variable(noise.normal_(0,1))
            fake = netG(noisev)
            fake_tensor = netD(fake)
            errG = fake_tensor
            errG.backward(one)
            G_optimizer_rnn.step()
            G_optimizer_output.step()
            for p in netG.parameters():
                print(p[0,0], p.grad)
                break
            # for p in netG.parameters()[0]:
            # netG.all_steps()
            gen_iterations += 1

            print(f"errG for generator: {errG}")


In [5]:
args = Args()

In [6]:
a = torch.zeros((5,10))
a[0, :] = torch.Tensor(np.array([i for i in range(10)]))
a[1, :] = torch.Tensor(np.array([i for i in range(10, 20)]))
torch.flip(a, dims=(1,))

tensor([[ 9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  0.],
        [19., 18., 17., 16., 15., 14., 13., 12., 11., 10.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])

In [7]:
# train, labels = get_dataset_with_label(args.graph_type) # entire dataset as train
# train_dataset = Graph_sequence_sampler_pytorch(train, labels, args)
# train_loader = get_dataloader_labels(train_dataset, args)
# e = 0
# for i, data in tqdm(enumerate(train_loader), desc=f"Training epoch#{e+1}", total=len(train_loader)):
#     X = data['x']
#     Y = data['y']

#     print(X.size(), Y.size())
#     # print(X)
#     print('-'*50)
#     # print(Y)
#     break

In [1]:
train(args=args)

NameError: name 'train' is not defined