In [1]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from Model.HGDM import HyperbolicAE,HyperbolicDiffusion
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm
from Model import Encoders, Decoders

In [2]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=20000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split20000-10000.npz"),
    )
print(len(train),len(val),len(test))

train_loader = spk.AtomsLoader(train, batch_size=256, shuffle=False)
val_loader = spk.AtomsLoader(val, batch_size=256)

133885
20000 10000 103885


In [3]:
import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

In [4]:
encoder_path = './saved_model/'+args.model+'-encoder_KL.pt'
decoder_path = './saved_model/'+args.model+'-decoder_KL.pt'

encoder = getattr(Encoders, args.model)(args)
encoder.load_state_dict(torch.load(encoder_path))
if args.model == 'MLP':
    decoder = Decoders.model2decoder[args.model](None, args)
else:
    decoder = Decoders.model2decoder[args.model](encoder.curvatures, args)
decoder.load_state_dict(torch.load(decoder_path))

device = torch.device('cuda')
model = HyperbolicDiffusion(args,encoder,decoder)
model = model.to(device)

In [5]:
optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
print(f"Total number of parameters: {tot_params}")

# Train model
t_total = time.time()

Total number of parameters: 1625684


In [6]:

for epoch in range(args.epochs):
    model.train()
    counter = 0
    loss_sum,n = 0.0,0
    for input in tqdm(train_loader):
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        loss = model(input).sum()
        print(loss,lr_scheduler.get_last_lr())
        loss.backward()
        loss_sum += loss
        n += 1
        # if args.grad_clip is not None:
        #     grad_clip = float(args.grad_clip)
        #     all_params = list(model.parameters())
        #     for param in all_params:
        #         torch.nn.utils.clip_grad_value_(param, grad_clip)
        optimizer.step()
        lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)

    model.eval()
    with torch.no_grad():
        loss_sum,n = 0.0,0
        for input in tqdm(val_loader):
            for key in input:
                input[key] = input[key].to(torch.device('cuda'))
            t = time.time()
            model.train()
            optimizer.zero_grad()
            loss = model(input).sum()
            n += 1
            loss_sum += loss

        print('val_loss:',loss_sum.item()/n)


  1%|▏         | 1/79 [00:01<02:30,  1.93s/it]

