Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,14 @@ dmypy.json
.DS_Store

# data
./data/*
data/**

# config
./config/*

# results
results/**

server/

main.py
Expand Down
8 changes: 2 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
# isort
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
Expand Down
2 changes: 1 addition & 1 deletion configs/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ trainer: property

task:
# run_mode: train
name: "my_train_job"
identifier: "my_train_job"

reprocess: False

Expand Down
4 changes: 2 additions & 2 deletions configs/examples/DOS_STO.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
trainer: property

task:
name: "my_train_job"
identifier: "my_train_job"
reprocess: False
parallel: True
seed: 0
Expand Down Expand Up @@ -38,7 +38,7 @@ optim:
scheduler_args: {"mode":"min", "factor":0.8, "patience":40, "min_lr":0.00001, "threshold":0.0002}

dataset:
processed: False
processed: True
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/raw/"
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/STO_DOS_data/targets.csv"
pt_path: "/global/cfs/projectdirs/m3641/Sarah/datasets/processed/STO_DOS_data/"
Expand Down
27 changes: 27 additions & 0 deletions matdeeplearn/common/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Registry:
"model_name_mapping": {},
"logger_name_mapping": {},
"trainer_name_mapping": {},
"loss_name_mapping": {},
"state": {},
}

Expand Down Expand Up @@ -165,6 +166,28 @@ def wrap(func):

return wrap

@classmethod
def register_loss(cls, name):
r"""Register a loss class to registry with key 'name'

Args:
name: Key with which the trainer will be registered.

Usage::

from matdeeplearn.common.registry import registry

@registry.register_loss("dos_loss")
class DOSLoss():
...
"""

def wrap(func):
cls.mapping["loss_name_mapping"][name] = func
return func

return wrap

@classmethod
def register(cls, name, obj):
r"""Register an item to registry with key 'name'
Expand Down Expand Up @@ -248,6 +271,10 @@ def get_logger_class(cls, name):
def get_trainer_class(cls, name):
return cls.get_class(name, "trainer_name_mapping")

@classmethod
def get_loss_class(cls, name):
return cls.get_class(name, "loss_name_mapping")

@classmethod
def get(cls, name, default=None, no_warning=False):
r"""Get an item from registry with key 'name'
Expand Down
10 changes: 7 additions & 3 deletions matdeeplearn/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import warnings
from abc import abstractmethod
from abc import ABCMeta, abstractmethod

import torch
import torch.nn as nn
from torch_geometric.nn import radius_graph
from torch_geometric.utils import dense_to_sparse

from matdeeplearn.preprocessor.helpers import (
Expand All @@ -14,12 +13,17 @@
)


class BaseModel(nn.Module):
class BaseModel(nn.Module, metaclass=ABCMeta):
def __init__(self, edge_steps: int = 50, self_loop: bool = True) -> None:
super(BaseModel, self).__init__()
self.edge_steps = edge_steps
self.self_loop = self_loop

@property
@abstractmethod
def target_attr(self):
"""Specifies the target attribute property for writing output to file"""

def __str__(self):
# Prints model summary
str_representation = "\n"
Expand Down
9 changes: 8 additions & 1 deletion matdeeplearn/models/cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def __init__(
self.gc_dim, self.post_fc_dim = dim1, dim1

# Determine output dimension length
self.output_dim = 1 if data[0].y.ndim == 0 else len(data[0].y[0])
if data[0][self.target_attr].ndim == 0:
self.output_dim = 1
else:
self.output_dim = len(data[0][self.target_attr][0])

# setup layers
self.pre_lin_list = self._setup_pre_gnn_layers()
Expand All @@ -75,6 +78,10 @@ def __init__(
# workaround for doubled dimension by set2set; if late pooling not recommended to use set2set
self.lin_out_2 = torch.nn.Linear(self.output_dim * 2, self.output_dim)

@property
def target_attr(self):
return "y"

def _setup_pre_gnn_layers(self):
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
pre_lin_list = torch.nn.ModuleList()
Expand Down
9 changes: 8 additions & 1 deletion matdeeplearn/models/dos_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __init__(
self.gc_dim, self.post_fc_dim = dim1, dim1

# Determine output dimension length
self.output_dim = 1 if data[0].scaled.ndim == 0 else len(data[0].scaled[0])
if data[0][self.target_attr].ndim == 0:
self.output_dim = 1
else:
self.output_dim = len(data[0][self.target_attr][0])

# setup layers
self.pre_lin_list = self._setup_pre_gnn_layers()
Expand All @@ -65,6 +68,10 @@ def __init__(
Linear(self.dim2, 1),
)

@property
def target_attr(self):
return "scaled"

def _setup_pre_gnn_layers(self):
"""Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)."""
pre_lin_list = torch.nn.ModuleList()
Expand Down
23 changes: 10 additions & 13 deletions matdeeplearn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from torch import nn
from torch_geometric.data import Batch

from matdeeplearn.common.registry import registry


@registry.register_loss("DOSLoss")
class DOSLoss(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -47,34 +50,28 @@ def forward(self, predictions: tuple[torch.Tensor, torch.Tensor], target: Batch)
def get_dos_features(self, x, dos):
"""get dos features"""
dos = torch.abs(dos)
dos_sum = torch.sum(dos, axis=1)

center = torch.sum(x * dos, axis=1) / torch.sum(dos, axis=1)
center = torch.sum(x * dos, axis=1) / dos_sum
x_offset = (
torch.repeat_interleave(x[np.newaxis, :], dos.shape[0], axis=0)
- center[:, None]
)
width = torch.diagonal(torch.mm((x_offset**2), dos.T)) / torch.sum(
dos, axis=1
)
skew = (
torch.diagonal(torch.mm((x_offset**3), dos.T))
/ torch.sum(dos, axis=1)
/ width**1.5
)
width = torch.diagonal(torch.mm((x_offset**2), dos.T)) / dos_sum
skew = torch.diagonal(torch.mm((x_offset**3), dos.T)) / dos_sum / width**1.5
kurtosis = (
torch.diagonal(torch.mm((x_offset**4), dos.T))
/ torch.sum(dos, axis=1)
/ width**2
torch.diagonal(torch.mm((x_offset**4), dos.T)) / dos_sum / width**2
)

# find zero index (fermi leve)
# find zero index (fermi level)
zero_index = torch.abs(x - 0).argmin().long()
ef_states = torch.sum(dos[:, zero_index - 20 : zero_index + 20], axis=1) * abs(
x[0] - x[1]
)
return torch.stack((center, width, skew, kurtosis, ef_states), axis=1)


@registry.register_loss("TorchLossWrapper")
class TorchLossWrapper(nn.Module):
def __init__(self, loss_fn="l1_loss"):
super().__init__()
Expand Down
98 changes: 88 additions & 10 deletions matdeeplearn/trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import copy
import csv
import logging
import os
from abc import ABC, abstractmethod
from datetime import datetime

import torch
import torch.optim as optim
Expand All @@ -17,7 +21,6 @@
from matdeeplearn.common.registry import registry
from matdeeplearn.models.base_model import BaseModel
from matdeeplearn.modules.evaluator import Evaluator
from matdeeplearn.modules.loss import *
from matdeeplearn.modules.scheduler import LRScheduler


Expand All @@ -35,6 +38,7 @@ def __init__(
test_loader: DataLoader,
loss: nn.Module,
max_epochs: int,
identifier: str = None,
verbosity: int = None,
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -56,9 +60,20 @@ def __init__(
self.step = 0
self.metrics = {}
self.epoch_time = None
self.best_val_metric = 1e10
self.best_model_state = None

self.evaluator = Evaluator()

self.run_dir = os.getcwd()

timestamp = torch.tensor(datetime.now().timestamp()).to(self.device)
self.timestamp_id = datetime.fromtimestamp(timestamp.int()).strftime(
"%Y-%m-%d-%H-%M-%S"
)
if identifier:
self.timestamp_id = f"{self.timestamp_id}-{identifier}"

if self.train_verbosity:
logging.info(
f"GPU is available: {torch.cuda.is_available()}, Quantity: {torch.cuda.device_count()}"
Expand Down Expand Up @@ -94,6 +109,7 @@ def from_config(cls, config):
loss = cls._load_loss(config["optim"]["loss"])

max_epochs = config["optim"]["max_epochs"]
identifier = config["task"].get("identifier", None)
verbosity = config["task"].get("verbosity", None)

return cls(
Expand All @@ -107,6 +123,7 @@ def from_config(cls, config):
test_loader=test_loader,
loss=loss,
max_epochs=max_epochs,
identifier=identifier,
verbosity=verbosity,
)

Expand Down Expand Up @@ -180,15 +197,12 @@ def _load_scheduler(scheduler_config, optimizer):
@staticmethod
def _load_loss(loss_config):
"""Loads the loss from either the TorchLossWrapper or custom loss functions in matdeeplearn"""
try:
loss_type = loss_config["loss_type"]
# if there are other params for loss type, include in call
if loss_config.get("loss_args"):
return eval(loss_type)(**loss_config["loss_args"])
else:
return eval(loss_type)()
except (AttributeError, NameError):
raise NotImplementedError(f"Unknown loss class name: {loss_type}")
loss_cls = registry.get_loss_class(loss_config["loss_type"])
# if there are other params for loss type, include in call
if loss_config.get("loss_args"):
return loss_cls(**loss_config["loss_args"])
else:
return loss_cls()

@abstractmethod
def _load_task(self):
Expand All @@ -205,3 +219,67 @@ def validate(self):
@abstractmethod
def predict(self):
"""Implemented by derived classes."""

def update_best_model(self, val_metrics):
"""Updates the best val metric and model, saves the best model, and saves the best model predictions"""
self.best_val_metric = val_metrics[type(self.loss_fn).__name__]["metric"]
self.best_model_state = copy.deepcopy(self.model.state_dict())

self.save_model("best_checkpoint.pt", val_metrics, False)

logging.debug(
f"Saving prediction results for epoch {self.epoch} to: /results/{self.timestamp_id}/"
)
self.predict(self.train_loader, "train")
self.predict(self.val_loader, "val")
self.predict(self.test_loader, "test")

def save_model(self, checkpoint_file, val_metrics=None, training_state=True):
"""Saves the model state dict"""

if training_state:
state = {
"epoch": self.epoch,
"step": self.step,
"state_dict": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.scheduler.state_dict(),
"best_val_metric": self.best_val_metric,
}
else:
state = {"state_dict": self.model.state_dict(), "val_metrics": val_metrics}

checkpoint_dir = os.path.join(
self.run_dir, "results", self.timestamp_id, "checkpoint"
)
os.makedirs(checkpoint_dir, exist_ok=True)
filename = os.path.join(checkpoint_dir, checkpoint_file)

torch.save(state, filename)
return filename

def save_results(self, output, filename, node_level_predictions=False):
results_path = os.path.join(self.run_dir, "results", self.timestamp_id)
os.makedirs(results_path, exist_ok=True)
filename = os.path.join(results_path, filename)
shape = output.shape

id_headers = ["structure_id"]
if node_level_predictions:
id_headers += ["node_id"]
num_cols = (shape[1] - len(id_headers)) // 2
headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols

with open(filename, "w") as f:
csvwriter = csv.writer(f)
for i in range(0, len(output)):
if i == 0:
csvwriter.writerow(headers)
elif i > 0:
csvwriter.writerow(output[i - 1, :])
return filename

def load_checkpoint(self):
"""Loads the model from a checkpoint.pt file"""
# TODO: implement this method
pass
Loading