In [1]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from MyModel.HGCN import RegModel
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm

# 获得训练数据

In [2]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

133885


In [3]:
train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=10000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split10000.npz"),
    )
print(len(train),len(test),len(val))

10000 113885 10000


In [4]:
train_loader = spk.AtomsLoader(train, batch_size=2, shuffle=False)
val_loader = spk.AtomsLoader(val, batch_size=10)

In [5]:
import sys
import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

model = RegModel(args)
optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
logging.info(f"Total number of parameters: {tot_params}")
device = torch.device('cuda')
# Train model
t_total = time.time()
model = model.to(device)
loss_fun = torch.nn.MSELoss(reduction='mean')
for epoch in range(args.epochs):
    loss_sum,n,t = 0,0,0.0
    counter = 0
    for input in tqdm(train_loader):
        counter+=1
        if counter>10:
            sys.exit(0)
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        y_hat = model(input)
        print('y_hat',y_hat[:5])
        print('input[QM9.U0]',input[QM9.U0][:5])
        loss = loss_fun(y_hat,input[QM9.U0])

        loss.backward()
        loss_sum += loss
        n += 1
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
    lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)
        logging.info(str)



    # if (epoch + 1) % args.eval_freq == 0:
    #     model.eval()
    #     embeddings = model.encode(data['features'], data['adj_train_norm'])
    #     val_metrics = model.compute_metrics(embeddings, data, 'val')
    #     if (epoch + 1) % args.log_freq == 0:
    #         logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
    #     if model.has_improved(best_val_metrics, val_metrics):
    #         best_test_metrics = model.compute_metrics(embeddings, data, 'test')
    #         best_emb = embeddings.cpu()
    #         if args.save:
    #             np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy())
    #         best_val_metrics = val_metrics
    #         counter = 0
    #     else:
    #         counter += 1
    #         if counter == args.patience and epoch > args.min_epochs:
    #             logging.info("Early stopping")
    #             break

  0%|          | 0/5000 [00:00<?, ?it/s]

x tensor([[[ 0.0000, -0.5446,  0.9108,  0.2786, -0.4618,  1.1062,  1.2304,
          -0.4602,  0.5016, -0.7519,  2.0885, -0.5970],
         [ 0.0000,  0.2087,  1.1108, -0.2929,  1.3175, -0.7258,  0.5064,
           1.1288, -0.0540,  0.2809,  1.1179, -0.7407],
         [ 0.0000, -0.5446,  0.9108,  0.2786, -0.4618,  1.1062,  1.2304,
          -0.4602,  0.5016,  0.6855,  0.5458,  0.4604],
         [ 0.0000, -0.5446,  0.9108,  0.2786, -0.4618,  1.1062,  1.2304,
          -0.4602,  0.5016, -0.2145, -0.4859,  1.1365],
         [ 0.0000,  0.2087,  1.1108, -0.2929,  1.3175, -0.7258,  0.5064,
           1.1288, -0.0540, -1.4567, -0.7735,  0.5829],
         [ 0.0000, -0.5446,  0.9108,  0.2786, -0.4618,  1.1062,  1.2304,
          -0.4602,  0.5016,  0.9962, -0.9271,  0.3378],
         [ 0.0000,  0.2087,  1.1108, -0.2929,  1.3175, -0.7258,  0.5064,
           1.1288, -0.0540,  0.6843, -1.5223, -0.8910],
         [ 0.0000,  1.4676,  0.2958, -0.7246, -0.3839,  0.3165,  0.9067,
          -0.8489,  1.

  0%|          | 1/5000 [00:02<2:51:09,  2.05s/it]

tensor([[[ 3.1491e-03, -3.4243e-03, -1.7592e-03, -3.2450e-05, -1.0327e-05,
          -9.5726e-03,  1.9483e-04,  1.9529e-03,  3.1166e-03, -1.0110e-03,
           2.9548e-03,  1.6464e-03],
         [-7.1954e-04, -7.8980e-03,  5.3153e-03, -8.5577e-03,  7.5670e-03,
          -7.9333e-04, -7.3202e-03,  1.9306e-04, -1.1148e-03, -5.0315e-04,
          -4.7744e-03,  7.2798e-04],
         [ 4.7135e-03, -1.1809e-02,  2.8753e-03, -8.0823e-03,  6.0295e-03,
          -1.5919e-02, -6.1044e-03,  3.0517e-03,  3.9249e-03, -6.3651e-04,
          -2.1621e-04,  2.1409e-03],
         [ 4.4981e-03, -1.7827e-02,  8.1534e-03, -6.6716e-03,  4.7861e-03,
          -2.3365e-02, -6.6313e-03,  7.5228e-03,  1.2677e-02,  3.7326e-04,
           3.3340e-03,  4.4853e-03],
         [-3.7747e-03, -8.1467e-03,  9.7624e-03, -1.2779e-03,  1.2310e-03,
           2.6013e-03, -5.9442e-03,  3.4061e-03,  5.7506e-03,  2.4806e-03,
          -1.9957e-03,  4.1895e-03],
         [ 5.5089e-03, -1.3520e-02,  4.3858e-03, -1.0307e-02,  8.

  0%|          | 3/5000 [00:02<47:29,  1.75it/s]  

act tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
  

  0%|          | 4/5000 [00:02<33:27,  2.49it/s]

x tensor([[[ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan,  1.4299e+00,
           2.3265e+00, -8.9943e-01],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan,  1.3156e+00,
           9.5107e-01, -4.1068e-01],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan,  1.9411e+00,
           3.3492e-01, -9.1542e-01],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan,  9.2371e-01,
           1.0036e+00,  1.4578e+00],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan, -5.5339e-02,
           4.1235e-01, -4.6968e-01],
         [ 0.0000e+00,         nan,         nan,         nan,  

  0%|          | 6/5000 [00:02<20:01,  4.15it/s]

tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
      

  0%|          | 7/5000 [00:02<17:01,  4.89it/s]

x tensor([[[ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan,  3.8333e-01,
           2.6501e+00, -9.0880e-01],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan, -7.5508e-01,
           1.7772e+00, -3.5296e-01],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan, -2.6028e-01,
           6.0867e-01,  4.6301e-01],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan, -4.5582e-01,
           7.0239e-01,  1.7254e+00],
         [ 0.0000e+00,         nan,         nan,         nan,         nan,
                  nan,         nan,         nan,         nan,  5.1134e-02,
          -4.1711e-01,  2.4166e+00],
         [ 0.0000e+00,         nan,         nan,         nan,  

  0%|          | 9/5000 [00:03<13:15,  6.28it/s]

tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
      

  0%|          | 10/5000 [00:03<26:55,  3.09it/s]

tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
      




SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# from manifolds.hyperboloid import Hyperboloid
# import torch
# size = (2,5,10)
# a = 2*torch.ones(size)
# a = torch.concat([torch.zeros((2,5,1)),a],dim=2)
# print(a)
# hyper = Hyperboloid()
# c = 1
# a_H = hyper.expmap0(a,c)  #
# print(a_H)
# a_H_p = hyper.proj(a_H,c)  #对双曲空间的向量没有影响
# print(a_H_p)

# a_tan = hyper.logmap0(a_H,c)
# print(a_tan)
#
# a_tan_p = hyper.proj_tan0(a_tan,c)  #  对已经在o点切空间的向量没有影响
# print(a_tan_p)

# a_tan_a = hyper.logmap(a_H,a_H,c)  # 向量在其切空间是0
# print(a_tan_a)