tensor(31.8704, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  3%|▎         | 2/79 [00:02<01:15,  1.03it/s]

tensor(7.0160, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  4%|▍         | 3/79 [00:02<00:51,  1.49it/s]

tensor(3.0742, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  5%|▌         | 4/79 [00:02<00:40,  1.87it/s]

tensor(5.5457, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  6%|▋         | 5/79 [00:03<00:33,  2.19it/s]

tensor(5.6615, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  8%|▊         | 6/79 [00:03<00:30,  2.43it/s]

tensor(5.3542, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


  9%|▉         | 7/79 [00:03<00:27,  2.59it/s]

tensor(3.4763, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 10%|█         | 8/79 [00:04<00:26,  2.68it/s]

tensor(2.2886, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 11%|█▏        | 9/79 [00:04<00:25,  2.74it/s]

tensor(1.7962, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 13%|█▎        | 10/79 [00:04<00:24,  2.83it/s]

tensor(1.7880, device='cuda:0', grad_fn=<SumBackward0>) [0.0001]


 14%|█▍        | 11/79 [00:05<00:23,  2.90it/s]

tensor(2.1700, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 15%|█▌        | 12/79 [00:05<00:22,  2.93it/s]

tensor(2.3370, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 16%|█▋        | 13/79 [00:05<00:21,  3.02it/s]

tensor(2.1420, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 18%|█▊        | 14/79 [00:06<00:21,  3.02it/s]

tensor(1.8114, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 19%|█▉        | 15/79 [00:06<00:21,  2.98it/s]

tensor(1.5388, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 20%|██        | 16/79 [00:06<00:20,  3.00it/s]

tensor(1.2613, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 22%|██▏       | 17/79 [00:07<00:20,  2.97it/s]

tensor(1.1888, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 23%|██▎       | 18/79 [00:07<00:20,  2.95it/s]

tensor(1.2496, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 24%|██▍       | 19/79 [00:07<00:20,  2.93it/s]

tensor(1.4043, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 25%|██▌       | 20/79 [00:08<00:20,  2.91it/s]

tensor(1.4084, device='cuda:0', grad_fn=<SumBackward0>) [8e-05]


 27%|██▋       | 21/79 [00:08<00:19,  2.90it/s]

tensor(1.4203, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 28%|██▊       | 22/79 [00:08<00:19,  2.87it/s]

tensor(1.3507, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 29%|██▉       | 23/79 [00:09<00:19,  2.84it/s]

tensor(1.2299, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 30%|███       | 24/79 [00:09<00:19,  2.82it/s]

tensor(1.1400, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 32%|███▏      | 25/79 [00:10<00:19,  2.77it/s]

tensor(1.0963, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 33%|███▎      | 26/79 [00:10<00:19,  2.79it/s]

tensor(1.0775, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 34%|███▍      | 27/79 [00:10<00:18,  2.74it/s]

tensor(1.0905, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 35%|███▌      | 28/79 [00:11<00:18,  2.72it/s]

tensor(1.1246, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 37%|███▋      | 29/79 [00:11<00:18,  2.76it/s]

tensor(1.1327, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 38%|███▊      | 30/79 [00:11<00:17,  2.79it/s]

tensor(1.1380, device='cuda:0', grad_fn=<SumBackward0>) [6.400000000000001e-05]


 39%|███▉      | 31/79 [00:12<00:17,  2.81it/s]

tensor(1.1223, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 41%|████      | 32/79 [00:12<00:16,  2.83it/s]

tensor(1.1042, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 42%|████▏     | 33/79 [00:12<00:16,  2.84it/s]

tensor(1.0707, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 43%|████▎     | 34/79 [00:13<00:15,  2.84it/s]

tensor(1.0409, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 44%|████▍     | 35/79 [00:13<00:15,  2.80it/s]

tensor(1.0250, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 46%|████▌     | 36/79 [00:13<00:15,  2.76it/s]

tensor(1.0275, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 47%|████▋     | 37/79 [00:14<00:15,  2.78it/s]

tensor(1.0401, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 48%|████▊     | 38/79 [00:14<00:14,  2.77it/s]

tensor(1.0433, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 49%|████▉     | 39/79 [00:15<00:14,  2.76it/s]

tensor(1.0413, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 51%|█████     | 40/79 [00:15<00:14,  2.77it/s]

tensor(1.0419, device='cuda:0', grad_fn=<SumBackward0>) [5.120000000000001e-05]


 52%|█████▏    | 41/79 [00:15<00:13,  2.76it/s]

tensor(1.0395, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 53%|█████▎    | 42/79 [00:16<00:13,  2.78it/s]

tensor(1.0337, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 54%|█████▍    | 43/79 [00:16<00:12,  2.80it/s]

tensor(1.0190, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 56%|█████▌    | 44/79 [00:16<00:12,  2.81it/s]

tensor(1.0117, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 57%|█████▋    | 45/79 [00:17<00:12,  2.79it/s]

tensor(1.0110, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 58%|█████▊    | 46/79 [00:17<00:11,  2.80it/s]

tensor(1.0092, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 59%|█████▉    | 47/79 [00:17<00:11,  2.84it/s]

tensor(1.0174, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 61%|██████    | 48/79 [00:18<00:10,  2.86it/s]

tensor(1.0160, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 62%|██████▏   | 49/79 [00:18<00:10,  2.85it/s]

tensor(1.0170, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 63%|██████▎   | 50/79 [00:18<00:10,  2.84it/s]

tensor(1.0162, device='cuda:0', grad_fn=<SumBackward0>) [4.0960000000000014e-05]


 65%|██████▍   | 51/79 [00:19<00:09,  2.84it/s]

tensor(1.0108, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 66%|██████▌   | 52/79 [00:19<00:09,  2.86it/s]

tensor(1.0065, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 67%|██████▋   | 53/79 [00:20<00:09,  2.84it/s]

tensor(1.0028, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 68%|██████▊   | 54/79 [00:20<00:08,  2.82it/s]

tensor(1.0014, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 70%|██████▉   | 55/79 [00:20<00:08,  2.79it/s]

tensor(0.9994, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 71%|███████   | 56/79 [00:21<00:08,  2.76it/s]

tensor(1.0047, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 72%|███████▏  | 57/79 [00:21<00:07,  2.75it/s]

tensor(1.0032, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 73%|███████▎  | 58/79 [00:21<00:07,  2.75it/s]

tensor(1.0078, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 75%|███████▍  | 59/79 [00:22<00:07,  2.76it/s]

tensor(1.0036, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 76%|███████▌  | 60/79 [00:22<00:06,  2.74it/s]

tensor(1.0063, device='cuda:0', grad_fn=<SumBackward0>) [3.2768000000000016e-05]


 77%|███████▋  | 61/79 [00:22<00:06,  2.78it/s]

tensor(1.0053, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 78%|███████▊  | 62/79 [00:23<00:06,  2.80it/s]

tensor(1.0106, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 80%|███████▉  | 63/79 [00:23<00:05,  2.85it/s]

tensor(1.0022, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 81%|████████  | 64/79 [00:23<00:05,  2.83it/s]

tensor(0.9991, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 82%|████████▏ | 65/79 [00:24<00:04,  2.83it/s]

tensor(0.9940, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 84%|████████▎ | 66/79 [00:24<00:04,  2.84it/s]

tensor(0.9975, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 85%|████████▍ | 67/79 [00:25<00:04,  2.79it/s]

tensor(1.0027, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 86%|████████▌ | 68/79 [00:25<00:03,  2.79it/s]

tensor(1.0009, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 87%|████████▋ | 69/79 [00:25<00:03,  2.57it/s]

tensor(1.0036, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 89%|████████▊ | 70/79 [00:26<00:03,  2.54it/s]

tensor(0.9889, device='cuda:0', grad_fn=<SumBackward0>) [2.6214400000000015e-05]


 90%|████████▉ | 71/79 [00:26<00:03,  2.65it/s]

tensor(1.0014, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


 91%|█████████ | 72/79 [00:26<00:02,  2.73it/s]

tensor(0.9928, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


 92%|█████████▏| 73/79 [00:27<00:02,  2.77it/s]

tensor(0.9951, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


 94%|█████████▎| 74/79 [00:27<00:01,  2.81it/s]

tensor(0.9919, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


 95%|█████████▍| 75/79 [00:27<00:01,  2.80it/s]

tensor(1.0047, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


 96%|█████████▌| 76/79 [00:28<00:01,  2.83it/s]

tensor(0.9969, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


 97%|█████████▋| 77/79 [00:28<00:00,  2.76it/s]

tensor(0.9993, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]


100%|██████████| 79/79 [00:29<00:00,  2.71it/s]


tensor(0.9980, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]
tensor(0.9793, device='cuda:0', grad_fn=<SumBackward0>) [2.0971520000000012e-05]
Epoch: 0001 lr: 2.0971520000000012e-05 loss: 1.8474 time: 0.0178s


 12%|█▎        | 5/40 [00:01<00:12,  2.83it/s]


KeyboardInterrupt: 