In [2]:
%config Completer.use_jedi=False

In [3]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
from sklearn.model_selection import train_test_split

from ruslan_nn.schnet import SchNet
import pickle
import wandb

In [4]:
targets = pd.read_csv('ruslan_nn/properties16k.csv')
label=targets['energy_per_atom']

In [5]:
with open('ruslan_nn/structures16k.pickle', 'rb') as file:
    structures = pickle.load(file)

In [6]:
from torch_geometric.data import Data
import torch
import ase
from pymatgen.io.ase import AseAtomsAdaptor

i=0
data_atoms = []
for _id in tqdm(targets._id):
    atoms=AseAtomsAdaptor.get_atoms(structures[str(_id)])
    # set the atomic numbers, positions, and cell
    atom = torch.Tensor(atoms.get_atomic_numbers())
    positions = torch.Tensor(atoms.get_positions())
    natoms = positions.shape[0]
    
    # put the minimum data in torch geometric data object
    data = Data(
        pos=positions,
        z= atom,
       # natoms=natoms,
    )
    
    # calculate energy
    data.y = label[i]
    i=i+1
    data_atoms.append(data)

100%|██████████| 14718/14718 [00:25<00:00, 584.43it/s]


In [7]:
from torch_geometric.data import DataLoader
train_dataset, test_dataset = train_test_split(data_atoms, test_size=0.2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [13]:
wandb.init(
    project="schnet_dichalcogenides", entity="inno-materials-ai", save_code=True, name='baseline'
)

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [14]:
model=SchNet()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
epochs = 200

loss_func = torch.nn.L1Loss() #define loss
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, epochs=epochs,
                                                steps_per_epoch=len(train_loader),
                                                max_lr=1e-3)

In [None]:
for epoch in range(epochs):
    model.train()
    valid_loss=0
    train_loss=0

     #shuffle the training data each epoch
    for d in tqdm(train_loader): #go over each training point
        data = d.to(device)#send data to device
        out = model(data) 
        optimizer.zero_grad() #zero gradients
        #evaluate data point
        loss = loss_func(out.view(-1), data.y.view(-1)) #L1 error loss
         #add loss value to aggregate loss
        loss.backward() #compute gradients
        optimizer.step() #apply optimization
        scheduler.step()
        train_loss += loss.item()
    with torch.no_grad():
        model.eval()     # Optional when not using Model Specific layer
        for d in tqdm(test_loader):
            data = d.to(device)
            target = model(data)
            loss = loss_func(target.view(-1), data.y.view(-1))
            valid_loss += loss.item()
        
    print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, train_loss/len(train_loader)))
    print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, valid_loss/len(test_loader)))
    wandb.log({
        "train_mae": train_loss/len(train_loader),
        "test_mae": valid_loss/len(test_loader),
    })    
    

100%|██████████| 368/368 [01:31<00:00,  4.04it/s]
100%|██████████| 92/92 [00:12<00:00,  7.50it/s]


Epoch: 000, Average loss: 0.30561
Epoch: 000, Average loss: 0.04736


100%|██████████| 368/368 [01:33<00:00,  3.94it/s]
100%|██████████| 92/92 [00:12<00:00,  7.17it/s]


Epoch: 001, Average loss: 0.04068
Epoch: 001, Average loss: 0.03534


100%|██████████| 368/368 [01:35<00:00,  3.85it/s]
100%|██████████| 92/92 [00:12<00:00,  7.30it/s]


Epoch: 002, Average loss: 0.03528
Epoch: 002, Average loss: 0.03318


100%|██████████| 368/368 [01:35<00:00,  3.85it/s]
100%|██████████| 92/92 [00:12<00:00,  7.09it/s]


Epoch: 003, Average loss: 0.03064
Epoch: 003, Average loss: 0.03057


100%|██████████| 368/368 [01:16<00:00,  4.82it/s]
100%|██████████| 92/92 [00:03<00:00, 29.13it/s]


Epoch: 004, Average loss: 0.03046
Epoch: 004, Average loss: 0.02780


100%|██████████| 368/368 [00:31<00:00, 11.71it/s]
100%|██████████| 92/92 [00:03<00:00, 28.71it/s]


