In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import wandb
from pytorch_lightning.loggers import WandbLogger
import torch
import pytorch_lightning as pl
from Models.Gemformer import Gemformer
import torch.nn as nn
import torch.optim as optim
from ocpmodels.datasets import LmdbDataset
from torch.utils.data import random_split
import torch_geometric.loader as geom_loader
import torch_geometric.data as data
from pytorch_lightning.callbacks import LearningRateMonitor,ModelCheckpoint
from typing import Any
from pytorch_lightning.utilities.types import STEP_OUTPUT


In [None]:
model_dict={}
def create_model(model_name,model_hparams):
    if model_name in model_dict:
        return model_dict[model_name](**model_hparams)
    else:
        assert False, f"Unknown model name \"{model_name}\".Available models are: {str(model_dict.keys())}"

In [None]:
class GeoTransformer_Traniner(pl.LightningModule):
    ''''pytorch lightning'''
    def __init__(self,model_name, model_hparams, optimizer_name, optimizer_hparams,**model_kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.optimizer_name=optimizer_name
        self.optimizer_hparams=optimizer_hparams
        self.model=create_model(model_name,model_hparams)
        self.loss_module=nn.MSELoss()

    def forward(self,data,mode="train"):
        x,edge_index,batch_idx=data.latent,data.edge_index,data.batch
        # print(data)
        x=self.model(x,batch_idx)
        x=x.squeeze(dim=-1)
        preds=x.float()
        loss=self.loss_module(data.y_relaxed,x)
        acc=abs(data.y_relaxed-preds)

        return loss,acc
    
    def configure_optimizers(self) -> Any:

        if self.optimizer_name == "Adam":
            optimizer=optim.AdamW(
                self.parameters(),**self.hparams.optimizer_hparams
            )
        scheduler=optim.lr_scheduler.MultiStepLR(
            optimizer,milestones=[20,30],gamma=0.1
        )    
        # return super().configure_optimizers()
        return [optimizer],[scheduler]


    
    def training_step(self,batch,batch_idx):
        loss,acc=self.forward(batch,mode="train")
        self.log("train_loss",loss)
        self.log("train_mae",acc)
        return loss
        
    def validation_step(self,batch,batch_idx):
        _,acc=self.forward(batch,mode="val")
        self.log("val_mae",acc)

    def test_step(self,batch,batch_idx):
        _,acc=self.forward(batch,mode="test")
        self.log("test_mae",acc)


In [None]:
# class ReverseDataset(data.Dataset):

#     def __init__(self, data,size):
#         super().__init__()
#         # self.size = data
#         self.data = data
#         self.size=size
  
#     def __len__(self):
#         return self.size

#     def __getitem__(self, idx):
#         inp_data = self.data[idx]
#         labels = inp_data.y_relaxed
#         return inp_data, labels

In [20]:
from torch.utils.data import DataLoader

In [21]:
CHECKPOINT_PATH="./checkpoints"

dataset=LmdbDataset({"src":"Data/eoh_t.lmdb"})

train_length = int(0.8 * len(dataset))
val_length = len(dataset) - train_length

# Split the dataset into train and validation
train_dataset, val_dataset =random_split(dataset, [train_length, val_length])
# train_dataset=ReverseDataset(train_dataset,train_length)
# val_dataset=ReverseDataset(val_dataset,val_length)
train_loader = DataLoader(train_dataset, batch_size=2)
val_loader = DataLoader(val_dataset, batch_size=2)


In [None]:
# inp_data, labels = train_loader
# print("Input data:", inp_data)
# print("Labels:    ", labels)

In [22]:
def train_model(model_name,save_name=None,**kwargs):
    pl.seed_everything(42)

    if save_name is None:
        save_name=model_name
    
    trainer=pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH,save_name),
                       accelerator='cuda',
                       devices=-1,
                       max_epochs=50,
                       callbacks=[ModelCheckpoint(save_weights_only=True,mode="min",monitor="val_mae"),
                                  LearningRateMonitor("epoch")],
                       enable_progress_bar=True,
                       logger=wandb_logger)
    
    trainer.logger._log_graph=True
    trainer.logger._default_hp_metric=None
    pretrained_filename=os.path.join(CHECKPOINT_PATH,save_name+".ckpt")
    if os.path.isfile(pretrained_filename):
        model=GeoTransformer_Traniner().load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model=GeoTransformer_Traniner(model_name="Gemformer",**kwargs)
        trainer.fit(model,train_loader,val_loader)
        model=GeoTransformer_Traniner.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    val_result=trainer.test(model,val_loader,verbose=False)
    # test_result=trainer.test(model,test_loader,verbose=False)
    # result={"test":test_result[0]["test_acc"],"val":val_result[0]["test_acc"]}

    return model,val_result

In [23]:
wandb.init()
wandb_logger = WandbLogger()

  rank_zero_warn(


In [24]:
%%capture out
model_dict['Gemformer']=Gemformer
gemformer_model,gemformer_results=train_model(model_name="Gemformer",
                                              model_hparams={"num_heads":1,
                                                             "emb_size_in":256,
                                                             "emb_size_trans":64},
                                              optimizer_name="Adam",
                                              optimizer_hparams={"lr":1e-3,
                                                                 "weight_decay":1e-4})

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type      | Params
------------------------------------------
0 | model       | Gemformer | 68.2 K
1 | loss_module | MSELoss   | 0     
------------------------------------------
68.2 K    Trainable params
0         Non-trainable params
68.2 K    Total params
0.273     Total estimated model params size (MB)


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.batch.DataBatch'>

In [None]:
from pathlib import Path
from torch.utils.data import Dataset
import pickle
import lmdb
import bisect

class LmdbDataset(Dataset):
    r"""Dataset class to load from LMDB files containing relaxation
    trajectories or single point computations.

    Useful for Structure to Energy & Force (S2EF), Initial State to
    Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks.

    Args:
            config (dict): Dataset configuration
            transform (callable, optional): Data transform function.
                    (default: :obj:`None`)
    """

    def __init__(self, config, transform=None):
        super(LmdbDataset, self).__init__()
        self.config = config

        assert not self.config.get(
            "train_on_oc20_total_energies", False
        ), "For training on total energies set dataset=oc22_lmdb"

        self.path = Path(self.config["src"])
        if not self.path.is_file():
            db_paths = sorted(self.path.glob("*.lmdb"))
            assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'"

            self.metadata_path = self.path / "metadata.npz"

            self._keys, self.envs = [], []
            for db_path in db_paths:
                self.envs.append(self.connect_db(db_path))
                length = pickle.loads(
                    self.envs[-1].begin().get("length".encode("ascii"))
                )
                self._keys.append(list(range(length)))

            keylens = [len(k) for k in self._keys]
            self._keylen_cumulative = np.cumsum(keylens).tolist()
            self.num_samples = sum(keylens)
        else:
            self.metadata_path = self.path.parent / "metadata.npz"
            self.env = self.connect_db(self.path)
            self._keys = [
                f"{j}".encode("ascii")
                for j in range(self.env.stat()["entries"])
            ]
            self.num_samples = len(self._keys)

        # If specified, limit dataset to only a portion of the entire dataset
        # total_shards: defines total chunks to partition dataset
        # shard: defines dataset shard to make visible
        self.sharded = False
        if "shard" in self.config and "total_shards" in self.config:
            self.sharded = True
            self.indices = range(self.num_samples)
            # split all available indices into 'total_shards' bins
            self.shards = np.array_split(
                self.indices, self.config.get("total_shards", 1)
            )
            # limit each process to see a subset of data based off defined shard
            self.available_indices = self.shards[self.config.get("shard", 0)]
            self.num_samples = len(self.available_indices)

        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # if sharding, remap idx to appropriate idx of the sharded set
        if self.sharded:
            idx = self.available_indices[idx]
        if not self.path.is_file():
            # Figure out which db this should be indexed from.
            db_idx = bisect.bisect(self._keylen_cumulative, idx)
            # Extract index of element within that db.
            el_idx = idx
            if db_idx != 0:
                el_idx = idx - self._keylen_cumulative[db_idx - 1]
            assert el_idx >= 0

            # Return features.
            datapoint_pickled = (
                self.envs[db_idx]
                .begin()
                .get(f"{self._keys[db_idx][el_idx]}".encode("ascii"))
            )
            data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))
            data_object.id = f"{db_idx}_{el_idx}"
        else:
            datapoint_pickled = self.env.begin().get(self._keys[idx])
            data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))

        if self.transform is not None:
            data_object = self.transform(data_object)

        return data_object

    def connect_db(self, lmdb_path=None):
        env = lmdb.open(
            str(lmdb_path),
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=1,
        )
        return env

    def close_db(self):
        if not self.path.is_file():
            for env in self.envs:
                env.close()
        else:
            self.env.close()