Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IGNITE-5217: Add Gradient Descent and QR-based trainers for Linear Regression #3308

Closed
wants to merge 20 commits into from

Conversation

dmitrievanthony
Copy link
Contributor

No description provided.

Copy link
Member

@zaleslaw zaleslaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should think about OLS, QR, SGD model-trainer hierarchy

* @param model linear regression model
* @return formatted string representation
*/
private static String formatLinearRegressionModelPrettyPrint(LinearRegressionModel model) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to add into LinearRegressionModel. I'd like the pretty print. Also you could add constraint in 30 or 50 vars to print out this model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added into toString of LinearRegressionModel.

System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
for (double[] observation : data) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best way here to use LabeledDatasets (but it can be a separate ticket)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will be done in separate task.

* @return This gradient descent instance
*/
public GradientDescent withMaxIterations(int maxIterations) {
if (maxIterations < 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly we use asserts to check input arguments, Let's discuss it with @ybabak @YuriBabak

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with asserts.

if (convergenceTol < 0)
throw new IllegalArgumentException("Convergence tolerance must be non-negative but got " + convergenceTol);
this.convergenceTol = convergenceTol;
return this;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly we don't use chain and use only setters. Let's discuss it with @ybabak @YuriBabak
I'd like chain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed will keep this approach with chain methods.

*/
public Vector optimize(Matrix data, Vector initialWeights) {
Vector weights = initialWeights, oldWeights = null, oldGradient = null;
if (data instanceof SparseDistributedMatrix) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactor it to one method to avoid copy-paste?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done.

long ts1 = System.currentTimeMillis();
Vector groundTruth = extractGroundTruth(data);
Matrix inputs = extractInputs(data);
long ts2 = System.currentTimeMillis();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget remove these timestamps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

83.6427296424, 27.4571268153, 73.5881193584, 27.1465364511, 79.4095449062}, -5.14077007134);

/** */
public static class Dataset {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this class to TestDataset or something else. Or use LabeledDataset after update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed.


// Check expected residuals from R
mdl.estimateResiduals();
LinearRegressionSGDTrainer trainer = new LinearRegressionSGDTrainer(100_000, 1e-12);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this was about OLS Regression and you use SGDTrainer. It's strange

Copy link
Contributor Author

@dmitrievanthony dmitrievanthony Dec 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking SGDTrainer also makes OLS Regression just because it uses Least Square as a loss function. Anyway, I replaced it with QR trainer for better transparency.

20.6, 4.83567, 0.0, 18.1, 0.0, 0.583, 5.905, 53.2, 3.1523, 24.0, 666.0, 20.2, 388.22, 11.45
15.2, 0.15086, 0.0, 27.74, 0.0, 0.609, 5.454, 92.7, 1.8209, 4.0, 711.0, 20.1, 395.09, 18.06
7.0, 0.18337, 0.0, 27.74, 0.0, 0.609, 5.414, 98.3, 1.7554, 4.0, 711.0, 20.1, 344.05, 23.97
8.1, 0.20746, 0.0, 27.74, 0.0, 0.609, 5.093, 98.0, 1.8226, 4.0, 711.0, 20.1, 318.43, 29.68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget add readme file with short description about attributes, licenses or link to the dataset repository

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

README.txt with dataset descriptions added.

SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data);

System.out.println(">>> Create new linear regression trainer object.");
Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OLS could be solved not by QR decomposition only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added separate examples for QRTrainer and SGDTrainer.

@asfgit asfgit closed this in b206085 Dec 28, 2017
@dmitrievanthony dmitrievanthony deleted the ignite-5217 branch February 6, 2018 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants