Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Fix coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jun 1, 2020
1 parent af9e495 commit 7d5420b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 33 deletions.
15 changes: 8 additions & 7 deletions .travis.yml
@@ -1,12 +1,12 @@
dist: xenial
language: python
python:
- 3.6
python: "3.6"
env:
- KERAS_BACKEND=tensorflow TF_KERAS=1 TF_2=1
- KERAS_BACKEND=tensorflow TF_KERAS=1 TF_EAGER=1
- KERAS_BACKEND=tensorflow TF_KERAS=1
- KERAS_BACKEND=tensorflow
global:
- COVERALLS_PARALLEL=true
matrix:
- KERAS_BACKEND=tensorflow
- KERAS_BACKEND=tensorflow TF_KERAS=1
install:
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
Expand All @@ -20,9 +20,10 @@ install:
- pip install --upgrade pip
- pip install -r requirements.txt
- pip install -r requirements-dev.txt
- if [[ $TF_2 == "1" ]]; then pip install tensorflow==2.0.0-beta1; fi
- pip install coveralls
script:
- ./test.sh
after_success:
coveralls
notifications:
webhooks: https://coveralls.io/webhook
2 changes: 2 additions & 0 deletions keras_trans_mask/__init__.py
@@ -1 +1,3 @@
from .masks import *

__version__ = '0.4.0'
21 changes: 6 additions & 15 deletions keras_trans_mask/backend.py
@@ -1,25 +1,17 @@
import os
from distutils.util import strtobool

__all__ = [
'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', 'engine',
'keras', 'utils', 'activations', 'applications', 'backend', 'datasets',
'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers',
'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'TF_KERAS', 'EAGER_MODE'
'metrics', 'models', 'losses', 'optimizers', 'regularizers', 'TF_KERAS',
]

TF_KERAS = False
EAGER_MODE = False
TF_KERAS = strtobool(os.environ.get('TF_KERAS', '0'))

if os.environ.get('TF_KERAS', '0') != '0':
if TF_KERAS:
import tensorflow as tf
from tensorflow.python import keras
TF_KERAS = True
if os.environ.get('TF_EAGER', '0') != '0':
try:
tf.enable_eager_execution()
raise AttributeError()
except AttributeError as e:
pass
EAGER_MODE = tf.executing_eagerly()
keras = tf.keras
else:
import keras

Expand All @@ -28,7 +20,6 @@
applications = keras.applications
backend = keras.backend
datasets = keras.datasets
engine = keras.engine
layers = keras.layers
preprocessing = keras.preprocessing
wrappers = keras.wrappers
Expand Down
4 changes: 2 additions & 2 deletions keras_trans_mask/masks.py
Expand Up @@ -56,7 +56,7 @@ def compute_mask(self, inputs, mask=None):
return None

def call(self, inputs, **kwargs):
return K.identity(inputs)
return inputs + 0.0


class RestoreMask(keras.layers.Layer):
Expand All @@ -81,4 +81,4 @@ def compute_mask(self, inputs, mask=None):
return mask[1]

def call(self, inputs, **kwargs):
return K.identity(inputs[0])
return inputs[0] + 0.0
31 changes: 22 additions & 9 deletions setup.py
@@ -1,30 +1,43 @@
import os
import re
import codecs
from setuptools import setup, find_packages

current_path = os.path.abspath(os.path.dirname(__file__))

with codecs.open('README.md', 'r', 'utf8') as reader:
long_description = reader.read()

def read_file(*parts):
with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader:
return reader.read()

with codecs.open('requirements.txt', 'r', 'utf8') as reader:
install_requires = list(map(lambda x: x.strip(), reader.readlines()))

def get_requirements(*parts):
with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader:
return list(map(lambda x: x.strip(), reader.readlines()))


def find_version(*file_paths):
version_file = read_file(*file_paths)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError('Unable to find version string.')


setup(
name='keras-trans-mask',
version='0.3.0',
version=find_version('keras_trans_mask', '__init__.py'),
packages=find_packages(),
url='https://github.com/CyberZHG/keras-trans-mask',
license='MIT',
author='CyberZHG',
author_email='CyberZHG@users.noreply.github.com',
description='Transfer masking in Keras',
long_description=long_description,
long_description=read_file('README.md'),
long_description_content_type='text/markdown',
install_requires=install_requires,
install_requires=get_requirements('requirements.txt'),
classifiers=(
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
),
Expand Down
9 changes: 9 additions & 0 deletions tests/test_masks.py
@@ -1,4 +1,6 @@
from unittest import TestCase
import os
import tempfile

import numpy as np

Expand Down Expand Up @@ -33,4 +35,11 @@ def test_over_fit(self):
[6, 7, 8, 9, 9, 9, 9, 9],
] * 1024)
y = np.array([[0], [1]] * 1024)
model_path = os.path.join(tempfile.gettempdir(), 'test_trans_mask_%f.h5' % np.random.random())
model.save(model_path)
model = keras.models.load_model(model_path, custom_objects={
'CreateMask': CreateMask,
'RemoveMask': RemoveMask,
'RestoreMask': RestoreMask,
})
model.fit(x, y, epochs=10)

0 comments on commit 7d5420b

Please sign in to comment.