In [1]:
%config Completer.use_jedi=False

In [2]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
from sklearn.model_selection import train_test_split

from ruslan_nn.schnet import SchNet
import pickle
import wandb

from torch_geometric.data import DataLoader
import torch
seed=42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
targets = pd.read_csv('ruslan_nn/properties16k.csv', index_col=0)
folds = pd.read_csv('folds.csv')
targets['band_gap'] = targets['bandgap']

In [4]:
with open('ruslan_nn/structures16k.pickle', 'rb') as file:
    structures = pickle.load(file)

In [5]:
from torch_geometric.data import Data
import torch
import ase
from pymatgen.io.ase import AseAtomsAdaptor

def construct_dataset(structures, targets, property_):
    data_atoms = []
    label = targets[property_]
    for _id in tqdm(targets.index):
        atoms=AseAtomsAdaptor.get_atoms(structures[str(_id)])
        # set the atomic numbers, positions, and cell
        atom = torch.Tensor(atoms.get_atomic_numbers())
        positions = torch.Tensor(atoms.get_positions())
        natoms = positions.shape[0]

        # put the minimum data in torch geometric data object
        data = Data(
            pos=positions,
            z= atom,
           # natoms=natoms,
        )

        # calculate energy
        data.y = label[_id]
        data_atoms.append(data)
    data_atoms = pd.DataFrame({"data": data_atoms}, index=targets.index)
    return data_atoms

In [6]:
from sklearn.model_selection import KFold
kf = KFold(n_splits=8, shuffle=True, random_state=42)

In [7]:
device = 'cuda:1'
epochs = 50

In [8]:
def train_model(model, optimizer, scheduler, train_loader, test_loader):
    for epoch in range(epochs):
        model.train()
        valid_loss=0
        train_loss=0
        for d in tqdm(train_loader): 
            data = d.to(device)
            out = model(data) 
            optimizer.zero_grad()
            loss = loss_func(out.view(-1), data.y.view(-1))
            loss.backward() 
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
        with torch.no_grad():
            model.eval()
            for d in tqdm(test_loader):
                data = d.to(device)
                target = model(data)
                loss = loss_func(target.view(-1), data.y.view(-1))
                valid_loss += loss.item()

        print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, train_loss/len(train_loader)))
        print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, valid_loss/len(test_loader)))
        wandb.log({
            "train_mae": train_loss/len(train_loader),
            "test_mae": valid_loss/len(test_loader),
        })    

In [None]:
for property_ in ['homo', 'lumo']:
    data_atoms = construct_dataset(structures, targets, property_)
    for fold_idx, (train_index, test_index) in enumerate(kf.split(data_atoms)):
        train = data_atoms.iloc[train_index].data
        test = data_atoms.iloc[test_index].data
        train_loader = DataLoader(train, batch_size=32, shuffle=True)
        test_loader = DataLoader(test, batch_size=32)
        
        model=SchNet()
        model = model.to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        loss_func = torch.nn.L1Loss() #define loss
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, epochs=epochs,
                                                        steps_per_epoch=len(train_loader),
                                                        max_lr=1e-3)
        wandb.init(
            project="schnet_dichalcogenides", entity="inno-materials-ai",
            save_code=True, name=f'schnet_{property_}_fold{fold_idx}'
        )
        train_model(model, optimizer, scheduler, train_loader, test_loader)

100%|██████████| 15355/15355 [00:27<00:00, 560.75it/s]
[34m[1mwandb[0m: Currently logged in as: [33mimplausible_denyability[0m (use `wandb login --relogin` to force relogin)


 49%|████▉     | 205/420 [00:16<00:17, 12.27it/s]