In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import wandb
from pytorch_lightning.loggers import WandbLogger
import torch
import pytorch_lightning as pl
from Models.EGformer import EGformer
import torch.nn as nn
import torch.optim as optim
from ocpmodels.datasets import LmdbDataset
from torch.utils.data import random_split
import torch_geometric.loader as geom_loader
import torch_geometric.data as data
from pytorch_lightning.callbacks import LearningRateMonitor,ModelCheckpoint
from typing import Any
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch_geometric.loader as geom_data


In [None]:
model_dict={}
def create_model(model_name,model_hparams):
    if model_name in model_dict: 
        return model_dict[model_name](**model_hparams)
    else:
        assert False, f"Unknown model name \"{model_name}\".Available models are: {str(model_dict.keys())}"

In [None]:
def config_model(model):
        checkpoint_path='params/gemnet_oc_base_oc20_oc22.pt'
        pretrained_state_dict = torch.load(checkpoint_path)['state_dict']
        new_model_state_dict = model.state_dict()
        filtered_pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in new_model_state_dict}
        new_model_state_dict.update(filtered_pretrained_state_dict)
        model.load_state_dict(new_model_state_dict)
        for param_name, param in model.named_parameters():

            if param_name in filtered_pretrained_state_dict.keys():
                
                param.requires_grad = False
        return model

In [None]:
class GeoTransformer_Traniner(pl.LightningModule):
    ''''pytorch lightning'''
    def __init__(self,model_name, model_hparams, optimizer_name, optimizer_hparams,**model_kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.optimizer_name=optimizer_name
        self.optimizer_hparams=optimizer_hparams
        self.model=create_model(model_name,model_hparams)
        self.model=config_model(self.model)
        self.loss_module=nn.MSELoss()

    def forward(self,data,mode="train"):
        # x,edge_index,batch_idx=data.latent,data.edge_index,data.batch
        
        x=self.model(data)
        x=x.squeeze(dim=-1)
        preds=x.float()
        loss=self.loss_module(data.y_relaxed,x)
        acc=abs(data.y_relaxed-preds)

        return loss,acc
    
    def configure_optimizers(self) -> Any:

        if self.optimizer_name == "Adam":
            optimizer=optim.AdamW(
                self.parameters(),**self.hparams.optimizer_hparams
            )
        scheduler=optim.lr_scheduler.MultiStepLR(
            optimizer,milestones=[20,30],gamma=0.1
        )    
        # return super().configure_optimizers()
        return [optimizer],[scheduler]


    
    def training_step(self,data,batch_idx):
        loss,acc=self.forward(data,mode="train")
        self.log("train_loss",loss.mean())
        self.log("train_mae",acc.mean())
        return loss
        
    def validation_step(self,data,batch_idx):
        _,acc=self.forward(data,mode="val")
        self.log("val_mae",acc.mean())

    def test_step(self,data,batch_idx):
        _,acc=self.forward(data,mode="test")
        self.log("test_mae",acc.mean())


In [None]:
CHECKPOINT_PATH="./checkpoints"

dataset=LmdbDataset({"src":"Data/eoh.lmdb"})

train_length = int(0.8 * len(dataset))
val_length = len(dataset) - train_length

# Split the dataset into train and validation
train_dataset, val_dataset =random_split(dataset, [train_length, val_length])
# train_dataset=ReverseDataset(train_dataset,train_length)
# val_dataset=ReverseDataset(val_dataset,val_length)
train_loader =  geom_data.DataLoader(train_dataset, batch_size=2)
val_loader =  geom_data.DataLoader(val_dataset, batch_size=2)


In [None]:
def train_model(model_name,save_name=None,**kwargs):
    pl.seed_everything(42)

    if save_name is None:
        save_name=model_name
    
    trainer=pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH,save_name),
                       accelerator='cuda',
                       devices=-1,
                       max_epochs=25,
                       callbacks=[ModelCheckpoint(save_weights_only=True,mode="min",monitor="val_mae"),
                                  LearningRateMonitor("epoch")],
                       enable_progress_bar=True,
                       logger=wandb_logger)
    
    trainer.logger._log_graph=True
    trainer.logger._default_hp_metric=None
    pretrained_filename=os.path.join(CHECKPOINT_PATH,save_name+".ckpt")
    if os.path.isfile(pretrained_filename):
        model=GeoTransformer_Traniner.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model=GeoTransformer_Traniner(model_name="EGformer",**kwargs)
        trainer.fit(model,train_loader,val_loader)
        model=GeoTransformer_Traniner.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    val_result=trainer.test(model,val_loader,verbose=False)
    # test_result=trainer.test(model,test_loader,verbose=False)
    # result={"test":test_result[0]["test_acc"],"val":val_result[0]["test_acc"]}

    return model,val_result

In [None]:
wandb.init()
wandb_logger = WandbLogger()

In [None]:
%%capture out
model_dict['EGformer']=EGformer
gemformer_model,gemformer_results=train_model(model_name="EGformer",
                                              model_hparams={"num_atoms":0,
                                                             "bond_feat_dim":0,
                                                             "num_targets":0,
                                                             "num_heads":4,
                                                             },
                                              optimizer_name="Adam",
                                              optimizer_hparams={"lr":1e-3,
                                                                 "weight_decay":1e-4})

In [None]:
# DEVICE='cuda'
# from tqdm import tqdm
# def out_fn(dataloader,model):

#     model.eval()    
#     with torch.no_grad():
#            for i, batch in tqdm(
#                            enumerate(dataloader),
#                         total=len(dataloader)
#             ):   
#                 batch=batch.to(DEVICE)
#                 model=model.to(DEVICE) 
#                 output=model(batch)
#                 print(output.shape)
                
                


In [None]:
# out_fn(train_loader,myGemnet)