In [2]:
import torch
import yaml
import ast
#from dig.threedgraph.dataset import QM93D
from dig.threedgraph.dataset import MD17
from dig.threedgraph.dataset.PygTobermorite import Tobermorite
from dig.threedgraph.method import DimeNetPP,SchNet,SphereNet #SchNet, DimeNetPP, ComENet
from dig.threedgraph.method import run
from dig.threedgraph.evaluation import ThreeDEvaluator

# Load configs

In [3]:
with open('config.yaml', 'r') as c:
    config = yaml.safe_load(c)
    
# For strings that yaml doesn't parse (e.g. None)
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

In [4]:
name = config['name']
n_train = config['n_train']
n_val = config['n_val']
seed = config['seed']
energy_and_force = config['energy_and_force']
cutoff = config['cutoff']
num_layers = config['num_layers']
hidden_channels = config['hidden_channels']
out_channels = config['out_channels']
int_emb_size = config['int_emb_size']
basis_emb_size = config['basis_emb_size']
out_emb_channels = config['out_emb_channels']
num_spherical = config['num_spherical']
num_radial = config['num_radial']
envelope_exponent = config['envelope_exponent']
num_before_skip = config['num_before_skip']
num_after_skip = config['num_after_skip']
num_output_layers = config['num_output_layers']

epochs = config['epochs']
batch_size = config['batch_size']
vt_batch_size = config['vt_batch_size']
lr = config['lr']
lr_decay_factor = config['lr_decay_factor']
lr_decay_step_size = config['lr_decay_step_size']

In [5]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

# Dataset

In [6]:
dataset = Tobermorite(root='dataset/', name=name)

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=n_train, valid_size=n_val, seed=seed)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

train, validaion, test: 950 50 6000




# Model

In [6]:
model_dime = DimeNetPP(energy_and_force=energy_and_force, cutoff=cutoff, num_layers=num_layers, 
        hidden_channels=hidden_channels, out_channels=out_channels, int_emb_size=int_emb_size, 
        basis_emb_size=basis_emb_size, out_emb_channels=out_emb_channels, 
        num_spherical=num_spherical, num_radial=num_radial, envelope_exponent=envelope_exponent, 
        num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers 
        )

model_sphere = SphereNet(energy_and_force=energy_and_force, cutoff=cutoff, num_layers=num_layers, 
        hidden_channels=hidden_channels, out_channels=out_channels, int_emb_size=int_emb_size, 
        basis_emb_size_dist=8, basis_emb_size_angle=8,
                  basis_emb_size_torsion=8, out_emb_channels=out_emb_channels, 
        num_spherical=num_spherical, num_radial=num_radial, envelope_exponent=envelope_exponent, 
        num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers 
        )

model_schnet = SchNet(energy_and_force=energy_and_force, cutoff=cutoff, num_layers=6, 
        hidden_channels=128, out_channels=1,
        num_filters=128, num_gaussians=20 
        )

loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

# FLOPs

## tob14

In [10]:
from torch_geometric.data import DataLoader
from tqdm import tqdm
from thop import profile

dataset = Tobermorite(root='dataset/', name='tob14')

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=950, valid_size=50, seed=seed)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

train_loader = DataLoader(train_dataset, 4, shuffle=True)
for step, batch_data in enumerate(tqdm(train_loader)):
    print(device)
    batch_data = batch_data.to(device)
    inputs = batch_data
    model = model_schnet.to(device)
    flops,params=profile(model,(inputs, ))
    print(flops,params)
    break

train, validaion, test: 950 50 6000


  0%|                                                             | 0/238 [00:00<?, ?it/s]

cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
674660352.0 419969.0





## tob11

In [21]:
from torch_geometric.data import DataLoader
from tqdm import tqdm
from thop import profile

dataset = Tobermorite(root='dataset/', name='tob11')

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=950, valid_size=50, seed=seed)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

train_loader = DataLoader(train_dataset, 4, shuffle=True)
for step, batch_data in enumerate(tqdm(train_loader)):
    print(device)
    batch_data = batch_data.to(device)
    inputs = batch_data
    model = model_dime.to(device)
    flops,params=profile(model,(inputs, ))
    print(flops,params)
    break

