In [1]:
from itertools import product
from dynalearn.util import load_experiments
from dynalearn.util.display import *

save_svg, save_png, save_pdf = False, False, False

  self[key] = other[key]


## Loading data

In [2]:
with open("../../data/spain_mobility/processed_data/covid_province_data.dat", "rb") as f:
    covid_data = pickle.load(f)

Did not find file `exp-sis-ba.zip`, kept proceding.
Did not find file `exp-plancksis-ba.zip`, kept proceding.


## Additional functions

In [None]:
def ridgeline(data, x=None, ax=None, xticks=None, overlap=0, yshift=0., color="blue", zorder=0, alpha=0.5, withline=True, witharea=True, index=None):
    if ax is None:
        ax = plt.gca()
    curves = []
    ys = []
    labels = []
    if index is None:
        index = list(data.keys())
    ymin, ymax = 0, -np.inf
    for i, k in enumerate(index):
        v = data[k]
        labels.append(k)
        d = v / v.max() * (1. + overlap)
        if x is None:
            x = np.arange(len(d))
        y = i
        ys.append(y)
        if witharea:
            ax.fill_between(
                x, 
                np.ones(len(d)) * y + yshift, 
                d + y + yshift, 
                zorder=zorder+len(data)-i+1, 
                color=color, 
                alpha=alpha
            )
        if withline:
            ax.plot(x, d + y, c='k', zorder=len(data)-i+1, linewidth=0.5)
        ymax = max(np.max(d + y + yshift), ymax)
    ax.set_yticks(ys)
    ax.set_yticklabels(labels)
    if xticks is not None:
        assert isinstance(xticks, dict)
        ax.set_xticks(list(xticks.values()))
        ax.set_xticklabels(list(xticks.keys()), rotation = 45, fontsize=14)
        xmin, xmax = min(list(xticks.values())), max(list(xticks.values()))
        ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

locations = {
    "center center": (0.5, 0.5, "center", "center"),
    "upper right": (0.95, 0.95, "top", "right"),
    "lower right": (0.95, 0.05, "bottom", "right"),
    "upper left": (0.05, 0.95, "top", "left"),
    "lower left": (0.05, 0.05, "bottom", "left"),
}
def label_plot(ax, label, loc="center center", fontsize=16):
    if isinstance(loc, tuple):
        h, v, va, ha = loc
    elif isinstance(loc, str):
        h, v, va, ha = locations[loc]
    ax.text(h, v, label, color="k", transform=ax.transAxes, 
        verticalalignment=va, horizontalalignment=ha, fontsize=fontsize,
    )

def add_subplot_axes(ax,rect):
    fig = plt.gcf()
    box = ax.get_position()
    width = box.width
    height = box.height
    inax_position  = ax.transAxes.transform(rect[0:2])
    transFigure = fig.transFigure.inverted()
    infig_position = transFigure.transform(inax_position)    
    x = infig_position[0]
    y = infig_position[1]
    width *= rect[2]
    height *= rect[3]  # <= Typo was here
    subax = fig.add_axes([x,y,width,height])
    x_labelsize = subax.get_xticklabels()[0].get_size()
    y_labelsize = subax.get_yticklabels()[0].get_size()
    x_labelsize *= rect[2]**0.5
    y_labelsize *= rect[3]**0.5
    subax.xaxis.set_tick_params(labelsize=x_labelsize)
    subax.yaxis.set_tick_params(labelsize=y_labelsize)
    return subax

times = {
    r"2020-01":31, 
    r"2020-02":29, 
    r"2020-03":31, 
    r"2020-04":30, 
    r"2020-05":31, 
    r"2020-06":30, 
    r"2020-07":31, 
    r"2020-08":31, 
    r"2020-09":30, 
    r"2020-10":31, 
    r"2020-11":30, 
    r"2020-12":31, 
    r"2021-01":31, 
    r"2021-02":28, 
    r"2021-03":31
}
cumul_times = {}
c = 0
for m, d in times.items():
    cumul_times[m] = c
    c += d

## Making the plot

In [None]:
factor = 8
h_map, w_map = 2, 4
h_traj, w_traj = 1, 3
fs = factor* 24 / 7
width, height = factor * int(np.ceil(w_map / h_map + w_traj / h_traj)), factor * 1
print(width, height)
fig = plt.figure(figsize=(width, height))
gs = fig.add_gridspec(height, width)
gs.update(wspace=0.025, hspace=0.05) # set the spacing between axes. 
ax_map = fig.add_subplot(gs[:height, :(factor*w_map // h_map)])
ax_traj = fig.add_subplot(gs[:height, (factor*w_map // h_map):])

# # map figure
# im = plt.imread("png/spain_mobility.png")
# ax_map.imshow(im)
ax_map.axis('off')
label_plot(ax_map, r"(a)", loc=(0.1, 1., "bottom", "right"), fontsize=fs)

# trajectories figure
index = reversed(sorted(list(covid_data["cases"].keys())))
ridgeline(covid_data["cases"], ax=ax_traj, xticks=cumul_times, index=index, color=color_dark["blue"], alpha=1., withline=True)
ax_traj.set_ylabel(r"Incidence", fontsize=fs)
label_plot(ax_traj, r"(b)", loc=(0., 1., "bottom", "left"), fontsize=fs)
ax_traj.axvspan(0, 335, alpha=0.2, color=color_pale["grey"])
# ax_traj.tick_params(axis="both", labelsize=fs * 8 / 18)
ax_traj.tick_params(axis="y", labelsize=12)
ax_traj.tick_params(axis="x", labelsize=20)
ax_traj.set_xlim([0, 445])

filename = "manuscript-figure4"
if save_svg:
    fig.savefig(os.path.join("svg", f"{figname}.svg"))