Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neural-ODE] Neural Ordinary Differential Equations #21

Open
Yagami360 opened this issue Sep 28, 2019 · 0 comments
Open

[Neural-ODE] Neural Ordinary Differential Equations #21

Yagami360 opened this issue Sep 28, 2019 · 0 comments

Comments

@Yagami360
Copy link
Owner

Yagami360 commented Sep 28, 2019

0. 論文情報・リンク

1. どんなもの?

  • ニューラルネットワークを層の ”深さ方向に” 連続化して、ニューラルネットワークを時間(=層の深さに対応)に対する微分方程式でモデル化し、逆伝搬処理を adjoint sensitivity method で解くことにより、誤差逆伝播法におけるメモリ効率向上やネットワークのパラメータ数を抑えることなどを実現したニューラルネットワーク。NIPS2018 best paper

2. 先行研究と比べてどこがすごいの?

  • 従来の離散的なニューラルネットワークでは、誤差逆伝播法での学習時に隠れ層での状態を都度メモリに保持しておく必要があった。本手法では、連続化して微分方程式でモデル化したネットワークを、 adjoint sensitivity method で逆方向に解けば良くなるので、隠れ層での状態をメモリに保持しておく必要がなく、よりメモリ効率のよいモデルになっている。
  • 従来の離散的なニューラルネットワークでは、層を追加する度にネットワークのパラメータ数が増加していたが、本手法では、層を追加しようがしまいが微分方程式での固有のパラメータ数のみで計算可能であるため、パラメータ数を抑えることが出来る。
  • また副作用として、フローベースの生成モデルにおいて本手法を適用することで、(フローベースの生成モデルの目的である観測データの)対数尤度最大化の計算コストを抑えることが出来る。

3. 技術や手法の"キモ"はどこにある?

  • ニューラルネットワークの深さ方向への連続化と微分方程式:
    ResNet においては、ネットワークの出力 z ではなく残差関数 f を学習対象とし、その更新式は以下のようになる。

    1回の更新幅 Δt=(t+1)−t=1 を微小化してネットワークを深さ方向に連続化すると、この更新式は以下のようになる。

    この式は、時間に関する微分方程式となっており、Neural-ODE [ODE-Net] という。
    この微分方程式は、オイラー法などの ODE-Solver(微分方程式の数値計算法)を用いて解いていくことになる。

    以下の図は、ネットワークを深さ方向に連続化した場合の各層からの出力値を図示したものである。従来の離散的なネットワークでは出力値も離散的であったのに対して、Neural-ODE では出力が連続的に滑らかなものになっていることが見て取れる。

    • 誤差逆伝播時のメモリ効率の向上:
      従来の離散的なニューラルネットワークでは、誤差逆伝播法での学習時に隠れ層での状態を都度メモリに保持しておく必要があったが、本手法では、後述で説明する adjoint sensitivity method での逆伝搬の微分方程式での解法により、単純に微分方程式を逆方向に解けばよくなるので、隠れ層での状態をメモリに保持しておく必要がなく、結果としてよりメモリ効率のよいモデルになっている。
      ※ 単純にニューラルネットワークを連続化して微分方程式でモデル化しただけでは、この誤差逆伝播時のメモリ効率向上のメリットは享受できないことに注意。後述の adjoint sensitivity method での微分方程式の解法の工夫があって初めてこのメリットを享受できる。

    • アプリケーションに応じた ODE-Solver の柔軟な選択
      オイラー法やルンゲクッタ法のなどの ODE-Solver には、それぞれ計算コストや近似精度、安定性などのメリットデメリットが存在する。本手法では、解きたい問題やアプリケーションに応じて ODE-Solver を柔軟に選択出来るので、それらのトレードオフをアプリケーションに応じて柔軟に取捨選択できる。

    • ネットワークのパラメータ数を抑える。
      従来の離散的なニューラルネットワークでは、層を追加する度にネットワークのパラメータ数が増加していたが、本手法では、層を追加しようがしまいが微分方程式での固有のパラメータ数のみで計算可能であるたm、パラメータ数を抑えることが出来る。
      ※ ネットワークのパラメータ数を抑えるの意味でのメモリ効率の向上は、(adjoint sensitivity method の工夫がなくとも、)ネットワークを連続化して微分方程式でモデル化しただけでも享受できる。

    • フローベースの生成モデルにおける normalizing flows
      また副作用として、フローベースの生成モデルにおいて、本手法を適用することで、フローベースの生成モデルでボトルネックとなっているヤコビアンの計算を行わずとも、フローベースの生成モデルの目的である対数尤度最大化の計算が可能となり、計算コストを抑えることが出来る。

  • Neural-ODE の解法:
    Neural-ODE を ODE-Solver で解く上で困難であるのは、ニューラルネットワーク特有の誤差逆伝搬処理となる。本手法では、この逆伝搬処理を既存の adjoint sensitivity method で解いている。これにより、逆伝搬時に勾配値を保存しておく必要がなくなるので、メモリ効率を向上させることが出来る。

    1. 順伝搬処理
      まず、誤差逆伝搬を行うためには、損失関数 L を求める必要があるが、これは Neural-ODE においては、この損失関数は以下の(順伝搬処理での)連続化された式で求まる。

    2. 逆伝搬処理
      誤差逆伝播法では、この損失関数が最小化するように最適化しながら学習を進めていくのであるが、そのためには、損失関数 L(z,θ,t_0,t_1) の各独立変数 z,θ, t_0, t_1 に関する勾配計算(=偏微分計算)が必要となる。

      この adjoint sensitivity method では、このような勾配計算を行うために、まず adjoint と呼ばれる以下のような隠れ層の状態 z での損失関数 L の勾配を表す量を定義する。
      ※ 同様にして、その他の独立変数 θ, t_0, t_1 に関する勾配も定義するが、ここでは z のみの導出を示す。

      すると、この adjoint に関して、以下のような別の微分方程式が成り立つ。(※この式の導出は省略)

      この adjoint に関する微分方程式は逆伝搬での微分方程式 [reverse-mode derivative of a ODE] であるので、adjoint の初期値は t_0 ではなく t_N であり、以下の図のように右側の初期値 a(t_N )=∂L/∂z(t_N ) から順に逆方向に求めていくことになる。

      以下のアルゴリズムは、順伝搬処理で損失関数 L を求めたあとの、逆伝搬処理での具体的な解法を示している。
      image

    この際に、a(t)=∂L/∂z(t) の定義より、この微分方程式を解くには、任意の t における z の値が求まってなくてはならないが、以下の図のように、この z の値は、adjoint とともに求まるので問題ない?

  • フローベースの生成モデルへの適用
    フローベースの生成モデルの目的は、観測データ z を生成する確率分布 p(z) の対数尤度の最大化であるが、この対数尤度の計算には、ヤコビアンの計算コストの大きさががボトルネックとなっている。
    ※ フローベースの生成モデルの詳細については、「生成モデル|フローベースの生成モデル」 を参照。

    フローベースの生成モデルにおいて、本手法を適用しネットワークを連続化することで、以下の定理のように、このヤコビアンの計算を行わずとも、対数尤度の計算が可能となる。

