In [1]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from MyModel.HGCN import RegModel
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm

# 获得训练数据

In [2]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

133885


In [3]:
train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=20000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split20000.npz"),
    )
print(len(train),len(test),len(val))

20000 103885 10000


In [4]:
train_loader = spk.AtomsLoader(train, batch_size=256, shuffle=False)
val_loader = spk.AtomsLoader(val, batch_size=10)

In [5]:
import sys
import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

model = RegModel(args)
optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
logging.info(f"Total number of parameters: {tot_params}")
device = torch.device('cuda')
# Train model
t_total = time.time()
model = model.to(device)
loss_fun = torch.nn.MSELoss()
# torch.set_printoptions(threshold=np.inf)
for epoch in range(args.epochs):
    loss_sum,n,t = 0,0,0.0
    counter = 0
    for input in tqdm(train_loader):
        # counter+=1
        # if counter>10:
        #     sys.exit(0)
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        y_hat = model(input)
        # print('y_hat',y_hat[:2])
        # print('input[QM9.U0]',input[QM9.U0][:2])
        loss = loss_fun(y_hat,input[QM9.U0])
        # curvatures = list(model.get_submodule('curvatures'))
        # print(curvatures)

        # print(loss)
        loss.backward()
        loss_sum += loss
        n += 1
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
    lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)
        logging.info(str)
        curvatures = list(model.get_submodule('curvatures'))
        # print(curvatures)



    # if (epoch + 1) % args.eval_freq == 0:
    #     model.eval()
    #     embeddings = model.encode(data['features'], data['adj_train_norm'])
    #     val_metrics = model.compute_metrics(embeddings, data, 'val')
    #     if (epoch + 1) % args.log_freq == 0:
    #         logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
    #     if model.has_improved(best_val_metrics, val_metrics):
    #         best_test_metrics = model.compute_metrics(embeddings, data, 'test')
    #         best_emb = embeddings.cpu()
    #         if args.save:
    #             np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy())
    #         best_val_metrics = val_metrics
    #         counter = 0
    #     else:
    #         counter += 1
    #         if counter == args.patience and epoch > args.min_epochs:
    #             logging.info("Early stopping")
    #             break

"""
euclid
Epoch: 0020 lr: 0.005 loss: 377.5419 time: 0.0045s
Epoch: 0040 lr: 0.0025 loss: 66.3344 time: 0.0000s
Epoch: 0060 lr: 0.00125 loss: 46.7183 time: 0.0199s

hyperbolid
Epoch: 0012 lr: 0.01 loss: 33.4861 time: 0.0595s
Epoch: 0020 lr: 0.005 loss: 6.2979 time: 0.0600s
Epoch: 0040 lr: 0.0025 loss: 4.1167 time: 0.0697s
"""

100%|██████████| 79/79 [00:32<00:00,  2.46it/s]


Epoch: 0001 lr: 0.01 loss: 23098756.0000 time: 0.0565s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0002 lr: 0.01 loss: 5280187.5000 time: 0.0610s


100%|██████████| 79/79 [00:32<00:00,  2.46it/s]


Epoch: 0003 lr: 0.01 loss: 5281497.5000 time: 0.0821s


100%|██████████| 79/79 [00:31<00:00,  2.47it/s]


Epoch: 0004 lr: 0.01 loss: 5279981.0000 time: 0.0531s


100%|██████████| 79/79 [00:31<00:00,  2.53it/s]


Epoch: 0005 lr: 0.01 loss: 4568375.0000 time: 0.0670s


100%|██████████| 79/79 [00:31<00:00,  2.53it/s]


Epoch: 0006 lr: 0.01 loss: 552977.3750 time: 0.0710s


100%|██████████| 79/79 [00:31<00:00,  2.49it/s]


Epoch: 0007 lr: 0.01 loss: 24327.4922 time: 0.0538s


100%|██████████| 79/79 [00:30<00:00,  2.62it/s]


Epoch: 0008 lr: 0.01 loss: 23226.3047 time: 0.0548s


100%|██████████| 79/79 [00:30<00:00,  2.59it/s]


Epoch: 0009 lr: 0.01 loss: 23354.4453 time: 0.0678s


100%|██████████| 79/79 [00:30<00:00,  2.61it/s]


Epoch: 0010 lr: 0.01 loss: 22763.4648 time: 0.0601s


100%|██████████| 79/79 [00:30<00:00,  2.56it/s]


Epoch: 0011 lr: 0.01 loss: 23447.0410 time: 0.0533s


100%|██████████| 79/79 [00:30<00:00,  2.60it/s]