Epoch: 005, Average loss: 0.03144
Epoch: 005, Average loss: 0.04568


100%|██████████| 368/368 [00:31<00:00, 11.73it/s]
100%|██████████| 92/92 [00:03<00:00, 28.74it/s]


Epoch: 006, Average loss: 0.04480
Epoch: 006, Average loss: 0.05676


100%|██████████| 368/368 [01:25<00:00,  4.28it/s]
100%|██████████| 92/92 [00:11<00:00,  7.85it/s]


Epoch: 007, Average loss: 0.03729
Epoch: 007, Average loss: 0.07327


100%|██████████| 368/368 [01:29<00:00,  4.13it/s]
100%|██████████| 92/92 [00:09<00:00,  9.97it/s]


Epoch: 008, Average loss: 0.03390
Epoch: 008, Average loss: 0.03934


100%|██████████| 368/368 [01:28<00:00,  4.14it/s]
100%|██████████| 92/92 [00:12<00:00,  7.34it/s]


Epoch: 009, Average loss: 0.02750
Epoch: 009, Average loss: 0.03413


100%|██████████| 368/368 [01:26<00:00,  4.24it/s]
100%|██████████| 92/92 [00:13<00:00,  6.76it/s]


Epoch: 010, Average loss: 0.04253
Epoch: 010, Average loss: 0.05024


100%|██████████| 368/368 [01:27<00:00,  4.23it/s]
100%|██████████| 92/92 [00:12<00:00,  7.63it/s]


Epoch: 011, Average loss: 0.04575
Epoch: 011, Average loss: 0.08276


100%|██████████| 368/368 [01:31<00:00,  4.04it/s]
100%|██████████| 92/92 [00:09<00:00,  9.36it/s]


Epoch: 012, Average loss: 0.04401
Epoch: 012, Average loss: 0.09564


100%|██████████| 368/368 [01:26<00:00,  4.26it/s]
100%|██████████| 92/92 [00:11<00:00,  8.25it/s]


Epoch: 013, Average loss: 0.04873
Epoch: 013, Average loss: 0.03322


100%|██████████| 368/368 [01:19<00:00,  4.65it/s]
100%|██████████| 92/92 [00:11<00:00,  8.03it/s]


Epoch: 014, Average loss: 0.04340
Epoch: 014, Average loss: 0.05577


100%|██████████| 368/368 [01:13<00:00,  5.01it/s]
100%|██████████| 92/92 [00:09<00:00, 10.14it/s]


Epoch: 015, Average loss: 0.04247
Epoch: 015, Average loss: 0.05367


100%|██████████| 368/368 [01:02<00:00,  5.92it/s]
100%|██████████| 92/92 [00:08<00:00, 10.46it/s]


Epoch: 016, Average loss: 0.05453
Epoch: 016, Average loss: 0.03741


100%|██████████| 368/368 [01:20<00:00,  4.58it/s]
100%|██████████| 92/92 [00:10<00:00,  8.66it/s]


Epoch: 017, Average loss: 0.04007
Epoch: 017, Average loss: 0.06679


100%|██████████| 368/368 [00:53<00:00,  6.89it/s]
100%|██████████| 92/92 [00:03<00:00, 29.10it/s]


Epoch: 018, Average loss: 0.03953
Epoch: 018, Average loss: 0.01721


100%|██████████| 368/368 [00:31<00:00, 11.74it/s]
100%|██████████| 92/92 [00:03<00:00, 28.72it/s]


Epoch: 019, Average loss: 0.05278
Epoch: 019, Average loss: 0.02186


100%|██████████| 368/368 [00:45<00:00,  8.14it/s]
100%|██████████| 92/92 [00:06<00:00, 13.90it/s]


Epoch: 020, Average loss: 0.05141
Epoch: 020, Average loss: 0.01474


100%|██████████| 368/368 [00:53<00:00,  6.91it/s]
100%|██████████| 92/92 [00:08<00:00, 10.91it/s]


Epoch: 021, Average loss: 0.04289
Epoch: 021, Average loss: 0.06316


100%|██████████| 368/368 [00:58<00:00,  6.25it/s]
100%|██████████| 92/92 [00:08<00:00, 11.28it/s]


