Skip to content

Commit

Permalink
need to test multi gpu implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 16, 2023
1 parent d06bf10 commit 5ed3aa6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 85 deletions.
3 changes: 2 additions & 1 deletion birds/models/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def forward(self, p):
Arguments:
p (torch.Tensor): Probability of moving forward at each timestep.
"""
device = p.device
p = torch.clip(p, min=0.0, max=1.0) #torch.nn.functional.softmax(p[0])
probs = p * torch.ones(self.n_timesteps)
probs = p * torch.ones(self.n_timesteps, device)
logits = torch.vstack((probs, 1 - probs)).log()
steps = torch.nn.functional.gumbel_softmax(
logits, dim=0, tau=self.tau_softmax, hard=True
Expand Down
12 changes: 7 additions & 5 deletions birds/models/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ def forward(self, params):
Arguments:
params (torch.Tensor) : a tensor of shape (3,) containing the **log10** of the fraction of infected, beta, and gamma
"""
device = params.device
self.mp = self.mp.to(device)
# Initialize the parameters
params = soft_minimum(params, torch.tensor(0.0), 2)
params = soft_minimum(params, torch.tensor(0.0, device=device), 2)
params = 10**params

initial_infected = params[0]
beta = params[1]
gamma = params[2]
n_agents = self.graph.num_nodes
# Initialize the state
infected = torch.zeros(n_agents)
susceptible = torch.ones(n_agents)
recovered = torch.zeros(n_agents)
infected = torch.zeros(n_agents, device=device)
susceptible = torch.ones(n_agents, device=device)
recovered = torch.zeros(n_agents, device=device)
# sample the initial infected nodes
probs = initial_infected * torch.ones(n_agents)
probs = initial_infected * torch.ones(n_agents, device=device)
new_infected = self.sample_bernoulli_gs(probs)
infected += new_infected
susceptible -= new_infected
Expand Down
Loading

0 comments on commit 5ed3aa6

Please sign in to comment.