Skip to content

Commit

Permalink
added train step return
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed May 28, 2020
2 parents 9f9fab1 + cee9bdd commit 81e6c75
Show file tree
Hide file tree
Showing 22 changed files with 449 additions and 89 deletions.
15 changes: 7 additions & 8 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ references:
flake8
make_docs: &make_docs
run:
name: Make Documentation
command: |
# sudo apt-get install pandoc
pip install -r requirements.txt --user
sudo pip install -r docs/requirements.txt
# sphinx-apidoc -o ./docs/source ./pytorch_lightning **/test_* --force --follow-links
cd docs; make clean ; make html --debug --jobs 2 SPHINXOPTS="-W"
run:
name: Make Documentation
command: |
# First run the same pipeline as Read-The-Docs
sudo apt-get update && sudo apt-get install -y cmake
sudo pip install -r docs/requirements.txt
cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W"
jobs:

Expand Down
122 changes: 122 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
name: CI testing

# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
on:
# Trigger the workflow on push or pull request,
# but only for the master branch
push:
branches:
- master
pull_request:
branches:
- master

jobs:
build:

runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
# max-parallel: 6
matrix:
# PyTorch 1.5 is failing on Win and bolts requires torchvision>=0.5
os: [ubuntu-18.04, macOS-10.15] # , windows-2019
python-version: [3.6, 3.7, 3.8]
requires: ['minimal', 'latest']
# exclude:
# # excludes PT 1.3 as it is missing on pypi
# - python-version: 3.8
# requires: 'minimal'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 15

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}

# Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646
- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
#- name: Setup Windows
# if: runner.os == 'windows'
# run: |
# python -c "lines = open('requirements.txt').readlines() + ['torch<1.5\n']; open('requirements.txt', 'w').writelines(lines)"

- name: Set min. dependencies
if: matrix.requires == 'minimal'
run: |
python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)"
# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- name: Get pip cache
id: pip-cache
run: |
python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)"
- name: Cache pip
uses: actions/cache@v1
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
- name: Install dependencies
run: |
# python -m pip install --upgrade --user pip
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q
pip install -r ./tests/requirements.txt -q
# pip install tox coverage
python --version
pip --version
pip list
shell: bash

- name: Cache datasets
uses: actions/cache@v1
with:
path: tests/Datasets # This path is specific to Ubuntu
# Look to see if there is a cache hit for the corresponding requirements file
key: mnist-dataset

- name: Tests
# env:
# TOXENV: py${{ matrix.python-version }}
run: |
# tox --sitepackages
coverage run --source pl_bolts -m py.test pl_bolts tests -v --doctest-modules --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
- name: Upload pytest test results
uses: actions/upload-artifact@master
with:
name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: always()

- name: Package Setup
run: |
check-manifest
python setup.py check --metadata --strict
python setup.py sdist
twine check dist/*
#- name: Try install package
# if: ! startsWith(matrix.os, 'windows')
# run: |
# virtualenv vEnv ; source vEnv/bin/activate
# pip install --editable . ; cd .. & python -c "import pytorch_lightning ; print(pytorch_lightning.__version__)"
# deactivate ; rm -rf vEnv

- name: Statistics
if: success()
run: |
coverage report
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# PyTorchLightning Bolts

[![CircleCI](https://circleci.com/gh/PyTorchLightning/pytorch-lightning-bolts/tree/master.svg?style=svg)](https://circleci.com/gh/PyTorchLightning/pytorch-lightning-bolts/tree/master)
![CI testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20testing/badge.svg?branch=master)
[![codecov](https://codecov.io/gh/PyTorchLightning/pytorch-lightning-bolts/branch/master/graph/badge.svg)](https://codecov.io/gh/PyTorchLightning/pytorch-lightning-bolts)
[![CodeFactor](https://www.codefactor.io/repository/github/pytorchlightning/pytorch-lightning-bolts/badge)](https://www.codefactor.io/repository/github/pytorchlightning/pytorch-lightning-bolts)
[![Documentation Status](https://readthedocs.org/projects/pytorch-lightning-bolts/badge/?version=latest)](https://pytorch-lightning-bolts.readthedocs.io/en/latest/)
Expand Down
38 changes: 25 additions & 13 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

builtins.__LIGHTNING_BOLT_SETUP__ = True

SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True))

import pl_bolts # noqa: E402

# -- Project information -----------------------------------------------------
Expand Down Expand Up @@ -130,8 +132,7 @@
exclude_patterns = [
'api/pl_bolts.rst',
'api/modules.rst',
'api/pl_bolts.datamodules.*',
'api/pl_bolts.metrics.*'
'api/pl_bolts.submit.rst',
]

# The name of the Pygments (syntax highlighting) style to use.
Expand Down Expand Up @@ -279,24 +280,35 @@ def setup(app):
path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
shutil.copy(path_ipynb, path_ipynb2)


# Ignoring Third-party packages
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
def package_list_from_file(file):
mocked_packages = []
with open(file, 'r') as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(',=<>#') if ch in ln]
pkg = ln[:min(found)] if found else ln
if pkg.rstrip():
mocked_packages.append(pkg.rstrip())
return mocked_packages


MOCK_PACKAGES = []
if SPHINX_MOCK_REQUIREMENTS:
# mock also base packages when we are on RTD since we don't install them there
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt'))

MOCK_REQUIRE_PACKAGES = []
with open(os.path.join(PATH_ROOT, 'requirements.txt'), 'r') as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(',=<>#') if ch in ln]
pkg = ln[:min(found)] if found else ln
if pkg.rstrip():
MOCK_REQUIRE_PACKAGES.append(pkg.rstrip())

# TODO: better parse from package since the import name and package name may differ
MOCK_MANUAL_PACKAGES = [
'torch',
'pytorch_lightning',
'numpy',
'torch',
'torchvision',
'sklearn',
'PIL',
'cv2',
]
autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES
autodoc_mock_imports = MOCK_PACKAGES + MOCK_MANUAL_PACKAGES
# for mod_name in MOCK_REQUIRE_PACKAGES:
# sys.modules[mod_name] = mock.Mock()

Expand Down
5 changes: 5 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ Indices and tables
:hidden:

readme
api/pl_bolts.callbacks
api/pl_bolts.datamodules
api/pl_bolts.metrics
api/pl_bolts.models
api/pl_bolts.callbacks
api/pl_bolts.losses
api/pl_bolts.loggers
api/pl_bolts.optimizers
api/pl_bolts.transforms
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/imagenet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def generate_meta_bins(cls, devkit_dir):

def _verify_archive(root, file, md5):
if not _check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.")
raise RuntimeError(msg.format(file, root))
raise RuntimeError(
f"The archive {file} is not present in the root directory or is corrupted."
f" You need to download it externally and place it in {root}.")


def _check_integrity(fpath, md5=None):
Expand Down Expand Up @@ -240,7 +240,7 @@ def extract_archive(from_path, to_path=None, remove_finished=False):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
raise ValueError(f"Extraction of {from_path} not supported")

if remove_finished:
os.remove(from_path)
Expand Down
34 changes: 25 additions & 9 deletions pl_bolts/datamodules/stl10_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from torchvision import transforms as transform_lib
from torchvision.datasets import STL10

from pl_bolts.datamodules.concat_dataset import ConcatDataset
from pl_bolts.datamodules.bolts_dataloaders_base import BoltDataLoaders
from pl_bolts.datamodules.concat_dataset import ConcatDataset
from pl_bolts.transforms.dataset_normalizations import stl10_normalization


Expand Down Expand Up @@ -31,7 +31,9 @@ def train_dataloader(self, batch_size, transforms=None):

dataset = STL10(self.save_path, split='unlabeled', download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(dataset, [train_length - self.unlabeled_val_split, self.unlabeled_val_split])
dataset_train, _ = random_split(dataset,
[train_length - self.unlabeled_val_split,
self.unlabeled_val_split])
loader = DataLoader(
dataset_train,
batch_size=batch_size,
Expand All @@ -46,14 +48,20 @@ def train_dataloader_mixed(self, batch_size, transforms=None):
if transforms is None:
transforms = self._default_transforms()

unlabeled_dataset = STL10(self.save_path, split='unlabeled', download=False, transform=transforms)
unlabeled_dataset = STL10(self.save_path,
split='unlabeled',
download=False,
transform=transforms)
unlabeled_length = len(unlabeled_dataset)
unlabeled_dataset, _ = random_split(unlabeled_dataset,
[unlabeled_length - self.unlabeled_val_split, self.unlabeled_val_split])
[unlabeled_length - self.unlabeled_val_split,
self.unlabeled_val_split])

labeled_dataset = STL10(self.save_path, split='train', download=False, transform=transforms)
labeled_length = len(labeled_dataset)
labeled_dataset, _ = random_split(labeled_dataset, [labeled_length - self.train_val_split, self.train_val_split])
labeled_dataset, _ = random_split(labeled_dataset,
[labeled_length - self.train_val_split,
self.train_val_split])

dataset = ConcatDataset(unlabeled_dataset, labeled_dataset)
loader = DataLoader(
Expand All @@ -72,7 +80,9 @@ def val_dataloader(self, batch_size, transforms=None):

dataset = STL10(self.save_path, split='unlabeled', download=False, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(dataset, [train_length - self.unlabeled_val_split, self.unlabeled_val_split])
_, dataset_val = random_split(dataset,
[train_length - self.unlabeled_val_split,
self.unlabeled_val_split])
loader = DataLoader(
dataset_val,
batch_size=batch_size,
Expand All @@ -86,14 +96,20 @@ def val_dataloader_mixed(self, batch_size, transforms=None):
if transforms is None:
transforms = self._default_transforms()

unlabeled_dataset = STL10(self.save_path, split='unlabeled', download=False, transform=transforms)
unlabeled_dataset = STL10(self.save_path,
split='unlabeled',
download=False,
transform=transforms)
unlabeled_length = len(unlabeled_dataset)
_, unlabeled_dataset = random_split(unlabeled_dataset,
[unlabeled_length - self.unlabeled_val_split, self.unlabeled_val_split])
[unlabeled_length - self.unlabeled_val_split,
self.unlabeled_val_split])

labeled_dataset = STL10(self.save_path, split='train', download=False, transform=transforms)
labeled_length = len(labeled_dataset)
_, labeled_dataset = random_split(labeled_dataset, [labeled_length - self.train_val_split, self.train_val_split])
_, labeled_dataset = random_split(labeled_dataset,
[labeled_length - self.train_val_split,
self.train_val_split])

dataset = ConcatDataset(unlabeled_dataset, labeled_dataset)
loader = DataLoader(
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/losses/self_supervised_learning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
from torch import nn
import numpy as np

from pl_bolts.models.vision import PixelCNN

Expand Down
5 changes: 2 additions & 3 deletions pl_bolts/models/self_supervised/amdim/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, dummy_batch, num_channels=3, ndf=64, n_rkhs=512,
self.encoder_size = encoder_size

# encoding block for local features
print('Using a {}x{} encoder'.format(encoder_size, encoder_size))
print(f'Using a {encoder_size}x{encoder_size} encoder')
if encoder_size == 32:
self.layer_list = nn.ModuleList([
Conv3x3(num_channels, ndf, 3, 1, 0, False),
Expand Down Expand Up @@ -58,8 +58,7 @@ def __init__(self, dummy_batch, num_channels=3, ndf=64, n_rkhs=512,
MaybeBatchNorm2d(n_rkhs, True, True)
])
else:
raise RuntimeError("Could not build encoder."
"Encoder size {} is not supported".format(encoder_size))
raise RuntimeError(f"Could not build encoder. Encoder size {encoder_size} is not supported")
self._config_modules(
dummy_batch,
output_widths=[1, 5, 7],
Expand Down
6 changes: 5 additions & 1 deletion pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from pl_bolts import metrics
from pl_bolts.datamodules import CIFAR10DataLoaders, STL10DataLoaders
from pl_bolts.datamodules.ssl_imagenet_dataloaders import SSLImagenetDataLoaders
from pl_bolts.losses.self_supervised_learning import InfoNCE
from pl_bolts.models.self_supervised.cpc import transforms as cpc_transforms
from pl_bolts.models.self_supervised.cpc.networks import CPCResNet101
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts.losses.self_supervised_learning import InfoNCE

__all__ = [
'CPCV2'
Expand Down Expand Up @@ -285,7 +285,11 @@ def add_model_specific_args(parent_parser):
'dataset': 'imagenet128',
'depth': 10,
'patch_size': 32,
<<<<<<< HEAD
'batch_size': 48,
=======
'batch_size': 52,
>>>>>>> cee9bdd88cdbec912af7c419925b5d8c0fa47f58
'nb_classes': 1000,
'patch_overlap': 32 // 2,
'lr_options': [
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/self_supervised/cpc/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, sample_batch, zero_init_residual=False,
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
raise ValueError("replace_stride_with_dilation should be None"
f" or a 3-element tuple, got {replace_stride_with_dilation}")
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
Expand Down

0 comments on commit 81e6c75

Please sign in to comment.