Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jul 28, 2021
2 parents 73dd197 + 1aab650 commit 5edd86f
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 6 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.27.2 (2021-07-28)

### New:
- N/A

### Changed
- N/A

### Fixed:
- check for `logits` attribute when predicting using `transformers`
- change raised Exception to warning for longer sequence lengths for `transformers`


## 0.27.1 (2021-07-20)

### New:
Expand Down
56 changes: 56 additions & 0 deletions FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@

- [How do I speed up predictions?](#how-do-i-increase-batch-size-for-predictions)

- [How do I do cross validation with transformers?](#how-do-i-do-cross-validation-with-transformers)



---
Expand Down Expand Up @@ -1043,6 +1045,59 @@ The `get_learner` function accepts an `eval_batch_size` argument that will be us



### How do I do cross validation with transformers?

Here is a quick self-contained example:

```python
from ktrain import text
import ktrain
import pandas as pd
from sklearn.model_selection import train_test_split,KFold
from sklearn.metrics import accuracy_score
from sklearn.datasets import fetch_20newsgroups

# load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)
df = pd.DataFrame({'text':x_train, 'target': [train_b.target_names[y] for y in y_train]})

# CV with transformers
N_FOLDS = 2
EPOCHS = 3
LR = 5e-5
def transformer_cv(MODEL_NAME):
preproc = text.Transformer(MODEL_NAME, maxlen=500)
predictions,accs=[],[]
data = df[['text', 'target']]
for train_index, val_index in KFold(N_FOLDS).split(data):
preproc = text.Transformer(MODEL_NAME, maxlen=500)
train,val=data.iloc[train_index],data.iloc[val_index]
x_train=train.text.values
x_val=val.text.values
y_train=train.target.values
y_val=val.target.values
trn = preproc.preprocess_train(x_train, y_train)
model = preproc.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, batch_size=16)
learner.fit_onecycle(LR, EPOCHS)
predictor = ktrain.get_predictor(learner.model, preproc)
pred=predictor.predict(x_val)
acc=accuracy_score(y_val,pred)
print('acc',acc)
accs.append(acc)
return accs
print( transformer_cv('distilbert-base-uncased') )
```


[[Back to Top](#frequently-asked-questions-about-ktrain)]




### What kinds of applications have been built with *ktrain*?

Expand All @@ -1063,3 +1118,4 @@ Examples include:
[[Back to Top](#frequently-asked-questions-about-ktrain)]



4 changes: 2 additions & 2 deletions examples/tabular/causal_inference_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"\n",
"## What is the causal impact of having a PhD on making over 50K/year?\n",
"\n",
"As of v0.27.x, ktrain supports causal inference using [meta-learners](https://arxiv.org/abs/1706.03461). We will use the well-studied [Adults Census](https://raw.githubusercontent.com/amaiya/ktrain/master/ktrain/tests/tabular_data/adults.csv) dataset from the UCI ML repository, which is census data from the early to mid 1990s. The objective is to estimate how much earning a PhD increases the probability of making over $50K in salary. This dataset is simply being used as a simple demonstration example. Unlike conventional supervised machine learning, there is typically no ground truth for causal infernence models, unless you're using a simulated datasets. So, we will simply check our estimates to see if they agree with intuition for illustration purposes in addition to inspecting robustness.\n",
"As of v0.27.x, ktrain supports causal inference using [meta-learners](https://arxiv.org/abs/1706.03461). We will use the well-studied [Adults Census](https://raw.githubusercontent.com/amaiya/ktrain/master/ktrain/tests/tabular_data/adults.csv) dataset from the UCI ML repository, which is census data from the early to mid 1990s. The objective is to estimate how much earning a PhD increases the probability of making over $50K in salary. This dataset is simply being used as a simple demonstration example. Unlike conventional supervised machine learning, there is typically no ground truth for causal inference models, unless you're using a simulated dataset. So, we will simply check our estimates to see if they agree with intuition for illustration purposes in addition to inspecting robustness.\n",
"\n",
"Let's begin by loading the dataset and creating a binary treatment (1 for PhD and 0 for no PhD)."
]
Expand Down Expand Up @@ -237,7 +237,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's invoke the `causal_inference_model` function to create a `CausalInferenceModel` instance and invoke `fit` to estimate the individualized treatment effect for each row in this dataset. By default, a [T-Learner](https://arxiv.org/abs/1706.03461) metalearner is used with LightGBM models as base learners. These can be adjusted using the `method` and `learner` parameters. Since this example is simply being used for illustration purposes, we will ignore the `fnlwgt` column, which represents the number of people the census believes the entry represents. In practice, one might incorporate domain knowledge when choosing which variables to include and ignore. For instance, variables thought to be common effects of both the treatment and outcome might be excluded as [colliders](https://en.wikipedia.org/wiki/Collider_(statistics). Finally, we will also exclude the education-related columns as they are already captured in the treatment. "
"Next, let's invoke the `causal_inference_model` function to create a `CausalInferenceModel` instance and invoke `fit` to estimate the individualized treatment effect for each row in this dataset. By default, a [T-Learner](https://arxiv.org/abs/1706.03461) metalearner is used with LightGBM models as base learners. These can be adjusted using the `method` and `learner` parameters. Since this example is simply being used for illustration purposes, we will ignore the `fnlwgt` column, which represents the number of people the census believes the entry represents. In practice, one might incorporate domain knowledge when choosing which variables to include and ignore. For instance, variables thought to be common effects of both the treatment and outcome might be excluded as [colliders](https://en.wikipedia.org/wiki/Collider_(statistics)). Finally, we will also exclude the education-related columns as they are already captured in the treatment. "
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion ktrain/text/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def predict(self, val_data=None):
if hasattr(val, 'reset'): val.reset()
classification, multilabel = U.is_classifier(self.model)
preds = self.model.predict(self._prepare(val, train=False))
if type(preds).__name__ == 'TFSequenceClassifierOutput': # dep_fix: undocumented breaking change in transformers==4.0.0
if hasattr(preds, 'logits'): # dep_fix: breaking change in transformers==4.0.0 - also needed for Longformer
#if type(preds).__name__ == 'TFSequenceClassifierOutput': # dep_fix: undocumented breaking change in transformers==4.0.0
# REFERENCE: https://discuss.huggingface.co/t/new-model-output-types/195
preds = preds.logits

# dep_fix: transformers in TF 2.2.0 returns a tuple insead of NumPy array for some reason
Expand Down
4 changes: 3 additions & 1 deletion ktrain/text/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def predict(self, texts, return_proba=False):
tseq.batch_size = self.batch_size
tfd = tseq.to_tfdataset(train=False)
preds = self.model.predict(tfd)
if type(preds).__name__ == 'TFSequenceClassifierOutput': # dep_fix: undocumented breaking change in transformers==4.0.0
if hasattr(preds, 'logits'): # dep_fix: breaking change - also needed for LongFormer
#if type(preds).__name__ == 'TFSequenceClassifierOutput': # dep_fix: undocumented breaking change in transformers==4.0.0
# REFERENCE: https://discuss.huggingface.co/t/new-model-output-types/195
preds = preds.logits

# dep_fix: transformers in TF 2.2.0 returns a tuple insead of NumPy array for some reason
Expand Down
2 changes: 1 addition & 1 deletion ktrain/text/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def __init__(self, model_name,
lang='en', ngram_range=1, multilabel=None):
class_names = self.migrate_classes(class_names, classes)

if maxlen > 512: raise ValueError('Transformer models only supports maxlen <= 512')
if maxlen > 512: warnings.warn('Transformer models typically only support maxlen <= 512, unless you are using certain models like the Longformer.')

super().__init__(maxlen, class_names, lang=lang, multilabel=multilabel)

Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.27.1'
__version__ = '0.27.2'

0 comments on commit 5edd86f

Please sign in to comment.