Epoch: 0012 lr: 0.01 loss: 26267.4160 time: 0.0560s


100%|██████████| 79/79 [00:30<00:00,  2.62it/s]


Epoch: 0013 lr: 0.01 loss: 30542.2598 time: 0.0572s


100%|██████████| 79/79 [00:29<00:00,  2.64it/s]


Epoch: 0014 lr: 0.01 loss: 22940.2930 time: 0.0568s


100%|██████████| 79/79 [00:30<00:00,  2.63it/s]


Epoch: 0015 lr: 0.01 loss: 12847.1035 time: 0.0577s


100%|██████████| 79/79 [00:30<00:00,  2.61it/s]


Epoch: 0016 lr: 0.01 loss: 1981.4211 time: 0.0630s


100%|██████████| 79/79 [00:30<00:00,  2.61it/s]


Epoch: 0017 lr: 0.01 loss: 114.1029 time: 0.0599s


100%|██████████| 79/79 [00:31<00:00,  2.50it/s]


Epoch: 0018 lr: 0.01 loss: 48.6787 time: 0.0564s


100%|██████████| 79/79 [00:32<00:00,  2.46it/s]


Epoch: 0019 lr: 0.01 loss: 30.7299 time: 0.0630s


100%|██████████| 79/79 [00:31<00:00,  2.48it/s]


Epoch: 0020 lr: 0.005 loss: 41.0017 time: 0.0679s


100%|██████████| 79/79 [00:31<00:00,  2.53it/s]


Epoch: 0021 lr: 0.005 loss: 23.5421 time: 0.0729s


100%|██████████| 79/79 [00:30<00:00,  2.55it/s]


Epoch: 0022 lr: 0.005 loss: 20.1379 time: 0.0637s


100%|██████████| 79/79 [00:29<00:00,  2.67it/s]


Epoch: 0023 lr: 0.005 loss: 19.9679 time: 0.0641s


100%|██████████| 79/79 [00:29<00:00,  2.65it/s]


Epoch: 0024 lr: 0.005 loss: 18.3718 time: 0.0589s


100%|██████████| 79/79 [00:30<00:00,  2.57it/s]


Epoch: 0025 lr: 0.005 loss: 17.2459 time: 0.0556s


100%|██████████| 79/79 [00:30<00:00,  2.58it/s]


Epoch: 0026 lr: 0.005 loss: 16.5226 time: 0.0697s


100%|██████████| 79/79 [00:29<00:00,  2.65it/s]


Epoch: 0027 lr: 0.005 loss: 15.9741 time: 0.0592s


100%|██████████| 79/79 [00:31<00:00,  2.48it/s]


Epoch: 0028 lr: 0.005 loss: 16.0803 time: 0.0538s


100%|██████████| 79/79 [00:32<00:00,  2.41it/s]


Epoch: 0029 lr: 0.005 loss: 16.0135 time: 0.0604s


100%|██████████| 79/79 [00:33<00:00,  2.38it/s]


Epoch: 0030 lr: 0.005 loss: 15.6306 time: 0.0666s


100%|██████████| 79/79 [00:32<00:00,  2.47it/s]


Epoch: 0031 lr: 0.005 loss: 15.4440 time: 0.0547s


100%|██████████| 79/79 [00:31<00:00,  2.51it/s]


Epoch: 0032 lr: 0.005 loss: 15.9313 time: 0.0550s


100%|██████████| 79/79 [00:32<00:00,  2.46it/s]


Epoch: 0033 lr: 0.005 loss: 16.9608 time: 0.0615s


100%|██████████| 79/79 [00:31<00:00,  2.47it/s]


Epoch: 0034 lr: 0.005 loss: 15.3255 time: 0.0575s


100%|██████████| 79/79 [00:31<00:00,  2.48it/s]


Epoch: 0035 lr: 0.005 loss: 18.7460 time: 0.0672s


100%|██████████| 79/79 [00:31<00:00,  2.49it/s]


Epoch: 0036 lr: 0.005 loss: 26.6611 time: 0.0628s


100%|██████████| 79/79 [00:31<00:00,  2.50it/s]


Epoch: 0037 lr: 0.005 loss: 23.7247 time: 0.0424s


100%|██████████| 79/79 [00:30<00:00,  2.58it/s]


Epoch: 0038 lr: 0.005 loss: 1281.0712 time: 0.0590s


100%|██████████| 79/79 [00:29<00:00,  2.64it/s]


Epoch: 0039 lr: 0.005 loss: 258.7288 time: 0.0500s


100%|██████████| 79/79 [00:29<00:00,  2.67it/s]