Epoch: 022, Average loss: 0.05912
Epoch: 022, Average loss: 0.02154


100%|██████████| 368/368 [01:11<00:00,  5.16it/s]
100%|██████████| 92/92 [00:08<00:00, 10.30it/s]


Epoch: 023, Average loss: 0.03846
Epoch: 023, Average loss: 0.03565


100%|██████████| 368/368 [01:10<00:00,  5.25it/s]
100%|██████████| 92/92 [00:11<00:00,  8.05it/s]


Epoch: 024, Average loss: 0.04145
Epoch: 024, Average loss: 0.03513


100%|██████████| 368/368 [01:07<00:00,  5.42it/s]
100%|██████████| 92/92 [00:08<00:00, 11.32it/s]


Epoch: 025, Average loss: 0.03836
Epoch: 025, Average loss: 0.02621


100%|██████████| 368/368 [01:16<00:00,  4.83it/s]
100%|██████████| 92/92 [00:13<00:00,  7.03it/s]


Epoch: 026, Average loss: 0.04229
Epoch: 026, Average loss: 0.01197


100%|██████████| 368/368 [01:36<00:00,  3.81it/s]
100%|██████████| 92/92 [00:12<00:00,  7.14it/s]


Epoch: 027, Average loss: 0.04317
Epoch: 027, Average loss: 0.03716


100%|██████████| 368/368 [01:11<00:00,  5.18it/s]
100%|██████████| 92/92 [00:10<00:00,  8.66it/s]


Epoch: 028, Average loss: 0.04571
Epoch: 028, Average loss: 0.03113


100%|██████████| 368/368 [01:21<00:00,  4.54it/s]
100%|██████████| 92/92 [00:11<00:00,  8.26it/s]


Epoch: 029, Average loss: 0.04053
Epoch: 029, Average loss: 0.05846


100%|██████████| 368/368 [01:18<00:00,  4.67it/s]
100%|██████████| 92/92 [00:09<00:00,  9.64it/s]


Epoch: 030, Average loss: 0.04354
Epoch: 030, Average loss: 0.04699


100%|██████████| 368/368 [01:17<00:00,  4.73it/s]
100%|██████████| 92/92 [00:10<00:00,  8.66it/s]


Epoch: 031, Average loss: 0.03317
Epoch: 031, Average loss: 0.02657


100%|██████████| 368/368 [01:19<00:00,  4.60it/s]
100%|██████████| 92/92 [00:10<00:00,  9.16it/s]


Epoch: 032, Average loss: 0.02885
Epoch: 032, Average loss: 0.01759


100%|██████████| 368/368 [00:44<00:00,  8.31it/s]
100%|██████████| 92/92 [00:03<00:00, 29.92it/s]


Epoch: 033, Average loss: 0.02276
Epoch: 033, Average loss: 0.00951


100%|██████████| 368/368 [00:31<00:00, 11.79it/s]
100%|██████████| 92/92 [00:03<00:00, 29.46it/s]


Epoch: 034, Average loss: 0.03179
Epoch: 034, Average loss: 0.02460


100%|██████████| 368/368 [00:31<00:00, 11.71it/s]
100%|██████████| 92/92 [00:03<00:00, 29.51it/s]


Epoch: 035, Average loss: 0.02847
Epoch: 035, Average loss: 0.01344


100%|██████████| 368/368 [01:03<00:00,  5.76it/s]
100%|██████████| 92/92 [00:06<00:00, 13.35it/s]


Epoch: 036, Average loss: 0.02406
Epoch: 036, Average loss: 0.02528


100%|██████████| 368/368 [01:03<00:00,  5.75it/s]
100%|██████████| 92/92 [00:09<00:00,  9.78it/s]


Epoch: 037, Average loss: 0.03010
Epoch: 037, Average loss: 0.01892


100%|██████████| 368/368 [00:59<00:00,  6.22it/s]
100%|██████████| 92/92 [00:08<00:00, 10.62it/s]


Epoch: 038, Average loss: 0.02782
Epoch: 038, Average loss: 0.04536


100%|██████████| 368/368 [01:03<00:00,  5.77it/s]
100%|██████████| 92/92 [00:08<00:00, 11.47it/s]


