Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed May 12, 2020
2 parents 4624a3c + d65b204 commit 32ec685
Show file tree
Hide file tree
Showing 18 changed files with 1,130 additions and 251 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.15.0 (2020-05-13)

### New:
- Out-of-the-box support for image regression
- `vision.images_from_df` function to load image data from *pandas* DataFrames

### Changed
- references to `fit_generator` and `predict_generator` converted to `fit` and `predict`

### Fixed:
- Resolved issue with multilabel detection returning `False` for valid multilabel problems when data is in form of generator


## 0.14.7 (2020-05-10)

### New:
Expand Down
35 changes: 19 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@


### News and Announcements
- **2020-05-13:**
- ***ktrain*** **v0.15.x is released** and includes support for:
- **image regression**: See the [example notebook on age prediction from photos](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/vision/utk_faces_age_prediction-resnet50.ipynb).
- **`tf.data.Datasets`**: See the [example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/vision/mnist-tf_workflow.ipynb) on using `tf.data.Datasets` in *ktrain* for custom models and data formats.
- **2020-04-15:**
- ***ktrain*** **v0.14.x is released** and now includes support for **open-domain question-answering**. See the [example QA notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/question_answering_with_bert.ipynb)
- **2020-04-09:**
Expand All @@ -20,21 +24,6 @@ ts = text.TransformerSummarizer()
ts.summarize(some_long_document)
```

- **2020-03-31:**
- ***ktrain*** **v0.12.x is released** and now includes BERT embeddings (i.e., BERT, DistilBert, and Albert) that can be used for downstream tasks like building sequence-taggers (i.e., NER)
for any language such as English, Chinese, Russian, Arabic, Dutch, etc. See [this English NER example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/CoNLL2003-BiLSTM.ipynb) or the [Dutch NER notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/CoNLL2002_Dutch-BiLSTM.ipynb) for examples on how to use this feature.
*ktrain* also supports NER with domain-specific embeddings from [community-uploaded Hugging Face models](https://huggingface.co/models) such as [BioBERT](https://arxiv.org/abs/1901.08746) for the biomedical domain:
```python
# NER with BioBERT embeddings
import ktrain
from ktrain import text as txt
x_train= [['IL-2', 'responsiveness', 'requires', 'three', 'distinct', 'elements', 'within', 'the', 'enhancer', '.'], ...]
y_train=[['B-protein', 'O', 'O', 'O', 'O', 'B-DNA', 'O', 'O', 'B-DNA', 'O'], ...]
(trn, val, preproc) = txt.entities_from_array(x_train, y_train)
model = txt.sequence_tagger('bilstm-bert', preproc, bert_model='monologg/biobert_v1.1_pubmed')
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)
learner.fit(0.01, 1, cycle_len=5)
```
----

### Overview
Expand All @@ -54,6 +43,7 @@ learner.fit(0.01, 1, cycle_len=5)
- **Open-Domain Question-Answering**: ask a large text corpus questions and receive exact answers <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/question_answering_with_bert.ipynb)]</sup></sub>
- `vision` data:
- **image classification** (e.g., [ResNet](https://arxiv.org/abs/1512.03385), [Wide ResNet](https://arxiv.org/abs/1605.07146), [Inception](https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf)) <sub><sup>[[example notebook](https://colab.research.google.com/drive/1WipQJUPL7zqyvLT10yekxf_HNMXDDtyR)]</sup></sub>
- **image regression** for predicting numerical targets from photos (e.g., age prediction) <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/vision/utk_faces_age_prediction-resnet50.ipynb)]</sup></sub>
- `graph` data:
- **node classification** with graph neural networks ([GraphSAGE](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)) <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/graphs/pubmed_node_classification-GraphSAGE.ipynb)]</sup></sub>
- **link prediction** with graph neural networks ([GraphSAGE](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)) <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/graphs/cora_link_prediction-GraphSAGE.ipynb)]</sup></sub>
Expand Down Expand Up @@ -251,13 +241,25 @@ learner.validate(class_names=t.get_classes()) # class_names must be string value
# weighted avg 0.96 0.96 0.96 1502
```

