In [1]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from Model.HGDM import HyperbolicAE,HyperbolicDiffusion
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm

# 获得训练数据

In [2]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

133885


In [3]:
train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=30000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split30000-10000.npz"),
    )
print(len(train),len(val),len(test))

30000 10000 93885


In [4]:
train_loader = spk.AtomsLoader(train, batch_size=256, shuffle=False)
val_loader = spk.AtomsLoader(val, batch_size=256)

In [5]:

import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

model = HyperbolicAE(args)

optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
logging.info(f"Total number of parameters: {tot_params}")
device = torch.device('cuda')
# Train model
t_total = time.time()

model = model.to(device)


In [6]:
step = 0
torch.set_printoptions(profile="full")
for epoch in range(args.epochs):
    model.train()
    loss_sum,n,t = 0,0,0.0
    counter = 0
    KL_sum = 0
    for input in (train_loader):
        # counter+=1
        # if counter>10:
        #     sys.exit(0)
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        optimizer.zero_grad()
        loss,KL = model(input)
        step+=1
        print('step',step,' loss:',loss,' lr: ',lr_scheduler.get_last_lr())
        # print('KL:',KL)
        # curvatures = list(model.get_submodule('encoder.curvatures'))
        # print('encoder:',curvatures)
        # curvatures = list(model.get_submodule('decoder.curvatures'))
        # print('decoder:',curvatures)

        loss+=KL
        loss.backward()
        loss_sum += loss
        KL_sum+=KL
        n += 1
        # if args.grad_clip is not None:
        #     max_norm = float(args.grad_clip)
        #     all_params = list(model.parameters())
        #     for param in all_params:
        #         torch.nn.utils.clip_grad_norm_(param, max_norm)

        if args.grad_clip is not None:
            grad_clip = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_value_(param, grad_clip)
        optimizer.step()
        en_curvatures = model.get_submodule('encoder.curvatures')
        for p in en_curvatures.parameters():
            p.data.clamp_(1e-8)
        de_curvatures = model.get_submodule('decoder.curvatures')
        for p in de_curvatures.parameters():
            p.data.clamp_(1e-8)
        lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'KL:{:.4f}'.format(KL_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)
        # logging.info(str)
        # curvatures = list(model.get_submodule('encoder.curvatures'))
        # print('encoder:',curvatures)
        # curvatures = list(model.get_submodule('decoder.curvatures'))
        # print('decoder:',curvatures)

    # model.eval()
    # with torch.no_grad():
    #     loss_sum,n = 0.0,0
    #     for input in tqdm(val_loader):
    #         for key in input:
    #             input[key] = input[key].to(torch.device('cuda'))
    #         t = time.time()
    #         model.train()
    #         optimizer.zero_grad()
    #         loss,KL = model(input)
    #         loss+=KL
    #         n += 1
    #         loss_sum += loss
    #
    #     print('val_loss:',loss_sum.item()/n)

"""


"""

step 1  loss: tensor(123.1349, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 2  loss: tensor(96.9694, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 3  loss: tensor(84.8839, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 4  loss: tensor(72.3122, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 5  loss: tensor(66.6071, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 6  loss: tensor(66.0879, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 7  loss: tensor(59.4141, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 8  loss: tensor(60.4246, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 9  loss: tensor(50.9489, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 10  loss: tensor(52.1512, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 11  loss: tensor(55.9561, device='cuda:0', grad_fn=<DivBackward0>)  lr:  [0.001]
step 12  loss: tensor(54.5654, device='cuda:0', grad_fn=<DivBa

KeyboardInterrupt: 

In [None]:

"""
hyperbolid
val_loss: tensor(0.0009, device='cuda:0') encoder结束后投影到欧氏空间
val_loss: tensor(1.5983e-05, device='cuda:0') encoder结束后保持在双曲空间
欧氏空间：
val_loss: tensor(0.1687, device='cuda:0')

val_loss: tensor(0.2672, device='cuda:0') 坐标在双曲空间 encoder结束后保持在双曲空间
val_loss: tensor(0.0075, device='cuda:0') 坐标在双曲空间 encoder结束后投影到欧氏空间
"""

In [None]:
torch.save(model.encoder.state_dict(), './saved_model/'+args.model+'-encoder_kl.pt')
torch.save(model.decoder.state_dict(), './saved_model/'+args.model+'-decoder_kl.pt')