Epoch: 039, Average loss: 0.02608
Epoch: 039, Average loss: 0.01066


100%|██████████| 368/368 [00:59<00:00,  6.18it/s]
100%|██████████| 92/92 [00:09<00:00, 10.00it/s]


Epoch: 040, Average loss: 0.02694
Epoch: 040, Average loss: 0.01492


100%|██████████| 368/368 [01:04<00:00,  5.73it/s]
100%|██████████| 92/92 [00:08<00:00, 10.31it/s]


Epoch: 041, Average loss: 0.02229
Epoch: 041, Average loss: 0.01839


100%|██████████| 368/368 [01:05<00:00,  5.64it/s]
100%|██████████| 92/92 [00:06<00:00, 14.07it/s]


Epoch: 042, Average loss: 0.02398
Epoch: 042, Average loss: 0.01110


100%|██████████| 368/368 [00:54<00:00,  6.73it/s]
100%|██████████| 92/92 [00:06<00:00, 14.02it/s]


Epoch: 043, Average loss: 0.02233
Epoch: 043, Average loss: 0.00824


100%|██████████| 368/368 [01:01<00:00,  5.98it/s]
100%|██████████| 92/92 [00:10<00:00,  9.04it/s]


Epoch: 044, Average loss: 0.02413
Epoch: 044, Average loss: 0.00571


100%|██████████| 368/368 [01:03<00:00,  5.79it/s]
100%|██████████| 92/92 [00:07<00:00, 12.30it/s]


Epoch: 045, Average loss: 0.02329
Epoch: 045, Average loss: 0.01025


100%|██████████| 368/368 [01:03<00:00,  5.82it/s]
100%|██████████| 92/92 [00:11<00:00,  8.19it/s]


Epoch: 046, Average loss: 0.02345
Epoch: 046, Average loss: 0.00818


100%|██████████| 368/368 [01:05<00:00,  5.60it/s]
100%|██████████| 92/92 [00:11<00:00,  7.98it/s]


Epoch: 047, Average loss: 0.02067
Epoch: 047, Average loss: 0.03958


100%|██████████| 368/368 [01:15<00:00,  4.88it/s]
100%|██████████| 92/92 [00:08<00:00, 10.77it/s]


Epoch: 048, Average loss: 0.02127
Epoch: 048, Average loss: 0.04844


100%|██████████| 368/368 [01:05<00:00,  5.65it/s]
100%|██████████| 92/92 [00:05<00:00, 18.19it/s]


Epoch: 049, Average loss: 0.02307
Epoch: 049, Average loss: 0.03100


100%|██████████| 368/368 [01:08<00:00,  5.37it/s]
100%|██████████| 92/92 [00:03<00:00, 29.49it/s]


Epoch: 050, Average loss: 0.02051
Epoch: 050, Average loss: 0.02548


100%|██████████| 368/368 [00:31<00:00, 11.82it/s]
100%|██████████| 92/92 [00:03<00:00, 29.39it/s]


Epoch: 051, Average loss: 0.02195
Epoch: 051, Average loss: 0.02478


100%|██████████| 368/368 [00:31<00:00, 11.75it/s]
100%|██████████| 92/92 [00:03<00:00, 29.45it/s]


Epoch: 052, Average loss: 0.02066
Epoch: 052, Average loss: 0.02153


100%|██████████| 368/368 [00:31<00:00, 11.73it/s]
100%|██████████| 92/92 [00:03<00:00, 29.33it/s]


Epoch: 053, Average loss: 0.02988
Epoch: 053, Average loss: 0.02476


100%|██████████| 368/368 [00:31<00:00, 11.74it/s]
100%|██████████| 92/92 [00:03<00:00, 29.43it/s]


Epoch: 054, Average loss: 0.01872
Epoch: 054, Average loss: 0.02237


100%|██████████| 368/368 [00:31<00:00, 11.72it/s]
100%|██████████| 92/92 [00:03<00:00, 29.25it/s]


Epoch: 055, Average loss: 0.01922
Epoch: 055, Average loss: 0.01578


