In [None]:
import os
os.chdir('../..')

import numpy as np
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

import scipy.io as scio

In [None]:
peak_file = 'data/figure_3_peaks.mat'
sim_file = 'data/figure_3_timeseries.mat'
mat = scio.loadmat(peak_file)
sims = scio.loadmat(sim_file)

def peak(x, y):
    return mat['peak'][x,y]

def peakTime(x, y):
    return mat['peakTime'][x, y]

def get_sim(idx):
    x = sims['TimeSeries'][idx][0].reshape(-1)[0]
    y = sims['TimeSeries'][idx][1].reshape(-1)[0]
    ts = sims['TimeSeries'][idx][2].reshape(-1)
    ys = sims['TimeSeries'][idx][3]
    infecteds = np.sum(ys[:,2:7], axis=1)
    return x, y, ts, infecteds

In [None]:
plt.rcParams.update({'font.size': 24})
cmap = plt.get_cmap('inferno')

def rescale(x):
    # rescale duty_cycle to [0, 1]
    min_x = 0.07
    max_x = 0.5
    return (x-min_x)/(max_x-min_x)

fig = plt.figure(figsize=(30,10))
spec = gridspec.GridSpec(ncols=3, nrows=2, figure=fig)

tl = fig.add_subplot(spec[:2,0])
tr = fig.add_subplot(spec[:2,2])
tm = [fig.add_subplot(spec[i, 1]) for i in range(2)]

# left subplot
ax = tl
xticks = [1, 2, 3, 4, 8, 12, 16]
ax.set_xticks(xticks)
# sample dutycycles
for base_period in [7, 14]:
    # use these as base period lengths
    multipliers = [(i+1) for i in range(int(112/base_period))]
    periods = [base_period*i for i in multipliers]
    print(periods)
    
    for duty_int in range(1,int(base_period/2)+1):
        if base_period > 7 and (duty_int % (base_period/7) == 0):
            continue
        if base_period > 14 and (duty_int % (base_period/14) == 0):
            continue
            
        xs = [duty_int * i for i in multipliers]
        ys = [(base_period-duty_int)*i for i in multipliers]
        peak_values = [peak(x, y) for x, y in zip(xs, ys)]
        ax.semilogy(np.array(periods)/7, np.array(peak_values)/1e5, 'o-', c=cmap(rescale(duty_int/base_period)))

rect = patches.Rectangle((1.5, 0.64), width=3, height=0.24,linewidth=2,edgecolor='k',facecolor='none', fill=False)
ax.add_patch(rect)
ax.grid('on', axis='y')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticks([1, 2, 5, 10, 20, 50])
ax.set_yticklabels([1, 2, 5, 10, 20, 50])
ax.set_ylabel('Peak Infected (% of 10 million)')
ax.set_xlabel('Period Length (weeks)')

# right layout
ax = tr
xticks = [1, 2, 3, 4, 8, 12, 16]
ax.set_xticks(xticks)
working_days = [(i+1) for i in range(6)]
multipliers = [(i+1) for i in range(8)]

# sample dutycycles
for w in working_days:
    duty_cycle = w/14
    
    if int(w)%2 == 0:
        x = int(w/2)
        multipliers = [(i+1) for i in range(16)]
        xs = [x * i for i in multipliers]
        ys = [(7-x)*i for i in multipliers]
        periods = [7*i for i in multipliers]
    else:
        x = w
        multipliers = [(i+1) for i in range(8)]
        xs = [x * i for i in multipliers] # working days, scaled to period length
        ys = [(14-x)*i for i in multipliers] # quarantine days, scaled to period length
        periods = [14*i for i in multipliers] # period length
    
    peak_values = [peakTime(x, y) for x, y in zip(xs, ys)]
    ax.plot(np.array(periods)/7, peak_values, 'o-', c=cmap(rescale(duty_cycle)), 
            label=': {:.0f}%, e.g. (X={},Y={})'.format(duty_cycle*100, w, 14-w))
    
ax.grid('on', axis='y')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
#ax.spines['bottom'].set_visible(False)
#ax.spines['left'].set_visible(False)
ax.legend()
#ax.set_ylabel('Peak Value (percentage of 10 million)')
ax.set_ylabel('Peak Time (days)')
ax.set_xlabel('Period Length (weeks)')

ax = tm

xs = [i+1 for i in range(6)]
print('xs', xs)
ys = [14-x for x in xs]
print('ys', ys)
multiples = [1, 2, 4, 6, 8]
t_max=1600
ylim = [0, 1.2]

for i in reversed(range(2)):
    multiple = multiples[i]
    base_idx = i*len(xs)
    for j in xs[:4]:
        x, y, t, infecteds = get_sim(base_idx+j-1)
        duty_cycle = x/(x+y)
        ax[i].plot(t[:t_max], infecteds[:t_max]/1e5, c=cmap(rescale(duty_cycle)))
    
    ax[i].set_ylim(ylim)
    t_start = np.argmax(t>=50)
    ax[i].plot([50, 50], ylim,'--', c='black', label='FPSP start', alpha=0.25)
    ax[i].set_ylabel('Infected (%)')
    ax[i].spines['top'].set_visible(False)
    ax[i].spines['right'].set_visible(False)
    ax[i].spines['bottom'].set_visible(False)
    ax[i].spines['left'].set_visible(False)
    ax[i].grid(axis='y')
    if i < 1:
        ax[i].set_xticklabels([])
    ax[i].set_title('Period Length: {} weeks'.format(int((x+y)/7)))
    
#ax[0].legend(loc='upper left', bbox_to_anchor=(1., 0., 0., 0.))
ax[-1].axis('on')
ax[-1].set_xlabel('Time (days)')
ax[0].legend(loc=0)

plt.tight_layout()
plt.savefig('results/f3_3_column_v3.eps', dpi=1200)
plt.savefig('results/f3_3_column_v3.png', dpi=300)