Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jan 22, 2020
2 parents bef14dd + b93823b commit 55e00d8
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 24 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.8.3 (2020-01-22)

### New:
- N/A

### Changed:
- `MultiArrayDataset` accepts list of Numpy arrays

### Fixed:
- fixed incorrect activation in `TextPredictor` for multi-label Transformer models
- fixed `top_losses` for regression tasks


## 0.8.2 (2020-01-19)

### New:
Expand Down
19 changes: 14 additions & 5 deletions ktrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,13 @@ def top_losses(self, n=4, val_data=None, preproc=None):
y_true = self.ground_truth(val_data=val)
y_true = y_true.astype('float32')

# compute loss
# adjust y_true for regression problems
if not classification and len(y_true.shape) == 1 and\
(len(y_pred.shape) == 2 and y_pred.shape[1] == 1):
y_true = np.expand_dims(y_true, -1)


# compute loss
# this doesn't work in tf.keras 1.14
#losses = self.model.loss_functions[0](tf.convert_to_tensor(y_true), tf.convert_to_tensor(y_pred))
if U.is_tf_keras():
Expand All @@ -199,10 +204,14 @@ def top_losses(self, n=4, val_data=None, preproc=None):
else:
class_fcn = lambda x:class_names[x]

# don't put regression predictions in a list
if not classification and len(y_pred.shape) == 2 and y_pred.shape[1] == 1:
y_pred = np.squeeze(y_pred)
y_pred = np.around(y_pred, 2)
# regression output modifications
if not classification:
if len(y_pred.shape) == 2 and y_pred.shape[1] == 1:
y_pred = np.squeeze(y_pred)
y_pred = np.around(y_pred, 2)
if len(y_true.shape) == 2 and y_true.shape[1] == 1:
y_true = np.squeeze(y_true)
y_true = np.around(y_true, 2)

# sort by loss and prune correct classifications, if necessary
if classification and not multilabel:
Expand Down
41 changes: 26 additions & 15 deletions ktrain/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,28 @@ def nclasses(self):

class MultiArrayDataset(Dataset):
def __init__(self, x, y, batch_size=32):
if type(x) != np.ndarray or type(y) != np.ndarray:
raise ValueError('x and y must be numpy arrays')
if len(x.shape) != 3:
raise valueError('x must have 3 dimensions')
# error checks
err = False
if type(x) == np.ndarray and len(x.shape) != 2: err = True
elif type(x) == list:
for d in x:
if type(d) != np.ndarray or len(d.shape) != 2:
err = True
break
else: err = True
if err:
raise ValueError('x must be a 2d numpy array or a list of 2d numpy arrays')
if type(y) != np.ndarray:
raise ValueError('y must be a numpy array')
if type(x) == np.ndarray:
x = [x]

# set variables
super().__init__(batch_size=batch_size)
self.x, self.y = x, y
self.indices = np.arange(self.x[0].shape[0])
self.n_inputs = x.shape[0]
self.n_inputs = len(x)


def __len__(self):
return math.ceil(self.x[0].shape[0] / self.batch_size)
Expand All @@ -61,24 +75,21 @@ def __getitem__(self, idx):
batch_y = self.y[inds]
return tuple(batch_x), batch_y

def nsamples(self):
return self.x[0].shape[0]

def get_y(self):
return self.y

def on_epoch_end(self):
np.random.shuffle(self.indices)

def xshape(self):
return self.x.shape

def nsamples(self):
if self.n_inputs == 1:
return self.x.shape[0]
else:
return self.x.shape[1]
return self.x[0].shape

def nclasses(self):
return self.y.shape[1]

def get_y(self):
return self.y

def ondisk(self):
return False

9 changes: 7 additions & 2 deletions ktrain/text/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def predict(self, texts, return_proba=False):
preds = self.model.predict(texts)
if U.is_huggingface(model=self.model):
# convert logits to probabilities for Hugging Face models
preds = activations.softmax(tf.convert_to_tensor(preds)).numpy()
result = preds if return_proba else [self.c[np.argmax(pred)] for pred in preds]
if multilabel and self.c:
preds = activations.sigmoid(tf.convert_to_tensor(preds)).numpy()
elif self.c:
preds = activations.softmax(tf.convert_to_tensor(preds)).numpy()
else:
preds = np.squeeze(preds)
result = preds if return_proba or not self.c else [self.c[np.argmax(pred)] for pred in preds]
if multilabel:
result = [list(zip(self.c, r)) for r in result]
if is_str: return result[0]
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.8.2'
__version__ = '0.8.3'
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
'seqeval',
'packaging',
'tensorflow_datasets',
'transformers'
'transformers',
'ipython'
#'stellargraph>=0.8.2',
#'eli5 >= 0.10.0',
#'pillow'
Expand Down

0 comments on commit 55e00d8

Please sign in to comment.