Epoch: 0040 lr: 0.0025 loss: 791.2199 time: 0.0504s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0041 lr: 0.0025 loss: 46.2953 time: 0.0618s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0042 lr: 0.0025 loss: 10.7854 time: 0.0525s


100%|██████████| 79/79 [00:29<00:00,  2.65it/s]


Epoch: 0043 lr: 0.0025 loss: 10.3416 time: 0.0671s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0044 lr: 0.0025 loss: 9.8477 time: 0.0645s


100%|██████████| 79/79 [00:29<00:00,  2.66it/s]


Epoch: 0045 lr: 0.0025 loss: 9.4973 time: 0.0574s


100%|██████████| 79/79 [00:31<00:00,  2.49it/s]


Epoch: 0046 lr: 0.0025 loss: 9.3065 time: 0.0625s


100%|██████████| 79/79 [00:30<00:00,  2.55it/s]


Epoch: 0047 lr: 0.0025 loss: 9.4747 time: 0.0531s


100%|██████████| 79/79 [00:30<00:00,  2.57it/s]


Epoch: 0048 lr: 0.0025 loss: 9.7508 time: 0.0571s


100%|██████████| 79/79 [00:29<00:00,  2.66it/s]


Epoch: 0049 lr: 0.0025 loss: 9.8264 time: 0.0595s


100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Epoch: 0050 lr: 0.0025 loss: 9.7564 time: 0.0583s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0051 lr: 0.0025 loss: 9.6690 time: 0.0529s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0052 lr: 0.0025 loss: 10.4051 time: 0.0512s


100%|██████████| 79/79 [00:29<00:00,  2.72it/s]


Epoch: 0053 lr: 0.0025 loss: 10.1725 time: 0.0629s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0054 lr: 0.0025 loss: 10.1264 time: 0.0568s


100%|██████████| 79/79 [00:35<00:00,  2.25it/s]


Epoch: 0055 lr: 0.0025 loss: 12.8499 time: 0.0629s


100%|██████████| 79/79 [00:33<00:00,  2.37it/s]


Epoch: 0056 lr: 0.0025 loss: 18.3601 time: 0.0613s


100%|██████████| 79/79 [00:32<00:00,  2.44it/s]


Epoch: 0057 lr: 0.0025 loss: 16.9956 time: 0.0609s


100%|██████████| 79/79 [00:32<00:00,  2.40it/s]


Epoch: 0058 lr: 0.0025 loss: 245.6144 time: 0.0592s


100%|██████████| 79/79 [00:32<00:00,  2.42it/s]


Epoch: 0059 lr: 0.0025 loss: 277.4164 time: 0.0627s


100%|██████████| 79/79 [00:32<00:00,  2.42it/s]


Epoch: 0060 lr: 0.00125 loss: 429.1048 time: 0.0708s


100%|██████████| 79/79 [00:33<00:00,  2.36it/s]


Epoch: 0061 lr: 0.00125 loss: 19.8051 time: 0.0768s


100%|██████████| 79/79 [00:33<00:00,  2.33it/s]


Epoch: 0062 lr: 0.00125 loss: 7.4901 time: 0.0805s


100%|██████████| 79/79 [00:33<00:00,  2.37it/s]


Epoch: 0063 lr: 0.00125 loss: 7.3628 time: 0.0698s


100%|██████████| 79/79 [00:31<00:00,  2.54it/s]


Epoch: 0064 lr: 0.00125 loss: 7.1050 time: 0.0640s


100%|██████████| 79/79 [00:34<00:00,  2.29it/s]


Epoch: 0065 lr: 0.00125 loss: 6.8446 time: 0.0696s


100%|██████████| 79/79 [00:35<00:00,  2.23it/s]


Epoch: 0066 lr: 0.00125 loss: 6.7056 time: 0.0692s


100%|██████████| 79/79 [00:34<00:00,  2.30it/s]


Epoch: 0067 lr: 0.00125 loss: 6.7571 time: 0.0741s


100%|██████████| 79/79 [00:33<00:00,  2.33it/s]


Epoch: 0068 lr: 0.00125 loss: 6.9814 time: 0.0681s


100%|██████████| 79/79 [00:34<00:00,  2.31it/s]


Epoch: 0069 lr: 0.00125 loss: 6.9819 time: 0.0670s


100%|██████████| 79/79 [00:34<00:00,  2.30it/s]


Epoch: 0070 lr: 0.00125 loss: 7.0889 time: 0.0715s


100%|██████████| 79/79 [00:34<00:00,  2.29it/s]


