-
Notifications
You must be signed in to change notification settings - Fork 50
Open
Description
Solved but could benefit from an improved error message! After adding a transform that reduces the number of tasks, .predict_on_dataset() gives a ValueError related to mismatched shapes, if return_df=True. Returning an array works fine., as well as .predict_on_seqs()
transform = Specificity(
on_tasks = [name],
on_aggfunc = "min",
off_tasks = [x for x in cell_types if x != name],
off_aggfunc = "max",
model = model,
)
model.add_transform(transform)
seqs_ds = grelu.data.dataset.SeqDataset(seqs)
probs = model.predict_on_dataset(seqs_ds, devices=0, num_workers=7)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[39], line 1
----> 1 probs = model.predict_on_dataset(seqs_ds, devices=0, num_workers=7, return_df=True)
File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/grelu/lightning/__init__.py:782, in LightningModel.predict_on_dataset(self, dataset, devices, num_workers, batch_size, augment_aggfunc, return_df, precision)
780 if return_df:
781 if (preds.ndim == 3) and (preds.shape[-1] == 1):
--> 782 preds = pd.DataFrame(
783 preds.squeeze(-1), columns=self.data_params["tasks"]["name"]
784 )
785 else:
786 warnings.warn(
787 "Cannot produce dataframe output."
788 + "Either output length > 1 or augmented sequences are not aggregated."
789 )
File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/pandas/core/frame.py:831, in DataFrame.__init__(self, data, index, columns, dtype, copy)
820 mgr = dict_to_mgr(
821 # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no
822 # attribute "name"
(...) 828 copy=_copy,
829 )
830 else:
--> 831 mgr = ndarray_to_mgr(
832 data,
833 index,
834 columns,
835 dtype=dtype,
836 copy=copy,
837 typ=manager,
838 )
840 # For data is list-like, or Iterable (will consume into list)
841 elif is_list_like(data):
File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/pandas/core/internals/construction.py:336, in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
331 # _prep_ndarraylike ensures that values.ndim == 2 at this point
332 index, columns = _get_axes(
333 values.shape[0], values.shape[1], index=index, columns=columns
334 )
--> 336 _check_values_indices_shape_match(values, index, columns)
338 if typ == "array":
339 if issubclass(values.dtype.type, str):
File ~/scratch/conda/envs/grelu_stable/lib/python3.11/site-packages/pandas/core/internals/construction.py:420, in _check_values_indices_shape_match(values, index, columns)
418 passed = values.shape
419 implied = (len(index), len(columns))
--> 420 raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")
ValueError: Shape of passed values is (5078, 1), indices imply (5078, 8)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels