Skip to content

Commit

Permalink
squeeze() all the tensors in to_table
Browse files Browse the repository at this point in the history
The to_table method now squeezes all the tensors before appending them to the astropy table
  • Loading branch information
sundarjhu committed Sep 1, 2023
1 parent dd5f5d9 commit 47bc847
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions pgmuvi/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,31 +2016,31 @@ def to_table(self):
if self.__FITTED_MCMC or self.__FITTED_MAP:
# These outputs can only be produced if a fit has been run.
periods, weights, scales = self.get_periods()
t['period'] = [np.asarray(periods)]
t['period'] = [np.asarray(periods).squeeze()]
try:
t['weights'] = [np.asarray(weights)]
t['weights'] = [np.asarray(weights).squeeze()]
except RuntimeError:
t['weights'] = [torch.as_tensor(weights).detach().numpy()]
t['weights'] = [torch.as_tensor(weights).detach().numpy().squeeze()]
try:
t['scales'] = [np.asarray(scales)]
t['scales'] = [np.asarray(scales).squeeze()]
except RuntimeError:
t['scales'] = [torch.as_tensor(scales).detach().numpy()]
t['scales'] = [torch.as_tensor(scales).detach().numpy().squeeze()]
for key, value in self.results.items():
try:
t[key] = [np.asarray(value)]
t[key] = [np.asarray(value).squeeze()]
except RuntimeError:
t[key] = [torch.as_tensor(value).detach().numpy()]
t[key] = [torch.as_tensor(value).detach().numpy().squeeze()]
if self.__FITTED_MAP:
# Loss isn't relevant for MCMC, I think
t['loss'] = [np.asarray(self.results['loss'])]
t['loss'] = [np.asarray(self.results['loss']).squeeze()]
# Now we want the model predictions for the input times:
if self.__FITTED_MAP:
self._eval()
with torch.no_grad():
observed_pred = self.likelihood(self.model(self._xdata_transformed))
t['y_pred_mean_obs'] = [np.asarray(observed_pred.mean)]
t['y_pred_lower_obs'] = [np.asarray(observed_pred.confidence_region()[0])] # noqa: E501
t['y_pred_upper_obs'] = [np.asarray(observed_pred.confidence_region()[1])] # noqa: E501
t['y_pred_mean_obs'] = [np.asarray(observed_pred.mean).squeeze()]
t['y_pred_lower_obs'] = [np.asarray(observed_pred.confidence_region()[0]).squeeze()] # noqa: E501
t['y_pred_upper_obs'] = [np.asarray(observed_pred.confidence_region()[1]).squeeze()] # noqa: E501

if self.ndim == 1:
x_raw = self.xdata
Expand All @@ -2057,10 +2057,10 @@ def to_table(self):

# Make predictions
observed_pred = self.likelihood(self.model(x_fine_transformed))
t['x_fine'] = [np.asarray(x_fine_raw)]
t['y_pred_mean'] = [np.asarray(observed_pred.mean)]
t['y_pred_lower'] = [np.asarray(observed_pred.confidence_region()[0])] # noqa: E501
t['y_pred_upper'] = [np.asarray(observed_pred.confidence_region()[1])] # noqa: E501
t['x_fine'] = [np.asarray(x_fine_raw).squeeze()]
t['y_pred_mean'] = [np.asarray(observed_pred.mean).squeeze()]
t['y_pred_lower'] = [np.asarray(observed_pred.confidence_region()[0]).squeeze()] # noqa: E501
t['y_pred_upper'] = [np.asarray(observed_pred.confidence_region()[1]).squeeze()] # noqa: E501
elif self.__FITTED_MCMC:
raise NotImplementedError("MCMC predictions not yet implemented")
# with torch.no_grad():
Expand Down

0 comments on commit 47bc847

Please sign in to comment.