100%|██████████| 368/368 [00:31<00:00, 11.68it/s]
100%|██████████| 92/92 [00:03<00:00, 28.56it/s]


Epoch: 056, Average loss: 0.01807
Epoch: 056, Average loss: 0.03080


100%|██████████| 368/368 [00:31<00:00, 11.58it/s]
100%|██████████| 92/92 [00:03<00:00, 29.07it/s]


Epoch: 057, Average loss: 0.01799
Epoch: 057, Average loss: 0.01145


100%|██████████| 368/368 [00:31<00:00, 11.73it/s]
100%|██████████| 92/92 [00:03<00:00, 28.95it/s]


Epoch: 058, Average loss: 0.01900
Epoch: 058, Average loss: 0.02535


100%|██████████| 368/368 [00:31<00:00, 11.70it/s]
100%|██████████| 92/92 [00:03<00:00, 29.19it/s]


Epoch: 059, Average loss: 0.02457
Epoch: 059, Average loss: 0.00790


100%|██████████| 368/368 [01:05<00:00,  5.61it/s]
100%|██████████| 92/92 [00:11<00:00,  8.33it/s]


Epoch: 060, Average loss: 0.01854
Epoch: 060, Average loss: 0.01341


100%|██████████| 368/368 [01:01<00:00,  6.00it/s]
100%|██████████| 92/92 [00:10<00:00,  9.07it/s]


Epoch: 061, Average loss: 0.01756
Epoch: 061, Average loss: 0.03507


100%|██████████| 368/368 [00:54<00:00,  6.76it/s]
100%|██████████| 92/92 [00:05<00:00, 16.14it/s]


Epoch: 062, Average loss: 0.02102
Epoch: 062, Average loss: 0.01403


100%|██████████| 368/368 [00:57<00:00,  6.39it/s]
100%|██████████| 92/92 [00:04<00:00, 22.43it/s]


Epoch: 063, Average loss: 0.02017
Epoch: 063, Average loss: 0.02324


100%|██████████| 368/368 [01:01<00:00,  6.02it/s]
100%|██████████| 92/92 [00:08<00:00, 10.83it/s]


Epoch: 064, Average loss: 0.01498
Epoch: 064, Average loss: 0.02770


100%|██████████| 368/368 [01:05<00:00,  5.62it/s]
100%|██████████| 92/92 [00:08<00:00, 10.51it/s]


Epoch: 065, Average loss: 0.01834
Epoch: 065, Average loss: 0.03897


100%|██████████| 368/368 [00:44<00:00,  8.32it/s]
100%|██████████| 92/92 [00:03<00:00, 28.35it/s]


Epoch: 066, Average loss: 0.01770
Epoch: 066, Average loss: 0.01497


100%|██████████| 368/368 [00:32<00:00, 11.46it/s]
100%|██████████| 92/92 [00:03<00:00, 28.68it/s]


Epoch: 067, Average loss: 0.01461
Epoch: 067, Average loss: 0.00667


100%|██████████| 368/368 [00:32<00:00, 11.44it/s]
100%|██████████| 92/92 [00:03<00:00, 28.72it/s]


Epoch: 068, Average loss: 0.01441
Epoch: 068, Average loss: 0.00861


100%|██████████| 368/368 [00:32<00:00, 11.43it/s]
100%|██████████| 92/92 [00:04<00:00, 19.28it/s]


Epoch: 069, Average loss: 0.01792
Epoch: 069, Average loss: 0.01391


100%|██████████| 368/368 [00:48<00:00,  7.59it/s]
100%|██████████| 92/92 [00:09<00:00,  9.92it/s]


Epoch: 070, Average loss: 0.01603
Epoch: 070, Average loss: 0.01208


100%|██████████| 368/368 [01:12<00:00,  5.06it/s]
100%|██████████| 92/92 [00:10<00:00,  9.05it/s]


Epoch: 071, Average loss: 0.01762
Epoch: 071, Average loss: 0.01250


100%|██████████| 368/368 [01:12<00:00,  5.07it/s]
100%|██████████| 92/92 [00:09<00:00,  9.85it/s]


Epoch: 072, Average loss: 0.01471
Epoch: 072, Average loss: 0.02558