train, validaion, test: 950 50 6000


  2%|██▊                                                                                                                                  | 5/238 [00:00<00:05, 43.80it/s]

cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3635036800.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3688009984.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3691398272.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3681520768.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3648440320.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3671485184.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3540500992.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3601305600.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3627036800.0 1874176.0


  6%|████████▎                                                                                                                           | 15/238 [00:00<00:04, 45.31it/s]

cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3646200064.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3630194816.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3608015872.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3599044992.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3638970880.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3682642560.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3687725952.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3689535872.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3646137088.0 1874176.0
cuda:0
[INFO] Register count_linear()

 11%|█████████████▊                                                                                                                      | 25/238 [00:00<00:04, 45.93it/s]

3664182528.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3643566208.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3644348160.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3693826816.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3604264832.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3574268928.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3568833920.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3641321344.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3554014720.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3623238528.0 1874176.0
cuda:0
[INFO] 

 15%|███████████████████▍                                                                                                                | 35/238 [00:00<00:04, 46.34it/s]

3641220992.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3639223424.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3585054336.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3640135936.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3667016448.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3594114432.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3631469440.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3617693952.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3586155776.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3621513216.0 1874176.0
cuda:0
[INFO] 

 19%|████████████████████████▉                                                                                                           | 45/238 [00:00<00:04, 46.54it/s]

3712838528.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3701412864.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3750623872.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3680056576.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3675744640.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3653753344.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3673704448.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3593752320.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3685932416.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3660841472.0 1874176.0
cuda:0
[INFO] 

 23%|██████████████████████████████▌                                                                                                     | 55/238 [00:01<00:03, 46.39it/s]

3654619904.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3676447232.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3617541760.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3627356928.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3591937792.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3666564480.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3537354112.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3669082880.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3669942912.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3687128320.0 1874176.0
cuda:0
[INFO] 

 27%|████████████████████████████████████                                                                                                | 65/238 [00:01<00:03, 46.66it/s]

3700620416.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3556873600.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3758749824.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3653088128.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3736496768.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3601306240.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3654929536.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3651324800.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3595100416.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3679060736.0 1874176.0
cuda:0
[INFO] 

 32%|█████████████████████████████████████████▌                                                                                          | 75/238 [00:01<00:03, 46.70it/s]

3667618688.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3635251968.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3617279360.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3616277632.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3651283456.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3622053120.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3658590720.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3604568576.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3629161600.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3587703296.0 1874176.0
cuda:0
[INFO] 

 36%|███████████████████████████████████████████████▏                                                                                    | 85/238 [00:01<00:03, 46.76it/s]

3651283456.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3658910848.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3690150528.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3626669440.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3592729600.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3557822848.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3632944128.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3594670720.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3634082304.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3659870592.0 1874176.0
cuda:0
[INFO] 

 40%|████████████████████████████████████████████████████▋                                                                               | 95/238 [00:02<00:03, 46.71it/s]

3653738240.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3656565632.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3651020416.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3688272384.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3685381376.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3629455488.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3597014656.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3630798336.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3613769728.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3653061888.0 1874176.0
cuda:0
[INFO] 

 44%|█████████████████████████████████████████████████████████▊                                                                         | 105/238 [00:02<00:02, 46.12it/s]

3607816448.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3638032768.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3603247360.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3677087488.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3620888704.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3668683392.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3615060736.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3621135360.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3614751744.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3657520768.0 1874176.0
cuda:0
[INFO] 

 48%|███████████████████████████████████████████████████████████████▎                                                                   | 115/238 [00:02<00:02, 46.24it/s]

3629150464.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3585631616.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3609411200.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3626575616.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3631773824.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3532443904.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3656996608.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3597120256.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3675571456.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3637277696.0 1874176.0
cuda:0
[INFO] 

 53%|████████████████████████████████████████████████████████████████████▊                                                              | 125/238 [00:02<00:02, 46.55it/s]

3606121984.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3649116672.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3561211776.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3626307328.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3672508544.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3639317248.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3622504448.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3608172672.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3740127104.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3593164544.0 1874176.0
cuda:0
[INFO] 

 57%|██████████████████████████████████████████████████████████████████████████▎                                                        | 135/238 [00:02<00:02, 46.58it/s]

3635881728.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3740420992.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3635865984.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3664492800.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3720591872.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3609049088.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3663322496.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3681719552.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3604658432.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3688796544.0 1874176.0
cuda:0
[INFO] 

 61%|███████████████████████████████████████████████████████████████████████████████▊                                                   | 145/238 [00:03<00:02, 46.31it/s]

3614053120.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3659341184.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3583947008.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3683408768.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3641452544.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3618937088.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3705063552.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3660211712.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3594785536.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3676525952.0 1874176.0
cuda:0
[INFO] 

 65%|█████████████████████████████████████████████████████████████████████████████████████▎                                             | 155/238 [00:03<00:01, 46.17it/s]

3596054912.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3646378496.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3599480576.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3657436800.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3614877696.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3631307392.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3702854784.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3678960384.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3654110848.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3701097984.0 1874176.0
cuda:0
[INFO] 

 69%|██████████████████████████████████████████████████████████████████████████████████████████▊                                        | 165/238 [00:03<00:01, 46.31it/s]

3679133568.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3607165696.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3629969152.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3626192512.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3677323648.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3648618112.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3627608832.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3600849664.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3599748224.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3636033920.0 1874176.0
cuda:0
[INFO] 

 74%|████████████████████████████████████████████████████████████████████████████████████████████████▎                                  | 175/238 [00:03<00:01, 46.36it/s]

3590053760.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3575711488.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3614856064.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3612102144.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3622777344.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3686687488.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3546387328.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3654158080.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3590431616.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3596301568.0 1874176.0
cuda:0
[INFO] 

 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▊                             | 185/238 [00:03<00:01, 46.62it/s]

3545951744.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3668841472.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3599491072.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3558253824.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3636594816.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3599144704.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3668279936.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3646888192.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3721908480.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3652080512.0 1874176.0
cuda:0
[INFO] 

 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▎                       | 195/238 [00:04<00:00, 46.63it/s]

3605932416.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3688413440.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3704843136.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3614430976.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3601642112.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3583611776.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3622247296.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3624833920.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3613754624.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3600644992.0 1874176.0
cuda:0
[INFO] 

 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                  | 205/238 [00:04<00:00, 46.44it/s]

3709160960.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3675885696.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3616298624.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3663680000.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3698883968.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3697776640.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3666706176.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3637696896.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3657882880.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3597807104.0 1874176.0
cuda:0
[INFO] 

 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 215/238 [00:04<00:00, 46.55it/s]

3658994816.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3600509184.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3677381376.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3620563968.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3587703296.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3731770880.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3664162176.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3718907904.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3647154560.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3576665984.0 1874176.0
cuda:0
[INFO] 

 95%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊       | 225/238 [00:04<00:00, 46.54it/s]

3683891584.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3537741824.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3704602368.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3573172096.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3559208320.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3632723712.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3630693376.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3609184896.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3683324800.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3604327808.0 1874176.0
cuda:0
[INFO] 

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 238/238 [00:05<00:00, 46.36it/s]

3574840960.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3620752256.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3638132480.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3600907392.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3598824576.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3706611072.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3645927168.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3604500992.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
3622425728.0 1874176.0
cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
1784372352.0 1874176.0





## tob9

In [28]:
from torch_geometric.data import DataLoader
from tqdm import tqdm
from thop import profile

dataset = Tobermorite(root='dataset/', name='tob14')

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=950, valid_size=50, seed=seed)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

train_loader = DataLoader(train_dataset, 4, shuffle=True)
for step, batch_data in enumerate(tqdm(train_loader)):
    print(device)
    batch_data = batch_data.to(device)
    inputs = batch_data
    model = model_sphere.to(device)
    flops,params=profile(model,(inputs, ))
    print(flops,params)
    break

train, validaion, test: 950 50 6000


  0%|                                                                                                                                             | 0/238 [00:00<?, ?it/s]

cuda:0
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
4284177536.0 1877952.0



