In [1]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from Model.HGDM import HyperbolicAE,HyperbolicDiffusion
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm
from Model import Encoders, Decoders

In [2]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=20000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split20000-10000.npz"),
    )
print(len(train),len(val),len(test))

train_loader = spk.AtomsLoader(train, batch_size=256, shuffle=False)
val_loader = spk.AtomsLoader(val, batch_size=256)

133885
20000 10000 103885


In [3]:
import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

In [4]:
encoder_path = './saved_model/'+args.model+'-encoder_KL.pt'
decoder_path = './saved_model/'+args.model+'-decoder_KL.pt'

encoder = getattr(Encoders, args.model)(args)
encoder.load_state_dict(torch.load(encoder_path))
if args.model == 'MLP':
    decoder = Decoders.model2decoder[args.model](None, args)
else:
    decoder = Decoders.model2decoder[args.model](encoder.curvatures, args)
decoder.load_state_dict(torch.load(decoder_path))

device = torch.device('cuda')
model = HyperbolicDiffusion(args,encoder,decoder)
model = model.to(device)

In [5]:
optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
print(f"Total number of parameters: {tot_params}")

# Train model
t_total = time.time()

Total number of parameters: 1625833


In [6]:

for epoch in range(args.epochs):
    model.train()
    counter = 0
    loss_sum,n = 0.0,0
    for input in tqdm(train_loader):
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        loss = model(input).sum()
        print(loss,lr_scheduler.get_last_lr())
        loss.backward()
        loss_sum += loss
        n += 1
        # if args.grad_clip is not None:
        #     grad_clip = float(args.grad_clip)
        #     all_params = list(model.parameters())
        #     for param in all_params:
        #         torch.nn.utils.clip_grad_value_(param, grad_clip)
        optimizer.step()
        lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)

    model.eval()
    with torch.no_grad():
        loss_sum,n = 0.0,0
        for input in tqdm(val_loader):
            for key in input:
                input[key] = input[key].to(torch.device('cuda'))
            t = time.time()
            model.train()
            optimizer.zero_grad()
            loss = model(input).sum()
            n += 1
            loss_sum += loss

        print('val_loss:',loss_sum.item()/n)


  1%|▏         | 1/79 [00:02<02:38,  2.03s/it]