100%|██████████| 368/368 [01:25<00:00,  4.32it/s]
100%|██████████| 92/92 [00:11<00:00,  8.20it/s]


Epoch: 073, Average loss: 0.01729
Epoch: 073, Average loss: 0.00629


100%|██████████| 368/368 [01:26<00:00,  4.24it/s]
100%|██████████| 92/92 [00:12<00:00,  7.57it/s]


Epoch: 074, Average loss: 0.01714
Epoch: 074, Average loss: 0.01124


100%|██████████| 368/368 [01:26<00:00,  4.23it/s]
100%|██████████| 92/92 [00:03<00:00, 29.46it/s]


Epoch: 075, Average loss: 0.01302
Epoch: 075, Average loss: 0.00812


100%|██████████| 368/368 [00:31<00:00, 11.79it/s]
100%|██████████| 92/92 [00:03<00:00, 29.49it/s]


Epoch: 076, Average loss: 0.01784
Epoch: 076, Average loss: 0.00573


100%|██████████| 368/368 [00:31<00:00, 11.79it/s]
100%|██████████| 92/92 [00:03<00:00, 29.15it/s]


Epoch: 077, Average loss: 0.01383
Epoch: 077, Average loss: 0.01401


100%|██████████| 368/368 [01:02<00:00,  5.85it/s]
100%|██████████| 92/92 [00:09<00:00,  9.73it/s]


Epoch: 078, Average loss: 0.01912
Epoch: 078, Average loss: 0.00667


100%|██████████| 368/368 [01:08<00:00,  5.35it/s]
100%|██████████| 92/92 [00:09<00:00,  9.23it/s]


Epoch: 079, Average loss: 0.01242
Epoch: 079, Average loss: 0.00968


100%|██████████| 368/368 [01:15<00:00,  4.86it/s]
100%|██████████| 92/92 [00:10<00:00,  8.90it/s]


Epoch: 080, Average loss: 0.01424
Epoch: 080, Average loss: 0.01041


100%|██████████| 368/368 [01:06<00:00,  5.50it/s]
100%|██████████| 92/92 [00:06<00:00, 13.40it/s]


Epoch: 081, Average loss: 0.01443
Epoch: 081, Average loss: 0.02631


100%|██████████| 368/368 [01:08<00:00,  5.39it/s]
100%|██████████| 92/92 [00:08<00:00, 11.18it/s]


Epoch: 082, Average loss: 0.01072
Epoch: 082, Average loss: 0.00417


100%|██████████| 368/368 [00:55<00:00,  6.61it/s]
100%|██████████| 92/92 [00:09<00:00,  9.24it/s]


Epoch: 083, Average loss: 0.01346
Epoch: 083, Average loss: 0.00694


100%|██████████| 368/368 [01:06<00:00,  5.57it/s]
100%|██████████| 92/92 [00:09<00:00,  9.54it/s]


Epoch: 084, Average loss: 0.01214
Epoch: 084, Average loss: 0.01250


100%|██████████| 368/368 [01:05<00:00,  5.59it/s]
100%|██████████| 92/92 [00:06<00:00, 14.38it/s]


Epoch: 085, Average loss: 0.01088
Epoch: 085, Average loss: 0.00445


100%|██████████| 368/368 [01:04<00:00,  5.70it/s]
100%|██████████| 92/92 [00:08<00:00, 11.10it/s]


Epoch: 086, Average loss: 0.01156
Epoch: 086, Average loss: 0.00546


100%|██████████| 368/368 [01:06<00:00,  5.51it/s]
100%|██████████| 92/92 [00:11<00:00,  8.24it/s]


Epoch: 087, Average loss: 0.01458
Epoch: 087, Average loss: 0.00747


100%|██████████| 368/368 [01:11<00:00,  5.18it/s]
100%|██████████| 92/92 [00:10<00:00,  9.11it/s]


Epoch: 088, Average loss: 0.01359
Epoch: 088, Average loss: 0.00636


100%|██████████| 368/368 [01:10<00:00,  5.20it/s]
100%|██████████| 92/92 [00:10<00:00,  9.13it/s]


Epoch: 089, Average loss: 0.01318
Epoch: 089, Average loss: 0.01493


