In [None]:
import torch
from schnetpack.datasets import QM9
import schnetpack as spk
import os
from my_config import config_args
from Model.HGDM import HyperbolicAE,HyperbolicDiffusion
import optimizers
import numpy as np
import logging
import time
from tqdm import tqdm
from Model import Encoders, Decoders

In [None]:
qm9data = QM9('./data/qm9.db', download=True,load_only=[QM9.U0])
qm9split = './data/qm9split'
print(len(qm9data))

train, val, test = spk.train_test_split(
        data=qm9data,
        num_train=20000,
        num_val=10000,
        split_file=os.path.join(qm9split, "split20000-10000.npz"),
    )
print(len(train),len(val),len(test))

train_loader = spk.AtomsLoader(train, batch_size=256, shuffle=False)
val_loader = spk.AtomsLoader(val, batch_size=256)

In [None]:
import json
class obj(object):
    def __init__(self, dict_):
        self.__dict__.update(dict_)
args = json.loads(json.dumps(config_args), object_hook=obj)

In [None]:
encoder_path = './saved_model/'+args.model+'-encoder_KL.pt'
decoder_path = './saved_model/'+args.model+'-decoder_KL.pt'

encoder = getattr(Encoders, args.model)(args)
encoder.load_state_dict(torch.load(encoder_path))
if args.model == 'MLP':
    decoder = Decoders.model2decoder[args.model](None, args)
else:
    decoder = Decoders.model2decoder[args.model](encoder.curvatures, args)
decoder.load_state_dict(torch.load(decoder_path))

device = torch.device('cuda')
model = HyperbolicDiffusion(args,encoder,decoder)
model = model.to(device)

In [None]:
optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr,
                                                    weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=args.lr_reduce_freq,
    gamma=float(args.gamma)
)
tot_params = sum([np.prod(p.size()) for p in model.parameters()])
print(f"Total number of parameters: {tot_params}")

# Train model
t_total = time.time()

In [None]:

for epoch in range(args.epochs):
    model.train()
    counter = 0
    loss_sum,n = 0.0,0
    for input in tqdm(train_loader):
        for key in input:
            input[key] = input[key].to(torch.device('cuda'))
        t = time.time()
        model.train()
        optimizer.zero_grad()
        loss = model(input).sum()
        print(loss,lr_scheduler.get_last_lr())
        loss.backward()
        loss_sum += loss
        n += 1
        # if args.grad_clip is not None:
        #     grad_clip = float(args.grad_clip)
        #     all_params = list(model.parameters())
        #     for param in all_params:
        #         torch.nn.utils.clip_grad_value_(param, grad_clip)
        optimizer.step()
        lr_scheduler.step()
    if (epoch + 1) % args.log_freq == 0:
        str = " ".join(['Epoch: {:04d}'.format(epoch + 1),
                               'lr: {}'.format(lr_scheduler.get_last_lr()[0]),
                               'loss: {:.4f}'.format(loss_sum/n),
                               'time: {:.4f}s'.format(time.time() - t)
                               ])
        print(str)

    model.eval()
    with torch.no_grad():
        loss_sum,n = 0.0,0
        for input in tqdm(val_loader):
            for key in input:
                input[key] = input[key].to(torch.device('cuda'))
            t = time.time()
            model.train()
            optimizer.zero_grad()
            loss = model(input).sum()
            n += 1
            loss_sum += loss

        print('val_loss:',loss_sum.item()/n)
