-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradient-descent.ts
109 lines (92 loc) · 2.75 KB
/
gradient-descent.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//h(x) = theta0 + theta1*x;
interface point {
x: number,
y: number
}
class GradientDescent {
private sampleSize = 30;
private learningRate = 0.01;
private maxIteration = 1000;
private trainingData: point[] = [];
private isConverged = false;
private theta0 = 0;
private theta1 = 0;
private threshold = 0.0001;
constructor(sampleSize = 30, learningRate = 0.01) {
this.sampleSize = sampleSize;
this.learningRate = learningRate;
this.init();
}
private init() {
this.prepareTrainingData();
}
public start() {
let iteration = 1;
while (!this.isConverged && iteration < this.maxIteration) {
console.log('\niteration:' + iteration);
this.iterate();
iteration++;
}
}
private iterate() {
var mse_before = this.J();
let temp0 = this.theta0 - (this.learningRate * this.derivativeTheta0());
let temp1 = this.theta1 - (this.learningRate * this.derivativeTheta1());
this.theta0 = temp0;
this.theta1 = temp1;
console.log("theta0:" + this.theta0);
console.log("theta1:" + this.theta1);
this.isConverged = (mse_before - this.J() < this.threshold);
}
private J() {
var sum = 0;
this.trainingData.forEach((point) => {
sum += this.squaredError(point);
});
return sum / (2 * this.trainingData.length);
}
private squaredError(point: point) {
return Math.pow(this.predictionError(point), 2);
}
private predictionError(point: point) {
return this.h(point.x) - point.y;
}
//hypothesis
private h(x: number) {
return this.theta1 * x + this.theta0;
}
private derivativeTheta0() {
var sum = 0;
this.trainingData.forEach((point) => {
sum += this.predictionError(point);
});
return sum / this.trainingData.length;
}
private derivativeTheta1() {
var sum = 0;
this.trainingData.forEach((point) => {
sum += this.predictionError(point) * point.x;
});
return sum / this.trainingData.length;
}
private prepareTrainingData() {
for (var i = 0; i < this.sampleSize; i++) {
var point = this.getRandomPoint(10);
this.trainingData.push(point);
}
}
private getRandomPoint(max: number): point {
let slope = 0.5;
let intercept = 2.5;
let stddev = 0.9;
let x = Math.round(Math.random() * max);
let y = slope * x + intercept + Math.random() * stddev;
console.log('data points');
console.log(x + " " + y);
return {
x: x,
y: y
};
}
}
new GradientDescent().start();