-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IGNITE-5217: Add Gradient Descent and QR-based trainers for Linear Regression #3308
Conversation
datasets, add BarzilaiBorwein gradient descent updater.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should think about OLS, QR, SGD model-trainer hierarchy
* @param model linear regression model | ||
* @return formatted string representation | ||
*/ | ||
private static String formatLinearRegressionModelPrettyPrint(LinearRegressionModel model) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to add into LinearRegressionModel. I'd like the pretty print. Also you could add constraint in 30 or 50 vars to print out this model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added into toString
of LinearRegressionModel
.
System.out.println(">>> ---------------------------------"); | ||
System.out.println(">>> | Prediction\t| Ground Truth\t|"); | ||
System.out.println(">>> ---------------------------------"); | ||
for (double[] observation : data) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best way here to use LabeledDatasets (but it can be a separate ticket)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, will be done in separate task.
* @return This gradient descent instance | ||
*/ | ||
public GradientDescent withMaxIterations(int maxIterations) { | ||
if (maxIterations < 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly we use asserts to check input arguments, Let's discuss it with @ybabak @YuriBabak
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with asserts.
if (convergenceTol < 0) | ||
throw new IllegalArgumentException("Convergence tolerance must be non-negative but got " + convergenceTol); | ||
this.convergenceTol = convergenceTol; | ||
return this; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly we don't use chain and use only setters. Let's discuss it with @ybabak @YuriBabak
I'd like chain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed will keep this approach with chain methods.
*/ | ||
public Vector optimize(Matrix data, Vector initialWeights) { | ||
Vector weights = initialWeights, oldWeights = null, oldGradient = null; | ||
if (data instanceof SparseDistributedMatrix) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we refactor it to one method to avoid copy-paste?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, done.
long ts1 = System.currentTimeMillis(); | ||
Vector groundTruth = extractGroundTruth(data); | ||
Matrix inputs = extractInputs(data); | ||
long ts2 = System.currentTimeMillis(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget remove these timestamps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
83.6427296424, 27.4571268153, 73.5881193584, 27.1465364511, 79.4095449062}, -5.14077007134); | ||
|
||
/** */ | ||
public static class Dataset { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename this class to TestDataset or something else. Or use LabeledDataset after update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed.
|
||
// Check expected residuals from R | ||
mdl.estimateResiduals(); | ||
LinearRegressionSGDTrainer trainer = new LinearRegressionSGDTrainer(100_000, 1e-12); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this was about OLS Regression and you use SGDTrainer. It's strange
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally speaking SGDTrainer
also makes OLS Regression
just because it uses Least Square as a loss function. Anyway, I replaced it with QR trainer for better transparency.
20.6, 4.83567, 0.0, 18.1, 0.0, 0.583, 5.905, 53.2, 3.1523, 24.0, 666.0, 20.2, 388.22, 11.45 | ||
15.2, 0.15086, 0.0, 27.74, 0.0, 0.609, 5.454, 92.7, 1.8209, 4.0, 711.0, 20.1, 395.09, 18.06 | ||
7.0, 0.18337, 0.0, 27.74, 0.0, 0.609, 5.414, 98.3, 1.7554, 4.0, 711.0, 20.1, 344.05, 23.97 | ||
8.1, 0.20746, 0.0, 27.74, 0.0, 0.609, 5.093, 98.0, 1.8226, 4.0, 711.0, 20.1, 318.43, 29.68 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget add readme file with short description about attributes, licenses or link to the dataset repository
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README.txt with dataset descriptions added.
SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); | ||
|
||
System.out.println(">>> Create new linear regression trainer object."); | ||
Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OLS could be solved not by QR decomposition only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added separate examples for QRTrainer
and SGDTrainer
.
No description provided.