In [1]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from Model.HGDM import HyperbolicAE,HyperbolicDiffusion
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm

# 获得训练数据

In [2]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

133885


In [3]:
train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=30000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split30000-10000.npz"),
    )
print(len(train),len(val),len(test))

30000 10000 93885


In [4]:
train_loader = spk.AtomsLoader(train, batch_size=256, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=256)

In [5]:
import sys
import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

model = HyperbolicAE(args)

optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
logging.info(f"Total number of parameters: {tot_params}")
device = torch.device('cuda')
# Train model
t_total = time.time()

model = model.to(device)
# loss_fun = torch.nn.MSELoss()
# torch.set_printoptions(threshold=np.inf)
model.train()

HyperbolicAE(
  (embedding): Embedding(100, 11, padding_idx=0)
  (encoder): HNN(
    (curvatures): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (4): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (5): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (6): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (7): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
        (8): Parameter containing: [torch.cuda.FloatTensor of size 1 (GPU 0)]
    )
    (layers): Sequential(
      (0): HNNLayer(
        (linear): HypLinear(
          in_features=15, out_features=20, c=Parameter containing:
          tensor([1.], device='cuda

In [6]:

for epoch in range(args.epochs):

    loss_sum,n,t = 0,0,0.0
    counter = 0
    for input in tqdm(train_loader):
        # counter+=1
        # if counter>10:
        #     sys.exit(0)
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        loss = model(input)
        # print('y_hat',y_hat[:2])
        # print('input[QM9.U0]',input[QM9.U0][:2])
        # loss = loss_fun(y_hat,input[QM9.U0])
        # curvatures = list(model.get_submodule('curvatures'))
        # print(curvatures)

        # print(loss)
        loss.backward()
        loss_sum += loss
        n += 1
        if args.grad_clip is not None:
            max_norm = float(args.grad_clip)
            all_params = list(model.parameters())
            for param in all_params:
                torch.nn.utils.clip_grad_norm_(param, max_norm)
        optimizer.step()
    lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)
        # logging.info(str)
        # curvatures = list(model.get_submodule('encoder.curvatures'))
        # print(curvatures)

    model.eval()
    with torch.no_grad():
        loss_sum,n = 0.0,0
        for input in tqdm(val_loader):
            for key in input:
                input[key] = input[key].to(torch.device('cuda'))
            t = time.time()
            model.train()
            optimizer.zero_grad()
            loss = model(input)
            n += 1
            loss_sum += loss

        print('val_loss:',loss_sum.item()/n)

"""
Epoch: 0020 lr: 0.0025 loss: 24.7880 time: 0.0977s
Epoch: 0020 lr: 0.0025 loss: 13.2554 time: 0.0969s
"""

100%|██████████| 118/118 [00:47<00:00,  2.51it/s]


Epoch: 0001 lr: 0.01 loss: 101.8151 time: 0.1134s


100%|██████████| 40/40 [00:13<00:00,  2.93it/s]


val_loss: 80.836376953125


100%|██████████| 118/118 [00:45<00:00,  2.58it/s]


Epoch: 0002 lr: 0.01 loss: 76.4348 time: 0.1134s


 25%|██▌       | 10/40 [00:03<00:09,  3.04it/s]


KeyboardInterrupt: 

In [None]:

model.eval()
with torch.no_grad():
    loss_sum,n = 0.0,0
    for input in tqdm(val_loader):
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        loss = model(input)
        n += 1
        loss_sum += loss

    print('val_loss:',loss_sum/n)

    # if (epoch + 1) % args.eval_freq == 0:
    #     model.eval()
    #     embeddings = model.encode(data['features'], data['adj_train_norm'])
    #     val_metrics = model.compute_metrics(embeddings, data, 'val')
    #     if (epoch + 1) % args.log_freq == 0:
    #         logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')]))
    #     if model.has_improved(best_val_metrics, val_metrics):
    #         best_test_metrics = model.compute_metrics(embeddings, data, 'test')
    #         best_emb = embeddings.cpu()
    #         if args.save:
    #             np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy())
    #         best_val_metrics = val_metrics
    #         counter = 0
    #     else:
    #         counter += 1
    #         if counter == args.patience and epoch > args.min_epochs:
    #             logging.info("Early stopping")
    #             break



"""
hyperbolid
val_loss: tensor(0.0009, device='cuda:0') encoder结束后投影到欧氏空间
val_loss: tensor(1.5983e-05, device='cuda:0') encoder结束后保持在双曲空间
欧氏空间：
val_loss: tensor(0.1687, device='cuda:0')

val_loss: tensor(0.2672, device='cuda:0') 坐标在双曲空间 encoder结束后保持在双曲空间
val_loss: tensor(0.0075, device='cuda:0') 坐标在双曲空间 encoder结束后投影到欧氏空间
"""

In [None]:
torch.save(model.encoder.state_dict(), './saved_model/HNN-encoder.pt')
encoder_path = './saved_model/HNN-encoder.pt'
torch.save(model.decoder.state_dict(), './saved_model/HNN-decoder.pt')
decoder_path = './saved_model/HNN-decoder.pt'

In [None]:
from Model import Encoders, Decoders

encoder = getattr(Encoders, args.model)(1, args)
encoder.load_state_dict(torch.load(encoder_path))
decoder = Decoders.model2decoder[args.model](encoder.curvatures, args)
decoder.load_state_dict(torch.load(decoder_path))

In [None]:
from Model.HGDM import HyperbolicDiffusion
model = HyperbolicDiffusion(args,encoder,decoder)
model = model.to(device)
model.eval()
with torch.no_grad():
    loss_sum,n = 0.0,0
    for input in tqdm(val_loader):
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        loss = model(input)
        n += 1
        loss_sum += loss

    print('val_loss:',loss_sum/n)