Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework of Sklearn Metrics #1327

Merged
merged 44 commits into from Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
99a94dc
Create utils.py
justusschock Apr 1, 2020
4f546b5
Create __init__.py
justusschock Apr 1, 2020
ae19aa8
redo sklearn metrics
justusschock Apr 1, 2020
fec66b4
add some more metrics
justusschock Apr 3, 2020
08ad7b0
add sklearn metrics
justusschock Apr 13, 2020
2722c08
New metric classes (#1326)
justusschock Apr 3, 2020
d7bf19a
Create __init__.py
justusschock Apr 1, 2020
ba2c6f7
redo sklearn metrics
justusschock Apr 1, 2020
6595de8
add sklearn metrics
justusschock Apr 13, 2020
729690e
start adding sklearn tests
justusschock Apr 26, 2020
429dab6
fix typo
justusschock Apr 27, 2020
387b9b2
fix typo
justusschock Apr 27, 2020
74ab62b
fix typo
justusschock Apr 27, 2020
8e5f1d6
fix typo
justusschock Apr 27, 2020
5082268
return x and y only for curves
justusschock Apr 27, 2020
10cde37
fix typo
justusschock Apr 27, 2020
1a1762d
add missing tests for sklearn funcs
justusschock Apr 27, 2020
a698282
imports
justusschock Apr 27, 2020
debd245
__all__
justusschock Apr 27, 2020
04adf4b
imports
justusschock Apr 27, 2020
c72bda9
fix sklearn arguments
justusschock May 25, 2020
ca471d6
fix imports
justusschock May 25, 2020
e9f5faf
update requirements
justusschock May 25, 2020
13f205a
Update CHANGELOG.md
Borda May 25, 2020
b5dbdb8
Update test_sklearn_metrics.py
Borda May 25, 2020
a7e3e4f
formatting
Borda May 25, 2020
cc9b1b3
formatting
Borda May 25, 2020
c9908a1
format
Borda May 25, 2020
41e5971
fix all warnings and formatting problems
awaelchli May 25, 2020
82781e5
Update environment.yml
justusschock May 26, 2020
01c3e57
Update requirements-extra.txt
justusschock May 26, 2020
6a674b6
Update environment.yml
justusschock May 26, 2020
c96e3d7
Update requirements-extra.txt
justusschock May 26, 2020
b779e37
Update CHANGELOG.md
Borda Jun 8, 2020
f6e6cec
docs
Borda Jun 8, 2020
533bb7c
inherit
Borda Jun 9, 2020
ca117e4
fix all warnings and formatting problems
awaelchli May 25, 2020
36fa04c
docs inherit.
Borda Jun 9, 2020
0a3a31a
docs
Borda Jun 9, 2020
c8d6cac
Apply suggestions from code review
Borda Jun 9, 2020
eeede87
docs
Borda Jun 9, 2020
db8e724
req
Borda Jun 9, 2020
d49298a
min
Borda Jun 9, 2020
363bd64
Apply suggestions from code review
Borda Jun 9, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions .circleci/config.yml
Expand Up @@ -64,8 +64,12 @@ references:
name: Make Documentation
command: |
# First run the same pipeline as Read-The-Docs
sudo apt-get update && sudo apt-get install -y cmake
sudo pip install -r docs/requirements.txt
# apt-get update && apt-get install -y cmake
# using: https://hub.docker.com/r/readthedocs/build
# we need to use py3.7 ot higher becase of an issue with metaclass inheritence
pyenv global 3.7.3
python --version
pip install -r docs/requirements.txt
cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W"
test_docs: &test_docs
Expand All @@ -81,7 +85,7 @@ jobs:

Build-Docs:
docker:
- image: circleci/python:3.7
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
- image: readthedocs/build:latest
steps:
- checkout
- *make_docs
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/ci-testing.yml
Expand Up @@ -68,9 +68,9 @@ jobs:
- name: Set min. dependencies
if: matrix.requires == 'minimal'
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)"
python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)"
python -c "req = open('requirements.txt').read().replace('>=', '==') ; open('requirements.txt', 'w').write(req)"
python -c "req = open('requirements-extra.txt').read().replace('>=', '==') ; open('requirements-extra.txt', 'w').write(req)"
python -c "req = open('tests/requirements-devel.txt').read().replace('>=', '==') ; open('tests/requirements-devel.txt', 'w').write(req)"

# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Expand Up @@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - YYYY-MM-DD

### Added
Expand All @@ -23,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Expand Up @@ -90,6 +90,7 @@
'sphinx.ext.linkcode',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'sphinx.ext.imgmath',
'recommonmark',
'sphinx.ext.autosectionlabel',
# 'm2r',
Expand Down
4 changes: 4 additions & 0 deletions environment.yml
Expand Up @@ -26,6 +26,10 @@ dependencies:
- autopep8
- check-manifest
- twine==1.13.0
- pillow<7.0.0
- scipy>=0.13.3
- scikit-learn>=0.20.0


- pip:
- test-tube>=0.7.5
Expand Down
12 changes: 7 additions & 5 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Expand Up @@ -27,6 +27,8 @@
from tempfile import TemporaryDirectory
from typing import Optional, Generator, Union

from torch.nn import Module

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
Expand All @@ -47,7 +49,7 @@
# --- Utility functions ---


def _make_trainable(module: torch.nn.Module) -> None:
def _make_trainable(module: Module) -> None:
"""Unfreezes a given module.

Args:
Expand All @@ -58,7 +60,7 @@ def _make_trainable(module: torch.nn.Module) -> None:
module.train()


def _recursive_freeze(module: torch.nn.Module,
def _recursive_freeze(module: Module,
train_bn: bool = True) -> None:
"""Freezes the layers of a given module.

Expand All @@ -80,7 +82,7 @@ def _recursive_freeze(module: torch.nn.Module,
_recursive_freeze(module=child, train_bn=train_bn)


def freeze(module: torch.nn.Module,
def freeze(module: Module,
n: Optional[int] = None,
train_bn: bool = True) -> None:
"""Freezes the layers up to index n (if n is not None).
Expand All @@ -101,7 +103,7 @@ def freeze(module: torch.nn.Module,
_make_trainable(module=child)


def filter_params(module: torch.nn.Module,
def filter_params(module: Module,
train_bn: bool = True) -> Generator:
"""Yields the trainable parameters of a given module.

Expand All @@ -124,7 +126,7 @@ def filter_params(module: torch.nn.Module,
yield param


def _unfreeze_and_add_param_group(module: torch.nn.Module,
def _unfreeze_and_add_param_group(module: Module,
optimizer: Optimizer,
lr: Optional[float] = None,
train_bn: bool = True):
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/grads.py
Expand Up @@ -4,9 +4,10 @@
from typing import Dict, Union

import torch
from torch.nn import Module


class GradInformation(torch.nn.Module):
class GradInformation(Module):

def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
"""Compute each parameter's gradient's norm and their overall norm.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/hooks.py
Expand Up @@ -2,6 +2,7 @@

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device

Expand All @@ -14,7 +15,7 @@
APEX_AVAILABLE = True


class ModelHooks(torch.nn.Module):
class ModelHooks(Module):

# TODO: remove in v0.9.0
def on_sanity_check_start(self):
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/metrics/__init__.py
Expand Up @@ -22,3 +22,9 @@


"""

from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from typing import Any, Optional

import torch
import torch.distributed
from torch.nn import Module

from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -11,7 +12,7 @@
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']


class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
class Metric(ABC, DeviceDtypeModuleMixin, Module):
"""
Abstract base class for metric implementation.

Expand Down