In [None]:
#Stochastic gradient descent is the workhorse of machine learning. It is absolutely not clear, however, why it should work on
#a non-convex function. Stochastic analysis offers a convincing explanation of this: the invariant law of the SGD equation is
#an exponential concentrating at the infimum of the non-convex function. Therefore, SGD sample paths converge to the infimum
#(in probability), and this explains why gradient descent performs so well with batching: this yields the noise required for the SDE,
#hence the name "stochastic" gradient descent. This script implements an animation for an SDE in a non-convex potential and plots
#hundreds of paths as they wiggle around to the infimum. The corresponding Fokker-Planck equation can simultaneously be solved with
#the script "Convection diffusion PDE solver" to reveal the marginals laws as the system evolves.

%matplotlib inline
#!pip install matplotlib==3.3.0
import autograd.numpy as np  # Thinly-wrapped numpy
from autograd import grad
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig=plt.figure(figsize=(13,13),dpi=108)
ax1=fig.add_subplot(111,projection='3d')
ax1.set_box_aspect((1,1,1))
#ax2=fig.add_subplot(122,projection='3d')
from google.colab import drive
from google.colab import files
import os

!pip install -U -q PyDrive
#pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

#drive.mount("/content/drive")

def phi(X,mu,sig):
  return 1/np.sqrt(2*np.pi)*np.exp(-1/(2*sig)*np.matmul(np.transpose(X-mu),X-mu))

def f(x,y):
  #return (1+np.power(x+y+1,2)*(19-14*x+3*np.power(x,2)-14*y+6*x*y+3*np.power(y,2))*
  #        (30+np.power(2*x-3*y,2)*(18-32*x+12*np.power(x,2)+48*y-36*x*y+27*np.power(y,2))))
  return 1/5*(y+np.sin(2*y)+np.power(x,4)-np.power(x,2)-0.4*x+0.28*np.sin(7*x))

def s(x):
  return (2/(1+np.exp(x))-1)
 
def g(x,y):
  return s((-np.power(x-1.1,2)-np.power(y-1.5,2)))/1.6  + s(-np.power(x+y,2)) + s((-np.power(x-2,2)-np.power(y-0.2,2))*2)/1.8

#def g(x,y):
#  return s((-(x-1.1)**2-(y-1.5)**2))/1.6  + s(-(x+y)**2) + s((-(x-2)**2-(y-0.2)**2)*2)/1.8
#  return (2/(1+exp(((-(x-1.1)**2-(y-1.5)**2))))-1)/1.6  + (2/(1+exp(-(x+y)**2))-1) + (2/(1+exp((-(x-2)**2-(y-0.2)**2)*2))-1)/1.8

def jx(x,y): #closed form x derivative
  #return -1.25*(2.2 - 2.0*x)*np.exp(-1.21*(0.909090909090909*x - 1)**2 - 2.25*(0.666666666666667*y - 1)**2)/(np.exp(-1.21*(0.909090909090909*x - 1)**2 - 2.25*(0.666666666666667*y - 1)**2) + 1)**2 - 1.11111111111111*(8 - 4*x)*np.exp(-2*(x - 2)**2 - 2*(y - 0.2)**2)/(np.exp(-2*(x - 2)**2 - 2*(y - 0.2)**2) + 1)**2 - 2*(-2*x - 2*y)*np.exp(-(x + y)**2)/(1 + np.exp(-(x + y)**2))**2
  return 1/5*(4*np.power(x,3)-2*x-0.4+0.28*7*np.cos(7*x))

def jy(x,y): #closed form y derivative
  #return -1.11111111111111*(0.8 - 4*y)*np.exp(-2*(x - 2)**2 - 2*(y - 0.2)**2)/(np.exp(-2*(x - 2)**2 - 2*(y - 0.2)**2) + 1)**2 - 1.25*(3.0 - 2.0*y)*np.exp(-1.21*(0.909090909090909*x - 1)**2 - 2.25*(0.666666666666667*y - 1)**2)/(np.exp(-1.21*(0.909090909090909*x - 1)**2 - 2.25*(0.666666666666667*y - 1)**2) + 1)**2 - 2*(-2*x - 2*y)*np.exp(-(x + y)**2)/(1 + np.exp(-(x + y)**2))**2
  return 1/5*(1+2*np.cos(2*y))

