From c5d02941b3ebe6e3f5c8927521599e20f27f1954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 15:22:44 +0200 Subject: [PATCH 1/7] shortened example GAN --- examples/backends/plot_wass2_gan_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 8f5002228..996247df3 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -115,7 +115,7 @@ def forward(self, x): optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) # number of iteration and size of the batches -n_iter = 500 +n_iter = 200 # set to 200 for doc buld but 1000 is better ;) size_batch = 500 # generate statis samples to see their trajectory along training @@ -167,7 +167,7 @@ def forward(self, x): pl.figure(3, (10, 10)) -ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] +ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199] for i in range(9): pl.subplot(3, 3, i + 1) @@ -183,7 +183,7 @@ def forward(self, x): # Generate and visualize data # --------------------------- -size_batch = 500 +size_batch = 200 xd = get_data(size_batch) xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy() From 4c65a41a5d10e00132ded79c396c90bdab826ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 15:24:40 +0200 Subject: [PATCH 2/7] pep8 and typo --- examples/backends/plot_wass2_gan_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 996247df3..20fd61c3c 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -115,7 +115,7 @@ def forward(self, x): optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) # number of iteration and size of the batches -n_iter = 200 # set to 200 for doc buld but 1000 is better ;) +n_iter = 200 # set to 200 for doc build but 1000 is better ;) size_batch = 500 # generate statis samples to see their trajectory along training From 0154d1c05f6a5e9d6a843becd3721927875cb84d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 16:46:29 +0200 Subject: [PATCH 3/7] awesome animation --- examples/backends/plot_wass2_gan_torch.py | 38 +++++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 20fd61c3c..f3017131d 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -50,6 +50,7 @@ import numpy as np import matplotlib.pyplot as pl +import matplotlib.animation as animation import torch from torch import nn import ot @@ -112,7 +113,7 @@ def forward(self, x): G = Generator() -optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001, alpha=0.5) # number of iteration and size of the batches n_iter = 200 # set to 200 for doc build but 1000 is better ;) @@ -179,16 +180,47 @@ def forward(self, x): if i == 0: pl.legend() + # %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(5) + + +def _update_plot(i): + pl.clf() + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.xlim((-1.5, 1.5)) + pl.ylim((-1.5, 1.5)) + pl.title('Iter. {}'.format(i)) + return 1 + + +i = 0 +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) +pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.xticks(()) +pl.yticks(()) +pl.xlim((-1.5, 1.5)) +pl.ylim((-1.5, 1.5)) +pl.title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=10, repeat_delay=2000) + # %% # Generate and visualize data # --------------------------- -size_batch = 200 +size_batch = 500 xd = get_data(size_batch) xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy() -pl.figure(4) +pl.figure(5) pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.title('Sources and Target distributions') From 794e6d9b5b316f1f5accaebb8745a2cc914f3758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 16:47:52 +0200 Subject: [PATCH 4/7] small eror pep8 --- examples/backends/plot_wass2_gan_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index f3017131d..01b511d82 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -180,7 +180,7 @@ def forward(self, x): if i == 0: pl.legend() - # %% +# %% # Animate trajectories of generated samples along iteration # ------------------------------------------------------- From 0c7427233b137a08632f04aed58d53585e51373a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 17:12:39 +0200 Subject: [PATCH 5/7] add animation to doc --- docs/source/conf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3a11798da..9b5a71971 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -337,7 +337,8 @@ def __getattr__(cls, name): intersphinx_mapping = {'python': ('https://docs.python.org/3', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'matplotlib': ('http://matplotlib.org/', None)} + 'matplotlib': ('http://matplotlib.org/', None), + 'torch': ('https://pytorch.org/docs/stable/', None)} sphinx_gallery_conf = { 'examples_dirs': ['../../examples', '../../examples/da'], @@ -345,6 +346,7 @@ def __getattr__(cls, name): 'backreferences_dir': 'gen_modules/backreferences', 'inspect_global_variables' : True, 'doc_module' : ('ot','numpy','scipy','pylab'), + 'matplotlib_animations': True, 'reference_url': { 'ot': None} } From 37282fe5612e65283b8439d5f75a5c6d56640a6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 17:27:57 +0200 Subject: [PATCH 6/7] better timing animation --- examples/backends/plot_wass2_gan_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 01b511d82..74f747b69 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -209,7 +209,7 @@ def _update_plot(i): pl.title('Iter. {}'.format(ivisu[i])) -ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=10, repeat_delay=2000) +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000) # %% # Generate and visualize data From ee69652ff1f8aca5c3be884e3801a3565995a014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 8 Jun 2021 17:52:39 +0200 Subject: [PATCH 7/7] tune step --- examples/backends/plot_wass2_gan_torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 74f747b69..ca5b3c96a 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -113,7 +113,7 @@ def forward(self, x): G = Generator() -optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001, alpha=0.5) +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5) # number of iteration and size of the batches n_iter = 200 # set to 200 for doc build but 1000 is better ;) @@ -184,7 +184,7 @@ def forward(self, x): # Animate trajectories of generated samples along iteration # ------------------------------------------------------- -pl.figure(5) +pl.figure(4, (8, 8)) def _update_plot(i):