Epoch: 0071 lr: 0.00125 loss: 7.2624 time: 0.0688s


100%|██████████| 79/79 [00:35<00:00,  2.25it/s]


Epoch: 0072 lr: 0.00125 loss: 7.7883 time: 0.0804s


100%|██████████| 79/79 [00:34<00:00,  2.32it/s]


Epoch: 0073 lr: 0.00125 loss: 7.8174 time: 0.0624s


100%|██████████| 79/79 [00:33<00:00,  2.36it/s]


Epoch: 0074 lr: 0.00125 loss: 7.3561 time: 0.0775s


100%|██████████| 79/79 [00:32<00:00,  2.45it/s]


Epoch: 0075 lr: 0.00125 loss: 8.0612 time: 0.0626s


100%|██████████| 79/79 [00:29<00:00,  2.64it/s]


Epoch: 0076 lr: 0.00125 loss: 12.2521 time: 0.0600s


100%|██████████| 79/79 [00:29<00:00,  2.67it/s]


Epoch: 0077 lr: 0.00125 loss: 12.2737 time: 0.0673s


100%|██████████| 79/79 [00:30<00:00,  2.62it/s]


Epoch: 0078 lr: 0.00125 loss: 16.2074 time: 0.0583s


100%|██████████| 79/79 [00:29<00:00,  2.67it/s]


Epoch: 0079 lr: 0.00125 loss: 122.8599 time: 0.0464s


100%|██████████| 79/79 [00:28<00:00,  2.73it/s]


Epoch: 0080 lr: 0.000625 loss: 308.0220 time: 0.0634s


100%|██████████| 79/79 [00:28<00:00,  2.73it/s]


Epoch: 0081 lr: 0.000625 loss: 15.7409 time: 0.0503s


100%|██████████| 79/79 [00:29<00:00,  2.72it/s]


Epoch: 0082 lr: 0.000625 loss: 5.4147 time: 0.0509s


100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Epoch: 0083 lr: 0.000625 loss: 5.4356 time: 0.0526s


100%|██████████| 79/79 [00:29<00:00,  2.67it/s]


Epoch: 0084 lr: 0.000625 loss: 5.4556 time: 0.0661s


100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Epoch: 0085 lr: 0.000625 loss: 5.4697 time: 0.0470s


100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Epoch: 0086 lr: 0.000625 loss: 5.2919 time: 0.0532s


100%|██████████| 79/79 [00:29<00:00,  2.66it/s]


Epoch: 0087 lr: 0.000625 loss: 5.1703 time: 0.0513s


100%|██████████| 79/79 [00:28<00:00,  2.73it/s]


Epoch: 0088 lr: 0.000625 loss: 5.1784 time: 0.0515s


100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Epoch: 0089 lr: 0.000625 loss: 5.3415 time: 0.0438s


100%|██████████| 79/79 [00:29<00:00,  2.66it/s]


Epoch: 0090 lr: 0.000625 loss: 5.4881 time: 0.0498s


100%|██████████| 79/79 [00:29<00:00,  2.63it/s]


Epoch: 0091 lr: 0.000625 loss: 5.3387 time: 0.0431s


100%|██████████| 79/79 [00:30<00:00,  2.56it/s]


Epoch: 0092 lr: 0.000625 loss: 5.5561 time: 0.0611s


100%|██████████| 79/79 [00:31<00:00,  2.50it/s]


Epoch: 0093 lr: 0.000625 loss: 5.7840 time: 0.0572s


100%|██████████| 79/79 [00:32<00:00,  2.42it/s]


Epoch: 0094 lr: 0.000625 loss: 6.2939 time: 0.0583s


 22%|██▏       | 17/79 [00:07<00:26,  2.37it/s]


KeyboardInterrupt: 

In [None]:
# from manifolds.hyperboloid import Hyperboloid
# import torch
# size = (2,5,10)
# a = 2*torch.ones(size)
# a = torch.concat([torch.zeros((2,5,1)),a],dim=2)
# print(a)
# hyper = Hyperboloid()
# c = 1
# a_H = hyper.expmap0(a,c)  #
# print(a_H)
# a_H_p = hyper.proj(a_H,c)  #对双曲空间的向量没有影响
# print(a_H_p)

# a_tan = hyper.logmap0(a_H,c)
# print(a_tan)
#
# a_tan_p = hyper.proj_tan0(a_tan,c)  #  对已经在o点切空间的向量没有影响
# print(a_tan_p)

# a_tan_a = hyper.logmap(a_H,a_H,c)  # 向量在其切空间是0
# print(a_tan_a)