4. どうやって有効だと検証した?

  • メモリ効率の比較実験

    上表は、MNIST での既存の手法との定量的比較結果を示している。
    ”同程度の識別率となるように調整した” ResNet と比較すると、パラメータ数がより少なくメモリ使用量も O(1) オーダーで済んでいることが見てとれる。

  • 時間発展系における RNN との近似能力の比較実験

    上図は、RNN との時間発展系の近似能力を比較した図である。
    RNN では、長期依存性はうまく近似できていないのに対し、本手法では、長期依存性もうまく近似出来ていることが見てとれる。

5. 議論はあるか?

  • ニューラルネットワークを連続化して微分方程式でモデル化するという考え自体は、割と誰でも思いつく発想。困難であるのは、ニューラルネットワーク特有の誤差逆伝播処理を微分方程式でどうモデル化するか?という点であるが、この論文では、adjoint sensitivity method で単純に微分方程式を逆方向に解けばよくしていており、これがこの論文の注目点になっている。
  • 単純にニューラルネットワークを連続化して微分方程式でモデル化するだけでは、誤差逆伝播時のメモリ効率向上のメリット等は享受できないことに注意。adjoint sensitivity method での微分方程式の解法の工夫があって初めてこのメリットを享受できる。(※ 但し、ネットワークのパラメータ数を抑えるの意味でのメモリ効率の向上は、ネットワークを連続化して微分方程式でモデル化しただけでも享受できる。)
  • 本手法では、ニューラルネットワークを深さ方向に連続化しているが、各層の幅方向にも連続化しているのか?
  • ニューラルネットワークの積分表現理論との絡みは?→ ニューラルネットワークの深さ方向のフロー表現(輸送解釈、輸送写像、最適輸送問題)

6. 次に読むべき論文はあるか?

  • xxx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

No branches or pull requests

1 participant