Skip to content

Commit

Permalink
multi gpu example running
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauq committed May 16, 2023
1 parent 5062731 commit c1a7135
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
7 changes: 4 additions & 3 deletions birds/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ def _differentiate_loss(
"""
# first we just differentiate the loss through reverse-diff
regularisation_loss.backward()
# then we differentiate the parameters through the flows but also tkaing into account the jacobians of the simulator
to_diff = torch.zeros(1)
# then we differentiate the parameters through the flow also tkaing into account the jacobians of the simulator
device = forecast_parameters.device
to_diff = torch.zeros(1, device=device)
for i in range(len(forecast_jacobians)):
to_diff += torch.dot(forecast_jacobians[i], forecast_parameters[i, :])
to_diff += torch.dot(forecast_jacobians[i].to(device), forecast_parameters[i, :])
to_diff.backward()

def step(self):
Expand Down
4 changes: 2 additions & 2 deletions birds/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def loss_f(params):
for i in range(mpi_rank, len(params_list_comm), mpi_size):
params = torch.tensor(params_list_comm[i], device=device)
jacobian, loss_i = jacobian_calculator(params)
if np.isnan(loss):
if torch.isnan(loss_i) or torch.isnan(jacobian).any():
continue
loss += loss_i
jacobians_per_rank.append(torch.tensor(jacobian.cpu().numpy()))
Expand All @@ -118,7 +118,7 @@ def loss_f(params):
if mpi_comm is not None:
losses = mpi_comm.gather(loss, root=0)
if mpi_rank == 0:
loss = sum(losses)
loss = sum([l.cpu() for l in losses if l != 0])
if mpi_rank == 0:
jacobians = list(chain(*jacobians_per_rank))
indices = list(chain(*indices_per_rank))
Expand Down
3 changes: 2 additions & 1 deletion birds/models/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(self, params):
"""
device = params.device
self.mp = self.mp.to(device)
self.graph = self.graph.to(device)
# Initialize the parameters
params = soft_minimum(params, torch.tensor(0.0, device=device), 2)
params = 10**params
Expand All @@ -67,7 +68,7 @@ def forward(self, params):
susceptible -= new_infected

infected_hist = infected.sum().reshape((1,))
recovered_hist = torch.zeros((1,))
recovered_hist = torch.zeros((1,), device=device)

# Run the model forward
for _ in range(self.n_timesteps):
Expand Down
27 changes: 15 additions & 12 deletions docs/examples/gpu_parallelisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@

from birds.models.sir import SIR
from birds.calibrator import Calibrator
from birds.mpi_setup import mpi_rank


def make_model(n_agents, n_timesteps, device):
def make_model(n_agents, n_timesteps):
graph = networkx.watts_strogatz_graph(n_agents, 10, 0.1)
return SIR(graph=graph, n_timesteps=n_timesteps, device=device)
return SIR(graph=graph, n_timesteps=n_timesteps)


def make_flow():
def make_flow(device):
# Define flows
torch.manual_seed(0)
K = 4
latent_size = 3
hidden_units = 64
Expand All @@ -41,13 +43,13 @@ def make_flow():

# Construct flow model
flow = nf.NormalizingFlow(q0=q0, flows=flows)
return flow
return flow.to(device)


def train_flow(flow, model, true_data, n_epochs, n_samples_per_epoch):

def train_flow(flow, model, true_data, n_epochs, n_samples_per_epoch, device):
torch.manual_seed(0)
# Define a prior
prior = torch.distributions.MultivariateNormal(-2.0 * torch.ones(3), torch.eye(3))
prior = torch.distributions.MultivariateNormal(-2.0 * torch.ones(3, device=device), torch.eye(3, device=device))

optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)

Expand All @@ -64,10 +66,11 @@ def train_flow(flow, model, true_data, n_epochs, n_samples_per_epoch):
optimizer=optimizer,
w=w,
n_samples_per_epoch=n_samples_per_epoch,
device=device,

)

# and we run for 500 epochs without early stopping.

calibrator.run(n_epochs=n_epochs, max_epochs_without_improvement=np.inf)


Expand All @@ -78,16 +81,16 @@ def train_flow(flow, model, true_data, n_epochs, n_samples_per_epoch):
parser.add_argument("--n_agents", type=int, default=1000)
parser.add_argument("--n_timesteps", type=int, default=100)
parser.add_argument("--n_samples_per_epoch", type=int, default=5)
parser.add_argument("--device_ids", type=list, default=["cpu"])
parser.add_argument("--device_ids", default=["cpu"], nargs="+")
args = parser.parse_args()

# device of this rank
device = args.device_ids[mpi_rank]

model = make_model(args.n_agents, args.n_timesteps, device=device)
model = make_model(args.n_agents, args.n_timesteps)
true_parameters = torch.tensor(
[0.05, 0.05, 0.05], device=device
).log10() # SIR takes log parameters
true_data = model(true_parameters)
flow = make_flow()
train_flow(flow, model, true_data, args.n_epochs, args.n_samples_per_epoch)
flow = make_flow(device)
train_flow(flow, model, true_data, args.n_epochs, args.n_samples_per_epoch, device=device)

0 comments on commit c1a7135

Please sign in to comment.