In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [10]:
num_agents = 100

def init_stages(num_agents, initial_infections_percentage):
        '''initial_infections_percentage should be between 0.1 to 1'''
        prob_infected = (initial_infections_percentage / 100) * torch.ones(
            (num_agents, 1))
        p = torch.hstack((prob_infected, 1 - prob_infected))
        cat_logits = torch.log(p + 1e-9)
        agents_stages = F.gumbel_softmax(logits=cat_logits,
                                         tau=1,
                                         hard=True,
                                         dim=1)[:, 0]
        return agents_stages

def transmission(current_stages, R):
      lam = R*torch.ones(size=(num_agents, 1))*0.5

      prob_infected = 1 - torch.exp(-lam)
      p = torch.hstack((prob_infected, 1 - prob_infected))
      cat_logits = torch.log(p + 1e-9)

      potentially_exposed_today = F.gumbel_softmax(logits=cat_logits, tau=1, hard=True, dim=1)[:, 0]
      newly_exposed_today = (current_stages == 0) * potentially_exposed_today

      return newly_exposed_today

def progression(current_stages, newly_exposed_today):
    new_stages = (current_stages==1)*2 + newly_exposed_today*1
    return new_stages

def progression2(current_stages, newly_exposed_today):
    new_stages = current_stages + (current_stages==1)*1 + newly_exposed_today*1
    return new_stages
      
initial_infections_percentage = torch.tensor(30.0, requires_grad=True)
R = torch.tensor(3.0, requires_grad=True)
m_rate = torch.tensor(0.7, requires_grad=True)

current_stages = init_stages(num_agents, initial_infections_percentage)
agents_infected_index = (current_stages > 0)

for i in range(2):
    newly_exposed = transmission(current_stages, R)
    current_stages = progression2(current_stages, newly_exposed)

deaths = m_rate*((current_stages==2)*current_stages/2).sum()
deaths.backward()
print(initial_infections_percentage.grad, R.grad, m_rate.grad)

tensor(0.2026) tensor(1.6954) tensor(82.)


