Skip to content

「Feedback Aligment를 이용한 신경망 학습 알고리즘 구현」에 대한 내용을 다루고 있습니다.

License

Notifications You must be signed in to change notification settings

YAGI0423/feedback_alignment

Repository files navigation

이 저장소(Repository)는 「Feedback Aligment를 이용한 신경망 학습 알고리즘 구현」에 대한 내용을 다루고 있습니다.



작성자: YAGI

최종 수정일: 2022-11-12

  • 2022.11.10: 코드 작성 완료(Task 1 ~ 3)
  • 2022.11.11: README.md 작성 완료
  • 2022.11.12: 코드 및 README 최종 정리



  • 프로젝트 기간: 2022-06-27 ~ 2022-11-12



  • 해당 프로젝트는 Timothy P. Lillicrap 외 3인의 「Random feedback weights support learning in deep neural networks」(2014)를 바탕으로 하고 있습니다.

Timothy P. Lillicrap, Daniel Cownden, Douglas B. Tweed, Colin J. Akerman. Random feedback weights support learning in deep neural networks. ArXiv, 1411.0247v1, 2014.




프로젝트 요약

   오차 역전파(Backpropagation of error)는 현재 가장 강력한 딥러닝 네트워크 학습 알고리즘이다. 하지만, 역전파는 뉴런이 기여하는 영향을 정확하게 계산하여 오류 신호를 하류의 뉴런에 할당하는데, 이는 생물학적으로 수용하기 어렵다. Timothy P. Lillicrap 외 3인은 역전파에서 사용하는 가중치의 전치 대신, '무작위 시냅스 가중치(random synaptic weights)'를 오류 신호와 곱하여 영향을 할당하는 Feedback Alignment 알고리즘(이하 FA)을 제시하였다. 나아가, 특정 작업에 대한 FA 알고리즘의 성능을 역전파 알고리즘과 비교하여 확인하였다. 성능 비교는 Task (1) 선형 함수 근사, Task (2) MNIST 데이터셋, Task (3) 비선형 함수 근사를 통해 이루어졌다. 세 Task 모두 손실함수로, $L = (1/2)e^Te$를 사용하며, $e = y^* - y$로, $e$는 예측과 실제 출력의 차이이다. 본 프로젝트는 앞선 세 Task를 구현하는 것을 목표로 한다.


Task (1) Linear function approximation

  30-20-10 선형 네트워크가 선형 함수, $T$를 근사하도록 학습한다. 입·출력 학습 쌍은 $x ~ N(μ=0, ∑=I)$으로 $y^* = Tx$를 통해 생성한다. 목표 선형 함수 $T$는 30차원 공간의 벡터를 10차원으로 매핑하였으며, $[-1, 1]$ 범위로부터 균일하게 추출하였다. 오차 역전파의 네트워크 가중치 $W_0$, $W$$[-0.01, 0.01]$에서 균일하게 추출하여 초기화 하였다. FA의 random feedback weight인 $B$는 균일(uniform) 분포 $[-0.5, 0.5]$에서 추출 한다. 각 알고리즘의 학습률, η는 학습 속도의 최적화를 위해 수동 탐색(manual search)을 통해 선택하였다. ...(Timothy P. Lillicrap et al.)

   figure 1은 네 알고리즘의 선형 함수에 대한 손실 변화를 제시한 것으로 'shallow' 학습(옅은 회색), 강화 학습(어두운 회색), 오차 역전파(검정), 그리고 피드백 정렬(초록)이다.


figure 1. Error on Test Set of Paper's Task (1) Linear function approximation(Timothy P. Lillicrap et al.)



   본 프로젝트에서는 학습률을 0.001, 배치 크기는 32로 설정하였으며, Epoch은 1,000회 수행하였다. 데이터셋의 경우 입·출력 데이터 모두 Min-Max 정규화 전처리를 진행하였다. figure 2는 학습 및 테스트 데이터셋에 대한 오차 역전파와 FA의 선형 함수 근사의 손실 변화를 시각화한 것으로 오차 역전파(검정), FA(초록)이다.


figure 2. Error of Project's Task (1) Linear function approximation



