Skip to content

Commit

Permalink
ND4Jでのベクトル生成サンプルコード、乱数も実装
Browse files Browse the repository at this point in the history
  • Loading branch information
takk committed Feb 19, 2019
1 parent 561db59 commit c274c7e
Showing 1 changed file with 65 additions and 7 deletions.
72 changes: 65 additions & 7 deletions PracticeJava1/src/jp/zenryoku/ml/nd4j/FirstNd4j.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jp.zenryoku.ml.nd4j;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.factory.Nd4j;

/**
Expand All @@ -10,6 +11,70 @@
*/
public class FirstNd4j {
public static void main(String[] args) {
// クラスのインスタンス化
FirstNd4j test = new FirstNd4j();

INDArray zeros5 = Nd4j.zeros(5);
System.out.print("Nd4j.zeros(5) -> ");
System.out.println(zeros5);
// 自分で作成したメソッドを使います。
INDArray zero5_3 = test.createZeroINDArray(5, 3);
System.out.print("Nd4j.zeros(5, 3) -> ");
System.out.println(zero5_3);

// 生成したNDArrayを3で埋める
INDArray add3 = zero5_3.add(3);
// 自作メソッドで表示します。
printArray("INDArray.add(3)", add3);

System.out.println("*** 0-1の間で乱数を生成する(2次元配列) ***");
printArray("Nd4j.rand(2, 3)", Nd4j.rand(2, 3));

System.out.println("*** 3次元以上 ***");
printArray("Nd4j.rand(new int[] {1, 2, 3}])", Nd4j.rand(new int[] {1, 2, 3}));

System.out.println("*** 平均ゼロ、標準偏差1のガウス乱数を生成する ***");
printArray("Nd4j.randn(2, 3)", Nd4j.randn(2, 3));
printArray("Nd4j.randn(new int[] {1, 2, 3}])", Nd4j.randn(new int[] {1, 2, 3}));
}

/**
* 制定メソッドです。
* このクラスを起動するときに他の読み込むクラスに同じメソッドは定義できません。
* 実行時エラーになります。
*
* @param formula 表示する計算式
* @param vector ベクトル(配列)
*/
public static void printArray(String formula, INDArray vector) {
System.out.print(formula + " -> ");
System.out.println(vector);
}
/**
* このクラスのインスタンスメソッドです。
* 0のベクトルを生成します。
*
* @param row 行の数
* @param col 列の数
* @return 生成された行列
*/
private INDArray createZeroINDArray(int row, int col) {
return Nd4j.zeros(row, col);
}

/**
* 1.初めてのND4Jです、ベクトル(数値配列)を生成します。
*/
private void helloNd4j() {
INDArray arr = Nd4j.zeros(3);
System.out.println(arr);
}

/**
* 2.NDArrayを作成する処理
* 配列の演算処理(簡単なもの)。
*/
public void createNDArray_1() {
/*
* 配列の計算をする
* x=(1.0, 2.0), y=(5.0, 7.0)を計算する
Expand All @@ -34,13 +99,6 @@ public static void main(String[] args) {
System.out.println("割り算の答えの0番目(INT)" + answer.getInt(0, 1));
System.out.println("割り算の答えの0番目(Double)" + answer.getDouble(1, 0));
System.out.println("割り算の答えの0番目(Double)" + answer.getDouble(1, 1));
}

/**
* 初めてのND4Jです、ベクトル(数値配列)を生成します。
*/
private void helloNd4j() {
INDArray arr = Nd4j.zeros(3);
System.out.println(arr);
}
}

0 comments on commit c274c7e

Please sign in to comment.