In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import wandb
from pytorch_lightning.loggers import WandbLogger
import torch
import pytorch_lightning as pl
from Models.EGformer import EGformer
import torch.nn as nn
import torch.optim as optim
from ocpmodels.datasets import LmdbDataset
from torch.utils.data import random_split
import torch_geometric.loader as geom_loader
import torch_geometric.data as data
from pytorch_lightning.callbacks import LearningRateMonitor,ModelCheckpoint
from typing import Any
from pytorch_lightning.utilities.types import STEP_OUTPUT

  from .autonotebook import tqdm as notebook_tqdm
ERROR:root:Invalid setup for SCN. Either the e3nn library or Jd.pt is missing.


In [2]:
import yaml
import torch.nn.init as init
import torch.nn as nn

# Model initialize

In [3]:
# Save the model hyperparameters to a YAML file
model_hparams ={"num_atoms":0,
                "bond_feat_dim":0,
                "num_targets":0,
                "num_heads":4,
                }

with open('params/model_hparams.yml', 'w') as file:
    yaml.dump(model_hparams, file)

In [4]:
class GeoTransformer_Traniner(pl.LightningModule):
    ''''pytorch lightning'''
    def __init__(self,model,y_mean,y_std,optimizer_name,optimizer_hparams,**model_kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model=model        
        # self.optimizer_name=optimizer_name
        # self.optimizer_hparams=optimizer_hparams
        
        self.loss_module=nn.MSELoss()
        self.y_mean=y_mean
        self.y_std=y_std

    def forward(self,data):
        # x,edge_index,batch_idx=data.latent,data.edge_index,data.batch
        
        x=self.model(data)
        preds=x.squeeze()        
        label=data.y_relaxed/data.natoms
        label=(label-self.y_mean)/self.y_std
        label=label.squeeze()
        loss=self.loss_module(label,preds)
        acc=abs(label-preds)

        return loss,acc
    
    def configure_optimizers(self) -> Any:

        if self.hparams.optimizer_name == "Adam":
            optimizer=optim.AdamW(
                self.parameters(),**self.hparams.optimizer_hparams
            )
        scheduler=optim.lr_scheduler.MultiStepLR(
            optimizer,milestones=[5,10],gamma=0.1
        )    
        # return super().configure_optimizers()
        return [optimizer],[scheduler]


    
    def training_step(self,data,batch_idx):
        loss,acc=self.forward(data)
        self.log("train_loss",loss.mean())
        self.log("train_mae",acc.mean())
        return loss
        
    def validation_step(self,data,batch_idx):
        _,acc=self.forward(data)
        self.log("val_mae",acc.mean())

    def test_step(self,data,batch_idx):
        _,acc=self.forward(data)
        self.log("test_mae",acc.mean())

In [5]:
with open('params/model_hparams.yml', 'r') as file:
    loaded_model_hparams = yaml.load(file, Loader=yaml.FullLoader)

# Create the model using the loaded hyperparameters
model = EGformer(**loaded_model_hparams)
checkpoint_path='params/gemnet_oc_base_oc20_oc22.pt'
pretrained_state_dict = torch.load(checkpoint_path)['state_dict']
new_model_state_dict = model.state_dict()
filtered_pretrained_state_dict = {k.strip('module.module.'): v for k, v in pretrained_state_dict.items() if k.strip('module.module.') in new_model_state_dict}
new_model_state_dict.update(filtered_pretrained_state_dict)
model.load_state_dict(new_model_state_dict)

for param_name, param in model.named_parameters():

    if param_name in filtered_pretrained_state_dict.keys():
        
        param.requires_grad = False





In [6]:
y_mean=-2
y_std=2
model=GeoTransformer_Traniner(model=model,y_mean=y_mean,y_std=y_std,optimizer_name="Adam",optimizer_hparams={"lr":1e-3,"weight_decay":1e-4})

  rank_zero_warn(


In [7]:
f_paras,t_paras=0,0
for param_name,param in model.named_parameters():
    if param.requires_grad is False:
        f_paras+=1
    else:
        t_paras+=1
print('Freeze params is',f_paras)
print('Need optimiz params is',t_paras)

Freeze params is 238
Need optimiz params is 131


In [8]:
CHECKPOINT_PATH="./checkpoints"
# DEVICE='cuda'
dataset=LmdbDataset({"src":"Data/eoh.lmdb"})

train_length = int(0.8 * len(dataset))
val_length = len(dataset) - train_length

# Split the dataset into train and validation
train_dataset, val_dataset =random_split(dataset, [train_length, val_length])


In [9]:
# train_dataset=ReverseDataset(train_dataset,train_length)
# val_dataset=ReverseDataset(val_dataset,val_length)
train_loader = geom_loader.DataLoader(train_dataset, batch_size=1)
val_loader = geom_loader.DataLoader(val_dataset, batch_size=1)
wandb.init()
wandb_logger = WandbLogger()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmoxx799[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(


In [10]:
data = next(iter(train_loader))
batched_data = data.batch
print(batched_data.shape)

torch.Size([98])


In [11]:
def train_model(model,save_name=None,**kwargs):
    pl.seed_everything(42)

    if save_name is None:
        # raise TypeError('need a save name')
        save_name='exmodel'
    
    trainer=pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH,save_name),
                       accelerator='cuda',
                       devices=-1,
                       max_epochs=25,
                       callbacks=[ModelCheckpoint(save_weights_only=True,mode="min",monitor="val_mae"),
                                  LearningRateMonitor("epoch")],
                       enable_progress_bar=True,
                       logger=wandb_logger)
    
    trainer.logger._log_graph=True
    trainer.logger._default_hp_metric=None
    pretrained_filename=os.path.join(CHECKPOINT_PATH,save_name+".ckpt")
    if os.path.isfile(pretrained_filename):
        model=GeoTransformer_Traniner.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        trainer.fit(model,train_loader,val_loader)
        model=GeoTransformer_Traniner.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

    val_result=trainer.test(model,val_loader,verbose=False)
    # test_result=trainer.test(model,test_loader,verbose=False)
    # result={"test":test_result[0]["test_acc"],"val":val_result[0]["test_acc"]}

    return model,val_result

In [12]:
# wandb.init()
# wandb_logger = WandbLogger()

In [13]:
# %%capture out

gemformer_model,gemformer_results=train_model(model)

Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42
  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type     | Params
-----------------------------------------
0 | model       | EGformer | 27.6 M
1 | loss_module | MSELoss  | 0     
-----------------------------------------
5.9 M     Trainable params
21.7 M    Non-trainable params
27.6 M    Total params
110.378   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

IndexError: index 2 is out of bounds for dimension 0 with size 2

# Unfinished

In [None]:
from tqdm import tqdm

In [None]:
def train_fn(data_loader,model,optimizer):
    model.train()
    total_loss=0.0
    for images,masks in tqdm(data_loader):
        images=images.to(DEVICE)
        masks=masks.to(DEVICE)
        optimizer.zero_grad()
        preds=model(images)
        loss=
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()
        

    return total_loss/len(data_loader)

In [None]:
def eval_fn(data_loader,model):

    model.eval()
    total_loss=0.0
    total_acc=0
    
    with torch.no_grad():

        for images,masks in tqdm(data_loader):
            images=images.to(DEVICE)
            masks=masks.to(DEVICE)        
            logits,loss=model(images,masks)

            total_loss+=loss.item() 

            mask=torch.sigmoid(logits[0]).cpu().squeeze().flatten()        
            
            mask_true=masks[0].cpu().squeeze(0).flatten()
            
            
            precision=roc_auc_score(mask_true,mask)
            total_acc+=precision
                     
          
            

    return total_loss/len(data_loader),total_acc/len(data_loader)

In [None]:
optimizer=torch.optim.Adam(model.parameters(),lr=0.005)
best_valid_loss=np.Inf
animator=d2l.Animator(xlabel='epoch',ylabel='loss',yscale='log',xlim=[1,epochs],ylim=[0.1,1.5],legend=['train','valid','acc'])

for i in range(epochs):
    train_loss=train_fn(trainloader,model,optimizer)
    valid_loss,acc=eval_fn(validloader,model)
    print(acc)
    animator.add(i+1,(train_loss,valid_loss,acc))
    if valid_loss< best_valid_loss:
        torch.save(model.state_dict(),'best_model_n.pt')
        print('saved-model')
        best_valid_loss=valid_loss

    print(f'epoch:{i+1} Train_loss:{train_loss} Valid_loss:{valid_loss} Valid acc:{acc}')