x=np.linspace(-2,2,1000)
y=x
X,Y=np.meshgrid(x,y)
print(X)
#ax1.plot_surface(X,Y,f(X,Y),zorder=-1,linewidth=0,edgecolor="none",alpha=1,shade=1)
ax1.set_xlim(-1.3,1.4)
ax1.set_ylim(-3,3)
#ax.plot_surface(X,Y,jy(X,Y))

gamma=2
dt=0.001
eps=0.05
m=50
p=np.array([-1,2.9])
p=np.reshape(np.tile(p,m),(m,2))
Z=0*f(X,Y)
c=["orange","blue","green","yellow","red","purple"]
for i in range(int(50/dt)):
  #p=np.array([0.5,1.5])
  for j in range(m):
    z=p[j]-gamma*dt*np.array([jx(p[j,0],p[j,1]),jy(p[j,0],p[j,1])])+np.sqrt(eps)*np.random.normal(0,np.sqrt(dt),size=(2,))
    ax1.plot(np.array([p[j,0],z[0]]),np.array([p[j,1],z[1]]),np.array([f(p[j,0],p[j,1]),f(z[0],z[1])]),zorder=i+1,color="orange",linewidth=0.4,alpha=1)
    #len=0.8*np.max(np.linalg.norm(p-np.mean(p,axis=0),axis=1))
    #ax1.set_xlim(np.mean(p[:,0])-len,np.mean(p[:,0])+len)
    #ax1.set_ylim(np.mean(p[:,1])-len,np.mean(p[:,1])+len)
    p[j,:]=z
  print(100.*i/(10/dt))
  if i%10==0:
    x=np.linspace(ax1.get_xlim()[0],ax1.get_xlim()[1],1000)
    y=np.linspace(ax1.get_ylim()[0],ax1.get_ylim()[1],1000)
    X,Y=np.meshgrid(x,y)
    sp=ax1.plot_surface(X,Y,f(X,Y),zorder=-1,linewidth=0,edgecolor="none",alpha=1,shade=1,color="blue")
    plt.savefig("sgd"+str(int(i/10))+".png",bbox_inches='tight')
    ###files.download("sgd"+str(i/50)+".png")
    uploaded = drive.CreateFile({'title': 'sgd'+str(int(i/10))+'.png'})
    uploaded.SetContentFile('sgd'+str(int(i/10))+'.png')
    uploaded.Upload()
    for l in fig.gca().lines:
      l.set_alpha(l.get_alpha()*0.95)
      if(l.get_alpha()<0.1):
        ax1.lines.remove(l)
  #Z+=np.exp(10*(-np.multiply(X-p[0],X-p[0])-np.multiply(Y-p[1],Y-p[1])))
  #print(Z)
#ax2.plot_surface(X,Y,Z)

plt.show()

[[-2.         -1.995996   -1.99199199 ...  1.99199199  1.995996
   2.        ]
 [-2.         -1.995996   -1.99199199 ...  1.99199199  1.995996
   2.        ]
 [-2.         -1.995996   -1.99199199 ...  1.99199199  1.995996
   2.        ]
 ...
 [-2.         -1.995996   -1.99199199 ...  1.99199199  1.995996
   2.        ]
 [-2.         -1.995996   -1.99199199 ...  1.99199199  1.995996
   2.        ]
 [-2.         -1.995996   -1.99199199 ...  1.99199199  1.995996
   2.        ]]
0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
0.1
0.11
0.12
0.13
0.14
0.15
0.16
0.17
0.18
0.19
0.2
0.21
0.22
0.23
0.24
0.25
0.26
0.27
0.28
0.29
0.3
0.31
0.32
0.33
0.34
0.35
0.36
0.37
0.38
0.39
0.4
0.41
0.42
0.43
0.44
0.45
0.46
0.47
0.48
0.49
0.5
0.51
0.52
0.53
0.54
0.55
0.56
0.57
0.58
0.59
0.6
0.61
0.62
0.63
0.64
0.65
0.66
0.67
0.68
0.69
0.7
0.71
0.72
0.73
0.74
0.75
0.76
0.77
0.78
0.79
0.8
0.81
0.82
0.83
0.84
0.85
0.86
0.87
0.88
0.89
0.9
0.91
0.92
0.93
0.94
0.95
0.96
0.97
0.98
0.99
1.0
1.01
1.02
1.03
1.04
1.05
1