#### Example: NER With [BioBERT](https://arxiv.org/abs/1901.08746) Embeddings
```python
# NER with BioBERT embeddings
import ktrain
from ktrain import text as txt
x_train= [['IL-2', 'responsiveness', 'requires', 'three', 'distinct', 'elements', 'within', 'the', 'enhancer', '.'], ...]
y_train=[['B-protein', 'O', 'O', 'O', 'O', 'B-DNA', 'O', 'O', 'B-DNA', 'O'], ...]
(trn, val, preproc) = txt.entities_from_array(x_train, y_train)
model = txt.sequence_tagger('bilstm-bert', preproc, bert_model='monologg/biobert_v1.1_pubmed')
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)
learner.fit(0.01, 1, cycle_len=5)
```

Using *ktrain* on **Google Colab**? See these Colab examples:
- [a simple demo of Multiclass Text Classification with BERT](https://colab.research.google.com/drive/1AH3fkKiEqBpVpO5ua00scp7zcHs5IDLK)
- [a simple demo of Multiclass Text Classification with Hugging Face Transformers](https://colab.research.google.com/drive/1YxcceZxsNlvK35pRURgbwvkgejXwFxUt)
- [image classification with Cats vs. Dogs](https://colab.research.google.com/drive/1WipQJUPL7zqyvLT10yekxf_HNMXDDtyR)

**Additional examples can be found [here](https://github.com/amaiya/ktrain/tree/master/examples).**
#### Additional examples can be found [here](https://github.com/amaiya/ktrain/tree/master/examples).



Expand All @@ -271,6 +273,7 @@ While *ktrain* will probably work with other versions of TensorFlow 2.x, v2.1.0
2. Install *ktrain*: `pip3 install ktrain`

**Some things to note:**
- *ktrain* will automatically install TensorFlow 2 as a dependency.
- Since some *ktrain* dependencies have not yet been migrated to `tf.keras` in TensorFlow 2 (or may have other issues),
*ktrain* is temporarily using forked versions of some libraries. Specifically, *ktrain* uses forked versions of the `eli5` and `stellargraph` libraries. If not installed, *ktrain* will complain when a method or function needing
either of these libraries is invoked.
Expand Down
7 changes: 7 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This directory contains various example notebooks using *ktrain*. The directory
- [Open-Domain Question-Answering](#textqa): ask questions to a large text corpus and receive exact candidate answers
- `vision`:
- [image classification](#imageclass): models for image datasetsimage classification examples using various models and datasets
- [image regression](#imageregression): example of predicting numerical values purely from images/photos
- `graphs`:
- [node classification](#-graph-node-classification-datasets): node classification in graphs or networks
- [link prediction](#-graph-link-prediction-datasets): link prediction in graphs or networks
Expand Down Expand Up @@ -159,6 +160,12 @@ Image labels are in the form of a CSV containing paths to images.
- [planet-ResNet50.ipynb](https://github.com/amaiya/ktrain/tree/master/examples/vision): Using a pretrained ResNet50 model for multi-label classification.


### <a name="imageregression"></a> Image Regression

#### [Age Prediction](http://aicip.eecs.utk.edu/wiki/UTKFace): Image Regression
- [utk_faces_age_prediction-resnet50.ipynb](https://github.com/amaiya/ktrain/tree/master/examples/vision): ResNet50 pretrained on ImageNet for age prediction using UTK Face dataset


## Graph Data

### <a name="#nodeclass"></a> Graph Node Classification Datasets
Expand Down
611 changes: 611 additions & 0 deletions examples/vision/utk_faces_age_prediction-resnet50.ipynb

Large diffs are not rendered by default.

33 changes: 11 additions & 22 deletions ktrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,11 @@ def top_losses(self, n=4, val_data=None, preproc=None):
# compute loss
# this doesn't work in tf.keras 1.14
#losses = self.model.loss_functions[0](tf.convert_to_tensor(y_true), tf.convert_to_tensor(y_pred))
if U.is_tf_keras():
L = self.model.loss_functions[0].fn
else:
L = self.model.loss_functions[0]
#if U.is_tf_keras():
#L = self.model.loss_functions[0].fn
#else:
#L = self.model.loss_functions[0]
L = U.loss_fn_from_model(self.model)
losses = L(tf.convert_to_tensor(y_true), tf.convert_to_tensor(y_pred))
if DISABLE_V2_BEHAVIOR:
losses = tf.Session().run(losses)
Expand Down Expand Up @@ -889,8 +890,10 @@ def predict(self, val_data=None):
if U.is_iter(val):
if hasattr(val, 'reset'): val.reset()
steps = np.ceil(U.nsamples_from_data(val)/val.batch_size)
result = self.model.predict_generator(self._prepare(val, train=False),
steps=steps)
# *_generator methods are deprecated from TF 2.1.0
#result = self.model.predict_generator(self._prepare(val, train=False),
#steps=steps)
result = self.model.predict(self._prepare(val, train=False), steps=steps)
return result
else:
return self.model.predict(val[0], batch_size=self.eval_batch_size)
Expand Down Expand Up @@ -1103,7 +1106,7 @@ def fit(self, lr, n_cycles, cycle_len=None, cycle_mult=1,
lr_decay=1.0, checkpoint_folder=None, early_stopping=None,
class_weight=None, callbacks=[], verbose=1):
"""
Trains the model. By default, fit is simply a wrapper for model.fit_generator.
Trains the model. By default, fit is simply a wrapper for model.fit (for generators/sequences).
When cycle_len parameter is supplied, an SGDR learning rate schedule is used.
lr (float): learning rate
Expand Down Expand Up @@ -1165,21 +1168,7 @@ def fit(self, lr, n_cycles, cycle_len=None, cycle_mult=1,
# train model
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='.*Check your callbacks.*')
# bug in TF2 causes fit_generator to be very slow
# https://github.com/tensorflow/tensorflow/issues/33024
if version.parse(tf.__version__) < version.parse('2.0'):
fit_fn = self.model.fit_generator
else:
# TF bug with using multiple inputs with utils.Sequence and model.fit
# TODO: check data and proceed accordingly
# potential patch is to have Sequence subclasses return tuple(batch_x), y
if U.is_nodeclass(model=self.model, data=self.train_data) or\
U.is_ner(model=self.model, data=self.train_data):
fit_fn = self.model.fit_generator
else:
fit_fn = self.model.fit
# fixed in 2.1.0
#fit_fn = self.model.fit
fit_fn = self.model.fit
hist = fit_fn(self._prepare(self.train_data),
steps_per_epoch = steps_per_epoch,
validation_steps = validation_steps,
Expand Down
12 changes: 9 additions & 3 deletions ktrain/graph/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def predict_transductive(self, node_ids, return_proba=False):
"""
gen = self.preproc.preprocess_valid(node_ids)
gen.batch_size = self.batch_size
preds = self.model.predict_generator(gen)
# *_generator methods are deprecated from TF 2.1.0
#preds = self.model.predict_generator(gen)
preds = self.model.predict(gen)
result = preds if return_proba else [self.c[np.argmax(pred)] for pred in preds]
return result

Expand All @@ -48,7 +50,9 @@ def predict_inductive(self, df, G, return_proba=False):

gen = self.preproc.preprocess(df, G)
gen.batch_size = self.batch_size
preds = self.model.predict_generator(gen)
# *_generator methods are deprecated from TF 2.1.0
#preds = self.model.predict_generator(gen)
preds = self.model.predict(gen)
result = preds if return_proba else [self.c[np.argmax(pred)] for pred in preds]
return result

Expand Down Expand Up @@ -81,7 +85,9 @@ def predict(self, G, edge_ids, return_proba=False):
"""
gen = self.preproc.preprocess(G, edge_ids)
gen.batch_size = self.batch_size
preds = self.model.predict_generator(gen)
# *_generator methods are deprecated from TF 2.1.0
#preds = self.model.predict_generator(gen)
preds = self.model.predict(gen)
preds = np.squeeze(preds)
if return_proba:
return [[1-pred, pred] for pred in preds]
Expand Down
12 changes: 5 additions & 7 deletions ktrain/lroptimize/lrfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,8 @@ def find(self, train_data, steps_per_epoch, use_gen=False,


if use_gen:
if version.parse(tf.__version__) < version.parse('2.0'):
fit_fn = self.model.fit_generator
else:
fit_fn = self.model.fit

# *_generator methods are deprecated from TF 2.1.0
fit_fn = self.model.fit
fit_fn(train_data, steps_per_epoch=steps_per_epoch,
epochs=epochs,
workers=workers, use_multiprocessing=use_multiprocessing,
Expand Down Expand Up @@ -151,16 +148,17 @@ def plot_loss(self, n_skip_beginning=10, n_skip_end=1, suggest=False):
# this code was adapted from fastai: https://github.com/fastai/fastai
try:
ml = np.argmin(self.losses)
mg = (np.gradient(np.array(self.losses[10:ml]))).argmin()
mg = (np.gradient(np.array(self.losses[32:ml]))).argmin()
except:
print("Failed to compute the gradients, there might not be enough points.\n" +\
"Plot displayed without suggestion.")
return
else:
print('Two possible suggestions for LR from plot:')
print(f"\tMin numerical gradient: {self.lrs[mg]:.2E}")
ax.plot(self.lrs[mg],self.losses[mg], markersize=10,marker='o',color='red')
print(f"\tMin loss divided by 10: {self.lrs[ml]/10:.2E}")
print(mg)
ax.plot(self.lrs[mg],self.losses[mg], markersize=10,marker='o',color='red')
return


Expand Down
19 changes: 16 additions & 3 deletions ktrain/tests/test_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def images_from_csv():
return (trn, val, preproc)


def images_from_fname():
trn, val, preproc = vis.images_from_fname(
'./image_data/image_folder/all',
pattern=r'([^/]+)\.\d+.jpg$',
val_pct=0.25, random_state=42,
data_aug=vis.get_data_aug(horizontal_flip=True))
return (trn, val, preproc)



class TestTextData(TestCase):

Expand Down Expand Up @@ -202,14 +211,18 @@ def test_images_from_csv(self):
(trn, val, preproc) = images_from_csv()
self.__test_images(trn, val, preproc)

def test_images_from_fname(self):
(trn, val, preproc) = images_from_fname()
self.__test_images(trn, val, preproc, nsamples=20)


def __test_images(self, trn, val, preproc):
def __test_images(self, trn, val, preproc, nsamples=16):
self.assertTrue(U.is_iter(trn))
self.assertEqual(U.shape_from_data(trn), (224, 224, 3))
self.assertTrue(U.ondisk(trn))
self.assertEqual(U.nsamples_from_data(trn), 16)
self.assertEqual(U.nsamples_from_data(trn), nsamples)
self.assertEqual(U.nclasses_from_data(trn), 2)
self.assertEqual(U.y_from_data(trn).shape, (16,2))
self.assertEqual(U.y_from_data(trn).shape, (nsamples,2))
self.assertFalse(U.bert_data_tuple(trn))
self.assertEqual(preproc.get_classes(), ['cat', 'dog'])
(gen, steps) = preproc.preprocess('./image_data/image_folder/all')
Expand Down
26 changes: 17 additions & 9 deletions ktrain/tests/test_imageclassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def test_folder(self):
model = vis.image_classifier('pretrained_resnet50', trn, val)
learner = ktrain.get_learner(model=model, train_data=trn, val_data=val, batch_size=1)
learner.freeze()


# test weight decay
self.assertEqual(learner.get_weight_decay(), None)
learner.set_weight_decay(1e-2)
self.assertAlmostEqual(learner.get_weight_decay(), 1e-2)

# train
hist = learner.autofit(1e-3, monitor=VAL_ACC_NAME)

# test train
Expand All @@ -59,10 +67,6 @@ def test_folder(self):
else:
self.assertEqual(max(hist.history[VAL_ACC_NAME]), 1)

# test weight decay
self.assertEqual(learner.get_weight_decay(), None)
learner.set_weight_decay(1e-2)
self.assertAlmostEqual(learner.get_weight_decay(), 1e-2)

# test load and save model
learner.save_model('/tmp/test_model')
Expand Down Expand Up @@ -108,6 +112,14 @@ def test_csv(self):
model = vis.image_classifier('pretrained_resnet50', trn, val)
learner = ktrain.get_learner(model=model, train_data=trn, val_data=val, batch_size=4)
learner.freeze()

# test weight decay
self.assertEqual(learner.get_weight_decay(), None)
learner.set_weight_decay(1e-2)
self.assertAlmostEqual(learner.get_weight_decay(), 1e-2)


# train
hist = learner.fit_onecycle(lr, 3)

# test train
Expand All @@ -124,10 +136,6 @@ def test_csv(self):
else:
self.assertEqual(max(hist.history[VAL_ACC_NAME]), 1)

# test weight decay
self.assertEqual(learner.get_weight_decay(), None)
learner.set_weight_decay(1e-2)
self.assertAlmostEqual(learner.get_weight_decay(), 1e-2)

# test load and save model
learner.save_model('/tmp/test_model')
Expand Down Expand Up @@ -185,7 +193,7 @@ def test_array(self):
(trn, val, preproc) = vis.images_from_array(x_train, y_train,
validation_data=(x_test, y_test),
data_aug=data_aug,
classes=classes)
class_names=classes)

model = vis.image_classifier('default_cnn', trn, val)
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)
Expand Down

0 comments on commit 32ec685

Please sign in to comment.