100%|██████████| 368/368 [01:11<00:00,  5.15it/s]
100%|██████████| 92/92 [00:10<00:00,  8.49it/s]


Epoch: 090, Average loss: 0.01072
Epoch: 090, Average loss: 0.00804


100%|██████████| 368/368 [01:01<00:00,  5.94it/s]
100%|██████████| 92/92 [00:03<00:00, 29.43it/s]


Epoch: 091, Average loss: 0.01244
Epoch: 091, Average loss: 0.01116


100%|██████████| 368/368 [00:31<00:00, 11.75it/s]
100%|██████████| 92/92 [00:03<00:00, 29.58it/s]


Epoch: 092, Average loss: 0.01138
Epoch: 092, Average loss: 0.01112


100%|██████████| 368/368 [00:31<00:00, 11.80it/s]
100%|██████████| 92/92 [00:03<00:00, 29.08it/s]


Epoch: 093, Average loss: 0.01278
Epoch: 093, Average loss: 0.00703


100%|██████████| 368/368 [00:35<00:00, 10.42it/s]
100%|██████████| 92/92 [00:10<00:00,  9.19it/s]


Epoch: 094, Average loss: 0.01299
Epoch: 094, Average loss: 0.02090


100%|██████████| 368/368 [01:08<00:00,  5.38it/s]
100%|██████████| 92/92 [00:08<00:00, 11.06it/s]


Epoch: 095, Average loss: 0.01213
Epoch: 095, Average loss: 0.00703


100%|██████████| 368/368 [01:01<00:00,  5.96it/s]
100%|██████████| 92/92 [00:08<00:00, 10.25it/s]


Epoch: 096, Average loss: 0.00975
Epoch: 096, Average loss: 0.00864


100%|██████████| 368/368 [01:03<00:00,  5.76it/s]
100%|██████████| 92/92 [00:11<00:00,  8.16it/s]


Epoch: 097, Average loss: 0.00975
Epoch: 097, Average loss: 0.00663


100%|██████████| 368/368 [01:16<00:00,  4.84it/s]
100%|██████████| 92/92 [00:08<00:00, 11.48it/s]


Epoch: 098, Average loss: 0.00970
Epoch: 098, Average loss: 0.01037


100%|██████████| 368/368 [01:06<00:00,  5.55it/s]
100%|██████████| 92/92 [00:09<00:00,  9.90it/s]


Epoch: 099, Average loss: 0.00848
Epoch: 099, Average loss: 0.00527


100%|██████████| 368/368 [01:02<00:00,  5.91it/s]
100%|██████████| 92/92 [00:10<00:00,  9.14it/s]


Epoch: 100, Average loss: 0.00871
Epoch: 100, Average loss: 0.00828


100%|██████████| 368/368 [01:09<00:00,  5.27it/s]
100%|██████████| 92/92 [00:13<00:00,  6.83it/s]


Epoch: 101, Average loss: 0.00961
Epoch: 101, Average loss: 0.00963


100%|██████████| 368/368 [00:55<00:00,  6.61it/s]
100%|██████████| 92/92 [00:04<00:00, 18.72it/s]


Epoch: 102, Average loss: 0.00975
Epoch: 102, Average loss: 0.00711


100%|██████████| 368/368 [00:44<00:00,  8.33it/s]
100%|██████████| 92/92 [00:07<00:00, 12.98it/s]


Epoch: 103, Average loss: 0.00846
Epoch: 103, Average loss: 0.00465


100%|██████████| 368/368 [00:41<00:00,  8.77it/s]
100%|██████████| 92/92 [00:04<00:00, 20.97it/s]


Epoch: 104, Average loss: 0.00932
Epoch: 104, Average loss: 0.00632


100%|██████████| 368/368 [00:39<00:00,  9.37it/s]
100%|██████████| 92/92 [00:04<00:00, 20.48it/s]


Epoch: 105, Average loss: 0.00947
Epoch: 105, Average loss: 0.01288


 28%|██▊       | 102/368 [00:11<00:29,  9.07it/s]

In [12]:
print(model)

SchNet(hidden_channels=128, num_filters=128, num_interactions=4, num_gaussians=50, cutoff=10.0)