In [None]:
class SEIRMProgression(DiseaseProgression):
    '''SEIRM for COVID-19 '''

    def __init__(self, params):
        super(DiseaseProgression, self).__init__()
        # encoding of stages
        self.SUSCEPTIBLE_VAR = 0
        self.EXPOSED_VAR = 1  # exposed state
        self.INFECTED_VAR = 2
        self.RECOVERED_VAR = 3
        self.MORTALITY_VAR = 4
        # default times (only for initialization, later they are learned)
        self.EXPOSED_TO_INFECTED_TIME = 3
        self.INFECTED_TO_RECOVERED_TIME = 5

        self.STAGE_UPDATE_VAR = 1
        self.STAGE_SAME_VAR = 0

        # inf time
        self.INFINITY_TIME = params['num_steps'] + 1
        self.num_agents = params['num_agents']

    def initialize_variables(self, agents_infected_time, agents_stages,
                             agents_next_stage_times):
        ''' initialize tensor variables depending on disease '''
        agents_infected_time[agents_stages == self.EXPOSED_VAR] = -1
        agents_infected_time[agents_stages == self.INFECTED_VAR] = -1 * self.EXPOSED_TO_INFECTED_TIME

        agents_next_stage_times[agents_stages == self.EXPOSED_VAR] = self.EXPOSED_TO_INFECTED_TIME
        agents_next_stage_times[agents_stages == self.INFECTED_VAR] = self.INFECTED_TO_RECOVERED_TIME

        return agents_infected_time, agents_next_stage_times

    def update_initial_times_new(self, learnable_params, agents_stages, agent_next_stage_times):
        infected_to_recovered_time = learnable_params[
            'infected_to_recovered_time']
        exposed_to_infected_time = learnable_params['exposed_to_infected_time']

        agent_next_stage_times = agent_next_stage_times + (agents_stages == self.EXPOSED_VAR)*(exposed_to_infected_time - self.EXPOSED_TO_INFECTED_TIME) + \
                                (agents_stages == self.INFECTED_VAR)*(infected_to_recovered_time - self.INFECTED_TO_RECOVERED_TIME)

        return agent_next_stage_times

    def update_initial_times(self, learnable_params, agents_stages,
                             agents_next_stage_times):
        ''' this is for the abm constructor '''
        infected_to_recovered_time = learnable_params['infected_to_recovered_time']
        exposed_to_infected_time = learnable_params['exposed_to_infected_time']
        
        agents_next_stage_times[agents_stages == self.EXPOSED_VAR] = exposed_to_infected_time
        agents_next_stage_times[agents_stages == self.INFECTED_VAR] = infected_to_recovered_time

        return agents_next_stage_times

    def get_newly_exposed(self, current_stages, potentially_exposed_today):
        # we now get the ones that new to exposure
        newly_exposed_today = (current_stages == self.SUSCEPTIBLE_VAR) * potentially_exposed_today
        return newly_exposed_today

    def update_new_stage_times_new(self, learnable_params, newly_exposed_today, current_stages, agents_next_stage_times, t):
        exposed_to_infected_time = learnable_params['exposed_to_infected_time']
        infected_to_recovered_time = learnable_params['infected_to_recovered_time']

        new_transition_times = torch.clone(agents_next_stage_times)
        curr_stages = torch.clone(current_stages)

        time_progression_old_exposure = (curr_stages == self.INFECTED_VAR)*(agents_next_stage_times == t)*self.INFINITY_TIME + \
                    (curr_stages==self.EXPOSED_VAR)*(agents_next_stage_times == t)*(infected_to_recovered_time)
        
        time_progression_new_exposure = exposed_to_infected_time + 1 - self.params['num_steps'] - 1 # TODO: verify the last part once!
        
        new_transition_times = new_transition_times + newly_exposed_today * time_progression_new_exposure + (1 - newly_exposed_today) * time_progression_old_exposure

        return new_transition_times

    def update_next_stage_times(self, learnable_params, newly_exposed_today,
                                current_stages, agents_next_stage_times, t):
        ''' update time '''
        exposed_to_infected_time = learnable_params['exposed_to_infected_time']
        infected_to_recovered_time = learnable_params[
            'infected_to_recovered_time']
        # for non-exposed
        # if S, R, M -> set to default value; if E/I -> update time if your transition time arrived in the current time
        new_transition_times = torch.clone(agents_next_stage_times)
        curr_stages = torch.clone(current_stages).long()
        new_transition_times[(curr_stages == self.INFECTED_VAR) *
                             (agents_next_stage_times
                              == t)] = self.INFINITY_TIME
        new_transition_times[(curr_stages == self.EXPOSED_VAR) *
                             (agents_next_stage_times
                              == t)] = t + infected_to_recovered_time
        return newly_exposed_today * (t + 1 + exposed_to_infected_time) + (
            1 - newly_exposed_today) * new_transition_times
    
    def update_current_stage_new(self, newly_exposed_today, current_stages, agents_next_stage_times, t):

        transit_agents = (agents_next_stage_times > t)*self.STAGE_SAME_VAR + (agents_next_stage_times<= t)*self.STAGE_UPDATE_VAR

        stage_transition = (current_stages == self.EXPOSED_VAR)*transit_agents + (current_stages == self.INFECTED_VAR)*transit_agents

        new_stages = current_stages + stage_transition + newly_exposed*self.STAGE_UPDATE_VAR

        return new_stages

    def update_current_stage(self, newly_exposed_today, current_stages,
                             agents_next_stage_times, t):
        ''' progress disease: move agents to different disease stage '''
        transition_to_infected = self.INFECTED_VAR * (agents_next_stage_times<= t) + self.EXPOSED_VAR * (agents_next_stage_times > t)
        transition_to_mortality_or_recovered = self.RECOVERED_VAR * (agents_next_stage_times <= t) + self.INFECTED_VAR * (agents_next_stage_times > t)  # can be stochastic --> recovered or mortality

        # Stage progression for agents NOT newly exposed today'''
        # if S -> stay S; if E/I -> see if time to transition has arrived; if R/M -> stay R/M
        stage_progression = (current_stages == self.EXPOSED_VAR)*transition_to_infected \
            + (current_stages == self.INFECTED_VAR)*transition_to_mortality_or_recovered
                
        next_stages = current_stages + newly_exposed_today*self.SUSCEPTIBLE_TO_EXPOSED_VAR + stage_progression

        stage_progression_values = (current_stages == self.SUSCEPTIBLE_VAR) * potentially_exposed_today + 
        # update curr stage - if exposed at current step t or not
        current_stages = current_stages + (stage_progression_values)

        current_stages = current_stages + newly_exposed_today * self.EXPOSED_VAR + stage_progression
        return current_stages
    
    def get_target_variables(self, params, learnable_params,
                             newly_exposed_today, current_stages,
                             agents_next_stage_times, t):
        ''' get recovered (not longer infectious) + targets '''
        mortality_rate = learnable_params['mortality_rate']
        new_death_recovered_today = (
            current_stages * (current_stages == self.INFECTED_VAR) *
            (agents_next_stage_times
             <= t)) / self.INFECTED_VAR  # agents when stage changes
        # update for newly recovered agents {recovered now}
        recovered_dead_now = new_death_recovered_today  # binary bit vector
        NEW_DEATHS_TODAY = mortality_rate * new_death_recovered_today.sum()
        NEW_INFECTIONS_TODAY = newly_exposed_today.sum()

        return recovered_dead_now, NEW_INFECTIONS_TODAY, NEW_DEATHS_TODAY

    def init_stages(self, learnable_params, device):
        '''initial_infections_percentage should be between 0.1 to 1'''
        initial_infections_percentage = learnable_params[
            'initial_infections_percentage']
        prob_infected = (initial_infections_percentage / 100) * torch.ones(
            (self.num_agents, 1)).to(device)
        p = torch.hstack((prob_infected, 1 - prob_infected))
        cat_logits = torch.log(p + 1e-9)
        agents_stages = F.gumbel_softmax(logits=cat_logits,
                                         tau=1,
                                         hard=True,
                                         dim=1)[:, 0]
        return agents_stages