tensor([0.0000, 0.0000, 3.3260, 0.5211, 1.7913, 0.0000, 2.1762, 3.1931, 3.4874,
        1.9030, 1.9651, 0.0000, 0.0000, 0.2339, 2.6870, 1.8458, 0.0000, 1.6993,
        4.9039, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(40.4763, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  3%|▎         | 2/79 [00:02<01:19,  1.03s/it]

tensor([0.0000, 0.0000, 1.0696, 0.0000, 1.7869, 0.0000, 2.7751, 0.8385, 4.9359,
        0.4971, 2.7933, 0.0000, 0.0000, 1.1113, 2.8633, 1.9674, 0.0000, 0.6262,
        2.1722, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(13.8018, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  4%|▍         | 3/79 [00:02<00:54,  1.39it/s]

tensor([0.0000, 0.0000, 1.0538, 1.0707, 0.6570, 0.5592, 0.5998, 4.7808, 1.9529,
        0.1118, 0.4543, 0.0000, 0.0000, 0.1517, 3.1344, 0.5669, 0.0000, 0.0000,
        4.7344, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(5.0007, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  5%|▌         | 4/79 [00:03<00:42,  1.75it/s]

tensor([0.0000, 0.0000, 2.3951, 0.0000, 2.5513, 0.0000, 2.4473, 1.9643, 5.2811,
        1.1091, 1.7846, 0.0000, 0.0000, 1.2789, 2.8809, 1.9397, 0.0000, 0.9752,
        3.8050, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(4.5446, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  6%|▋         | 5/79 [00:03<00:36,  2.01it/s]

tensor([0.0000, 0.0000, 3.3396, 0.8473, 1.6250, 0.0000, 2.0978, 2.9876, 3.3698,
        2.0298, 2.1048, 0.0000, 0.0000, 0.1945, 2.6885, 1.5392, 0.0000, 1.5029,
        4.9258, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(5.0921, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  8%|▊         | 6/79 [00:03<00:33,  2.21it/s]

tensor([0.0000, 0.0000, 1.6706, 0.0000, 1.0919, 0.0000, 2.8188, 0.4794, 4.9429,
        1.3007, 3.0138, 0.0000, 0.0000, 0.8803, 3.0998, 2.1155, 0.0000, 0.9583,
        2.8535, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(5.4866, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  9%|▉         | 7/79 [00:04<00:30,  2.38it/s]

tensor([0.0000, 0.0000, 0.0000, 0.7141, 0.4475, 1.4526, 0.4679, 4.0629, 1.6529,
        0.3198, 0.0000, 0.0000, 0.0000, 0.0000, 3.1053, 1.4336, 0.0000, 0.0000,
        3.5206, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(5.3301, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 10%|█         | 8/79 [00:04<00:28,  2.51it/s]

tensor([0.0000, 0.0000, 0.0538, 1.4221, 0.0000, 0.4543, 0.3918, 3.5832, 0.7628,
        0.0918, 0.3938, 0.0000, 0.0000, 0.0000, 2.8702, 0.2550, 0.4378, 0.0000,
        3.6938, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(4.7772, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 11%|█▏        | 9/79 [00:04<00:26,  2.60it/s]

tensor([0.0000, 0.0000, 0.8162, 0.9349, 0.8880, 0.8859, 0.3299, 4.7422, 1.8456,
        0.2464, 0.2688, 0.0000, 0.0000, 0.0000, 3.1344, 1.0137, 0.0000, 0.0000,
        4.5246, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(3.4132, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 13%|█▎        | 10/79 [00:05<00:26,  2.58it/s]

tensor([0.0000, 0.0000, 0.7198, 0.9155, 0.9242, 0.9705, 0.3900, 4.7907, 1.9432,
        0.2514, 0.2320, 0.0000, 0.0000, 0.0000, 3.1671, 1.1742, 0.0000, 0.0000,
        4.3844, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(2.7832, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 14%|█▍        | 11/79 [00:05<00:25,  2.62it/s]

tensor([0.0000, 0.0000, 0.5525, 0.8855, 0.8983, 0.9569, 0.7786, 4.7215, 2.2216,
        0.3505, 0.1988, 0.0000, 0.0000, 0.1396, 3.2342, 1.4674, 0.0000, 0.0000,
        3.9591, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(2.0604, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 15%|█▌        | 12/79 [00:05<00:25,  2.66it/s]

tensor([0.0000, 0.0000, 0.5082, 0.8597, 0.8677, 1.1051, 0.4526, 4.6616, 1.9483,
        0.2941, 0.1425, 0.0000, 0.0000, 0.0000, 3.1772, 1.3472, 0.0000, 0.0000,
        4.1146, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.6398, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 16%|█▋        | 13/79 [00:06<00:25,  2.61it/s]

tensor([0.0000, 0.0000, 2.2893, 0.0000, 2.4903, 0.0000, 2.3443, 1.9023, 5.2340,
        1.0442, 1.7634, 0.0000, 0.0000, 1.1421, 2.8834, 2.0162, 0.0000, 1.1074,
        3.7626, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.5178, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 18%|█▊        | 14/79 [00:06<00:25,  2.57it/s]

tensor([0.0000, 0.0000, 2.6696, 0.0000, 2.6673, 0.0000, 2.6895, 2.1020, 5.3464,
        1.3900, 1.8162, 0.0000, 0.0000, 1.6276, 2.8777, 1.7644, 0.0000, 0.5997,
        3.8547, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.4045, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 19%|█▉        | 15/79 [00:07<00:24,  2.59it/s]

tensor([0.0000, 0.0000, 0.0000, 0.7520, 0.3665, 1.4815, 0.5033, 4.0274, 1.6151,
        0.3340, 0.0000, 0.0000, 0.0000, 0.0000, 3.1357, 1.5612, 0.0573, 0.0000,
        3.3092, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.5132, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 20%|██        | 16/79 [00:07<00:24,  2.57it/s]

tensor([0.0000, 0.0000, 2.4154, 0.0000, 2.4644, 0.0000, 2.7066, 1.9847, 5.3155,
        1.2334, 2.3564, 0.0000, 0.0000, 1.6420, 3.0613, 1.8031, 0.0000, 0.5612,
        3.6372, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.5744, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 22%|██▏       | 17/79 [00:08<00:25,  2.45it/s]

tensor([0.0000, 0.0000, 1.4498, 0.0000, 0.7213, 0.0000, 2.7070, 0.3295, 4.8489,
        1.3746, 2.9432, 0.0000, 0.0000, 0.8134, 3.1736, 2.0756, 0.0000, 0.8127,
        2.7675, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.6969, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 23%|██▎       | 18/79 [00:08<00:25,  2.37it/s]

tensor([0.0000, 0.0000, 2.2452, 0.0000, 2.6399, 0.0000, 2.4683, 2.1274, 5.3211,
        0.8382, 2.1393, 0.0000, 0.0000, 1.3649, 2.9813, 1.9103, 0.0000, 0.9536,
        3.7119, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.6243, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 24%|██▍       | 19/79 [00:08<00:25,  2.32it/s]

tensor([0.0000, 0.0000, 3.1778, 0.9098, 1.4992, 0.0000, 2.2469, 3.0001, 3.2775,
        1.8862, 2.2563, 0.0000, 0.0000, 0.0425, 2.8755, 1.9457, 0.0000, 1.6676,
        4.8941, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.6200, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 25%|██▌       | 20/79 [00:09<00:25,  2.29it/s]

tensor([0.0000, 0.0000, 0.5449, 0.8753, 0.9419, 1.0315, 0.6219, 4.7392, 2.1269,
        0.3386, 0.1757, 0.0000, 0.0000, 0.0643, 3.2174, 1.4695, 0.0000, 0.0000,
        4.0240, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.5592, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 27%|██▋       | 21/79 [00:09<00:25,  2.29it/s]

tensor([0.0000, 0.0000, 0.0000, 0.6701, 0.7809, 1.3650, 0.8101, 4.4013, 2.1835,
        0.4228, 0.0000, 0.0000, 0.0000, 0.0058, 3.2269, 1.9124, 0.0000, 0.0000,
        3.2943, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.3830, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 28%|██▊       | 22/79 [00:10<00:25,  2.28it/s]

tensor([0.0000, 0.0000, 0.3188, 0.7562, 0.7651, 1.2653, 0.2504, 4.3798, 1.6730,
        0.3379, 0.0174, 0.0000, 0.0000, 0.0000, 3.1098, 1.3506, 0.0000, 0.0000,
        4.0023, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.2767, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 29%|██▉       | 23/79 [00:10<00:24,  2.27it/s]

tensor([0.0000, 0.0000, 3.1250, 0.3075, 1.0420, 0.0000, 2.6797, 1.7363, 3.4382,
        2.0051, 1.9625, 0.0000, 0.0000, 0.0000, 2.6718, 2.0939, 0.0000, 1.5993,
        4.6648, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.2232, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 30%|███       | 24/79 [00:11<00:24,  2.26it/s]

tensor([0.0000, 0.0000, 0.5302, 0.9593, 0.6913, 1.0002, 0.5685, 4.5233, 1.8843,
        0.2840, 0.2117, 0.0000, 0.0000, 0.0000, 3.1614, 1.1514, 0.0000, 0.0000,
        4.1048, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.1485, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 30%|███       | 24/79 [00:11<00:26,  2.07it/s]

tensor([0.0000, 0.0000, 0.1373, 0.7449, 0.6431, 1.3953, 0.2541, 4.2558, 1.6054,
        0.3159, 0.0000, 0.0000, 0.0000, 0.0000, 3.1002, 1.3779, 0.0000, 0.0000,
        3.8338, 0.0000], device='cuda:0', grad_fn=<SelectBackward0>)
tensor(1.0969, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]





KeyboardInterrupt: 