Skip to content
Permalink
Browse files

ND4Jでのベクトル生成サンプルコード、乱数も実装

  • Loading branch information...
takk
takk committed Feb 19, 2019
1 parent 561db59 commit c274c7e359b072f2f4553decf45f7e86734a1e9e
Showing with 65 additions and 7 deletions.
  1. +65 −7 PracticeJava1/src/jp/zenryoku/ml/nd4j/FirstNd4j.java
@@ -1,6 +1,7 @@
package jp.zenryoku.ml.nd4j;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.factory.Nd4j;

/**
@@ -10,6 +11,70 @@
*/
public class FirstNd4j {
public static void main(String[] args) {
// クラスのインスタンス化
FirstNd4j test = new FirstNd4j();

INDArray zeros5 = Nd4j.zeros(5);
System.out.print("Nd4j.zeros(5) -> ");
System.out.println(zeros5);
// 自分で作成したメソッドを使います。
INDArray zero5_3 = test.createZeroINDArray(5, 3);
System.out.print("Nd4j.zeros(5, 3) -> ");
System.out.println(zero5_3);

// 生成したNDArrayを3で埋める
INDArray add3 = zero5_3.add(3);
// 自作メソッドで表示します。
printArray("INDArray.add(3)", add3);

System.out.println("*** 0-1の間で乱数を生成する(2次元配列) ***");
printArray("Nd4j.rand(2, 3)", Nd4j.rand(2, 3));

System.out.println("*** 3次元以上 ***");
printArray("Nd4j.rand(new int[] {1, 2, 3}])", Nd4j.rand(new int[] {1, 2, 3}));

System.out.println("*** 平均ゼロ、標準偏差1のガウス乱数を生成する ***");
printArray("Nd4j.randn(2, 3)", Nd4j.randn(2, 3));
printArray("Nd4j.randn(new int[] {1, 2, 3}])", Nd4j.randn(new int[] {1, 2, 3}));
}

/**
* 制定メソッドです。
* このクラスを起動するときに他の読み込むクラスに同じメソッドは定義できません。
* 実行時エラーになります。
*
* @param formula 表示する計算式
* @param vector ベクトル(配列)
*/
public static void printArray(String formula, INDArray vector) {
System.out.print(formula + " -> ");
System.out.println(vector);
}
/**
* このクラスのインスタンスメソッドです。
* 0のベクトルを生成します。
*
* @param row 行の数
* @param col 列の数
* @return 生成された行列
*/
private INDArray createZeroINDArray(int row, int col) {
return Nd4j.zeros(row, col);
}

/**
* 1.初めてのND4Jです、ベクトル(数値配列)を生成します。
*/
private void helloNd4j() {
INDArray arr = Nd4j.zeros(3);
System.out.println(arr);
}

/**
* 2.NDArrayを作成する処理
* 配列の演算処理(簡単なもの)。
*/
public void createNDArray_1() {
/*
* 配列の計算をする
* x=(1.0, 2.0), y=(5.0, 7.0)を計算する
@@ -34,13 +99,6 @@ public static void main(String[] args) {
System.out.println("割り算の答えの0番目(INT)" + answer.getInt(0, 1));
System.out.println("割り算の答えの0番目(Double)" + answer.getDouble(1, 0));
System.out.println("割り算の答えの0番目(Double)" + answer.getDouble(1, 1));
}

/**
* 初めてのND4Jです、ベクトル(数値配列)を生成します。
*/
private void helloNd4j() {
INDArray arr = Nd4j.zeros(3);
System.out.println(arr);
}
}

0 comments on commit c274c7e

Please sign in to comment.
You can’t perform that action at this time.