In [None]:
SUSCEPTIBLE_VAR = 1
EXPOSED_VAR = 2
INFECTED_VAR = 3

SUSCEPTIBLE_TO_EXPOSED_UPDATE = 1
EXPOSED_TO_INFECTED_UPDATE = 1

current_stages = current_stages + [(current_stages==SUSCEPTIBLE_VAR)*SUSCEPTIBLE_TO_EXPOSED_UPDATE + (current_stages==SUSCEPTIBLE_VAR)*EXPOSED_TO_INFECTED_UPDATE]

## Checking gradient propogation for vector masks

In [7]:
R = torch.tensor([3.0], requires_grad=True)
request_times = R*torch.tensor([1, 2, 3, 4])
request_2 = torch.clone(request_times)
eligible = torch.tensor([1, 0, 0, 1])
t = 10
print(request_times, request_2)

request_times[eligible==1] = t
request_2 = request_2*(1 - eligible) + eligible*t
print(request_times, request_2)

request_2.sum().backward()
print(R.grad)

tensor([ 3.,  6.,  9., 12.], grad_fn=<MulBackward0>) tensor([ 3.,  6.,  9., 12.], grad_fn=<CloneBackward0>)
tensor([10.,  6.,  9., 10.], grad_fn=<IndexPutBackward0>) tensor([10.,  6.,  9., 10.], grad_fn=<AddBackward0>)
tensor([5.])


## Gradient Horizon for Runner Trajectory

trajectory_variables = x1 - x100
ground truth = g100
params = w
horizon_length = 40
gradient_horizon = x61 - x100

xt+1 = f(xt, w)
loss = l(x100, g100)

dl/dw = l()

In [None]:
trajectory = runner.trajectory
gradient_horizon = 30

# ground truth is at end of the trajectory
detach_gradient_length = len(trajectory) - gradient_horizon
no_gradient_trajectory = trajectory[:detach_gradient_length].detach()
trajectory = torch.cat((no_gradient_trajectory, trajectory[len(trajectory) - gradient_horizon:]))


class OurOptimizer(nn.Module):
    def __init__(self, base_opt, loss_fn, gradient_horizon):
        super().__init__()
        self.loss = loss_fn
        self.base_opt = base_opt
        self.gradient_horizon = gradient_horizon

    def forward(trajectory):
        detach_gradient_length = len(trajectory) - self.gradient_horizon
        no_gradient_trajectory = trajectory[:detach_gradient_length].detach()
        trajectory = torch.cat((no_gradient_trajectory, trajectory[len(trajectory) - self.gradient_horizon:]))

        self.loss.backward()
        self.base_opt.step()