Task (2) MNIST dataset

  표준 시그모이드 은닉과 출력 유닛(즉, $σ{(x)} = 1/{(1+exp(-x))}$)의 784-1000-10 네트워크는 0-9의 필기 숫자 이미지를 분류하도록 학습되었다. 네트워크는 기본 MNIST 데이터셋 60,000개 이미지로 학습되었으며, 성능 측정은 10,000개의 이미지 테스트 셋을 사용하였다. 학습률은 $η = 10^{-3}$ 그리고 weight decay$α = 10^{-6}$이 사용되었다. ...(Timothy P. Lillicrap et al.)

   figure 3은 10,000개의 MNIST 테스트 셋에 대한 오차 역전파(검정), FA(초록)의 손실 곡선을 제시한 것이다.


figure 3. Error on Test Set of Paper's Task (2) MNIST dataset(Timothy P. Lillicrap et al.)



   본 프로젝트에서는 배치 크기 32로 설정하고 Epoch은 20회 수행하였다. 입·출력 데이터 모두 Min-Max 정규화 전처리를 진행하였다. weight decay는 사용하지 않았다. 네트워크 가중치는 $[-0.01, 0.01]$ 범위에서 균일하게 추출하여 초기화하였다. figure 4는 MNIST 학습 및 테스트 데이터셋에 대한 오차 역전파와 FA의 선형 함수 근사의 손실 변화를 시각화한 것으로 오차 역전파(검정), FA(초록)이다.


figure 4. Error of Project's Task (2) MNIST dataset



Task (3) Nonlinear function approximation

  30-20-10 그리고 30-20-10-10 네트워크는 30-20-10-10의 목표(target) 네트워크의 출력을 근사하도록 학습한다. 세 개의 모든 네트워크는 $tanh(·)$의 은닉 유닛, 선형 출력 유닛을 가진다. 입·출력 학습 쌍은, $x ~ N(μ=0, ∑=I)$인, $y^* = W_2·tanh(W_1·tanh(W_0·x + b_0) + b_1) + b_2$으로 $y^*_i = T(x_i)$를 통해 생성되었다. 목표 네트워크 $T(·)$에 대한 매개변수는 무작위로 선택되었다.FA의 random feedback wieght, $B_1$$B_2$는 수동으로 선택한 매개변수 척도(scale)를 이용하여 균일 분포에서 추출하였다. ...(Timothy P. Lillicrap et al.)

   figure 5는 비선형 함수 근사 문제에 대한 각 평균 20회 이상 시도한 손실 곡선으로 세 층의 네트워크는 shallow 학습(회색), 오차 역전파(검정), 그리고 피드백 정렬(초록)이며, 네 층의 네트워크는 오차 역전파(마젠타) 그리고 피드백 정렬(파랑)으로 학습되었다.


figure 5. Error on Test Set of Paper's Task (3) Nonlinear Function approximation(Timothy P. Lillicrap et al.)



   본 프로젝트에서는 학습률을 0.001, 배치 크기 4로 설정하고 Epoch은 10회 수행하였다. 입·출력 데이터 모두 Min-Max 정규화 전처리를 진행하였다. 네트워크 가중치는 $[-0.01, 0.01]$ 범위에서 균일하게 추출하여 초기화하였다. figure 6은 비선형 함수 근사 문제에 대한 손실 곡선으로 세 층의 네트워크는 오차 역전파(검정), 그리고 피드백 정렬(초록)이며, 네 층의 네트워크는 오차 역전파(마젠타) 그리고 피드백 정렬(파랑)으로 학습되었다. 학습 셋에 대한 손실은 500 이동 평균을 시각화 한 것이다.


figure 6. Error of Project's Task (3) Nonlinear function approximation



Getting Start


각 Task는 [TASK NAME].py 파일을 실행하여 수행할 수 있다. 네트워크 학습 및 추론이 종료 되면 /plot/images/ 경로에 시각화 이미지가 저장된다.


Task (1) Linear function approximation

$ python task1_linearFunction.py

Task (2) MNIST dataset

MNIST 데이터셋은 첨부하지 않았으므로 /datasets/ 경로에 별도의 데이터셋을 위치시켜야 한다.

$ python task2_mnistDataset.py

Task (3) Nonlinear function approximation

$ python task3_nonlinearFunction.py



License

This project is licensed under the terms of the MIT license.

About

「Feedback Aligment를 이용한 신경망 학습 알고리즘 구현」에 대한 내용을 다루고 있습니다.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages