Skip to content

Commit

Permalink
Add torch evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Aug 12, 2021
1 parent ed84665 commit 1792f55
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
2 changes: 1 addition & 1 deletion neuralogic/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_parsed_template(self):
def build(self, backend: Backend, *, native_backend_models=False):
from neuralogic.nn import get_neuralogic_layer

if native_backend_models:
if backend == Backend.PYG:
return get_neuralogic_layer(backend, native_backend_models)(self.module_list)

with self.context():
Expand Down
17 changes: 10 additions & 7 deletions neuralogic/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@ def get_neuralogic_layer(backend: Backend, native_backend_models: bool = False):
from neuralogic.nn.dynet import NeuraLogic # type: ignore

return NeuraLogic
if backend == Backend.DGL:
from neuralogic.nn.dgl import NeuraLogicLayer # type: ignore

return NeuraLogicLayer
# if backend == Backend.DGL:
# from neuralogic.nn.dgl import NeuraLogicLayer # type: ignore
#
# return NeuraLogicLayer
if backend == Backend.JAVA:
from neuralogic.nn.java import NeuraLogic # type: ignore

return NeuraLogic
if backend == Backend.PYG:
if native_backend_models:
from neuralogic.nn.native.torch import NeuraLogic
from neuralogic.nn.native.torch import NeuraLogic

return NeuraLogic
return NeuraLogic
raise NotImplementedError


Expand All @@ -48,3 +47,7 @@ def get_evaluator(
from neuralogic.nn.evaluators.java import JavaEvaluator

return JavaEvaluator(template, settings)
if backend == Backend.PYG:
from neuralogic.nn.evaluators.torch import TorchEvaluator

return TorchEvaluator(template, settings)
4 changes: 3 additions & 1 deletion neuralogic/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __init__(self, backend: Backend, template: Template, settings: Settings):
self.backend = backend
self.dataset: Optional[BuiltDataset] = None
self.neuralogic_model = template.build(backend)
self.neuralogic_model.set_hooks(template.hooks)

if backend != Backend.PYG:
self.neuralogic_model.set_hooks(template.hooks)

def set_dataset(self, dataset: Union[Dataset, BuiltDataset]):
self.dataset = self.build_dataset(dataset)
Expand Down
88 changes: 88 additions & 0 deletions neuralogic/nn/evaluators/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Optional, Dict, Union

import torch.nn.functional as F
import torch

from neuralogic.nn.base import AbstractEvaluator

from neuralogic.core import Template, BuiltDataset
from neuralogic.core.settings import Settings, Optimizer, ErrorFunction
from neuralogic.core.builder import Backend
from neuralogic.utils.data import Dataset


class TorchEvaluator(AbstractEvaluator):
trainers = {
Optimizer.SGD: lambda param, rate: torch.optim.SGD(param, lr=rate),
Optimizer.ADAM: lambda param, rate: torch.optim.Adam(param, lr=rate),
}

error_functions = {ErrorFunction.SQUARED_DIFF: F.mse_loss, ErrorFunction.CROSSENTROPY: F.cross_entropy}

def __init__(self, template: Template, settings: Settings):
super().__init__(Backend.PYG, template, settings)

def train(self, dataset: Optional[Union[Dataset, BuiltDataset]] = None, *, generator: bool = True):
# dataset = self.dataset if dataset is None else self.build_dataset(dataset)

epochs = self.settings.epochs
error_function = ErrorFunction[str(self.settings.error_function)]
optimizer = Optimizer[str(self.settings.optimizer)]

if optimizer not in TorchEvaluator.trainers:
raise NotImplementedError
if error_function not in TorchEvaluator.error_functions:
raise NotImplementedError

trainer = TorchEvaluator.trainers[optimizer](
self.neuralogic_model.module_list.parameters(),
self.settings.learning_rate,
)
error_function = TorchEvaluator.error_functions[error_function]

def _train():
for _ in range(epochs):
seen_instances = 0
total_loss = 0

for data in dataset.data:
self.neuralogic_model.train()
trainer.zero_grad()

out = self.neuralogic_model(x=data.x, edge_index=data.edge_index)
loss = F.nll_loss(out[data.y_mask], data.y[data.y_mask])
loss.backward()
trainer.step()

seen_instances += 1
total_loss += float(loss)
yield total_loss, seen_instances

if generator:
return _train()

stats = 0, 0
for stats in _train():
pass
return stats

def test(self, dataset: Optional[Union[Dataset, BuiltDataset]] = None, *, generator: bool = True):
self.neuralogic_model.train(mode=False)

# dataset = self.dataset if dataset is None else self.build_dataset(dataset)

def _test():
for data in dataset.data:
self.neuralogic_model.train(mode=False)
out = self.neuralogic_model(x=data.x, edge_index=data.edge_index)
results = (out[data.y_mask], data.y[data.y_mask])

pred = out[data.y_mask].max(1)[1]
acc = pred.eq(data.y[data.y_mask]).sum().item() / data.y_mask.sum().item()
# accs.append(acc)
print(acc)
yield results

if generator:
return _test()
return list(_test())

0 comments on commit 1792f55

Please sign in to comment.