# The courtroom scenario
This script models a courtroom case study from Vaud, Switzerland, in October 2020. The detailed description of this case study can be found in Vernez, D., Schwarz, S., Sauvain, J.-J., Petignat, C. and Suarez, G. 2021. Probable aerosol transmission of SARS-CoV-2 in a poorly ventilated courtroom. **Indoor Air**. *31*, pp.1776â€“1785, https://doi.org/10.1111/ina.12866.
## Import the libraries

In [None]:
from topologicpy.Vertex import Vertex
from topologicpy.Edge import Edge
from topologicpy.Wire import Wire
from topologicpy.Face import Face
from topologicpy.Shell import Shell
from topologicpy.Cell import Cell
from topologicpy.CellComplex import CellComplex
from topologicpy.Cluster import Cluster
from topologicpy.Topology import Topology
from topologicpy.Graph import Graph
from topologicpy.Dictionary import Dictionary
from topologicpy.Vector import Vector
from topologicpy.Helper import Helper
from topologicpy.Plotly import Plotly
import numpy as np
nan = np.nan
import os
from tqdm.auto import tqdm
from skfem import *
from skfem.helpers import dot, grad
from funcs import *

## Define parameters

In [None]:
# physics parameters
beta = 1.7e-4
sigma = 1.1e-4
lambda_ = 6.4e-5
# lambda_ = 2.78e-4*3         # ACH=3
rho = beta+sigma+lambda_
room_volume = 8.9*5.6*3
D = 0.8*lambda_*room_volume**(2/3)
# person parameters
mu = 0.065*5
dm = 100
eta = 0                   # surgical mask
# simulation parameters
start_time = 0*3600 #seconds
end_time = 3*3600 #seconds
time_step = 60 #seconds
time_steps = [t for t in range(start_time,end_time,time_step)]
dt = time_steps[1]-time_steps[0]
animation_time_step = 6
# input/output settings
output_folder = "./simulation"
mesh_name = "floor_mesh.msh"
fig_width = 1920
fig_height = 1080

## Geometry

In [None]:
interior_wall_thickness = 0.1
exterior_wall_thickness = 0.3
room_origin = Vertex.ByCoordinates(0,0,0)
room = Cell.Prism(origin=room_origin, width=8.9, length=5.6, height=3, placement="lowerleft")
dic = Cell.Decompose(room)
bhf = dic['bottomHorizontalFaces']
shell = Shell.ByFaces(bhf)
horiz_slab = Face.ByShell(shell)
wire = Wire.RemoveCollinearEdges(Face.Wire(horiz_slab), angTolerance=2)
ext_offset = Wire.ByOffset(wire, offset=exterior_wall_thickness)
face = Face.ByWires(ext_offset, [wire])
normal = Face.Normal(face)
if normal[2] < 0:
    face = Face.Invert(face)
normal = Edge.ByFaceNormal(face)
walls = Cell.ByThickenedFace(face, thickness=3, bothSides=False, tolerance=0.001)
face = Face.Rectangle(origin=room_origin, width=8.9, length=5.6, placement='lowerleft')
locations = [[5.7,0.6,0],[4.2,0.6,0],[2.7,0.6,0],[7.2,0.6,0],[6.9,3.7,0],[8.3,4.9,0],[8.3,4.3,0],[2.1,3.9,0],[3.7,5.3,0],[4.8,3.9,0]]
location_vertices = [Vertex.ByCoordinates(loc) for loc in locations]
g = Graph.NavigationGraph(face,sources=location_vertices,destinations=location_vertices)
Topology.Show(Cluster.ByTopologies([walls,face,Graph.Topology(g)]))

## Define the class for agent, event and schedule

In [None]:
# Assumptions
# Metric system. All distances are in meters. Time is always measured in seconds.

class agent:
    def __init__(self, name="Untitled", infectious=False, mu=0.065, dm=100, rep=None, speed=1.5, eta=0, schedule=None, event_list=None, trajectory=None):
        self.name = name
        self.infectious = infectious
        self.rep = rep
        self.speed = speed
        self.schedule = schedule
        self.event_list = event_list if event_list is not None else []      # corresponding to time_steps
        self.trajectory = trajectory if trajectory is not None else []      # corresponding to time_steps
        self.location = None                # current location
        self.path = None                    # current path
        self.event = None                   # current event
        self.mu = mu
        self.dm = dm
        self.I = np.log(0.5)/(-dm)
        self.eta = eta                      # mask efficiency, 0 for not wearing a mask
        self.virus_load = np.array([0])
        self.risk = np.array([0])
        
    def get_current_event(self, current_time):
        current_event = None
        # Compute target_location based on current_time. Where should this agent be?
        events = self.schedule.events #The events are sorted by start time and duration.
        for event in events:
            start_time = event.start_time
            end_time = event.end_time
            if start_time <= current_time < end_time:
                current_event = event
                break
        if current_event:
            return current_event
        else:
            return None
    
    def get_current_position(self, current_time, previous_time, graph):
        # Based on current time, the agent can be at one of two events or travelling between them.
        self.event = self.get_current_event(current_time)
        # Set the default status of the agent as stationary
        self.moving = False
        if self.event:
            previous_event = self.get_current_event(previous_time)
            if self.schedule.events.index(self.event) == len(self.schedule.events) - 1:
                return Graph.NearestVertex(graph, self.event.location)
            target_event = self.schedule.events[self.schedule.events.index(self.event)+1]
            # path: get positions, find nearest vertex, compute shortest distance
            current_location = self.event.location
            c_vertex = Graph.NearestVertex(graph, current_location)
            if self.path == None or self.event!=previous_event:
                target_location = target_event.location
                t_vertex = Graph.NearestVertex(graph, target_location)
                self.path = Graph.ShortestPath(graph, c_vertex, t_vertex, edgeKey="length")
            if not self.path == None:
                distance = Wire.Length(self.path)
                if distance == None:
                    distance = 0
                # agent needs to be at target_event at its start time.
                duration = distance/self.speed
                # find out the time to leave (tl) = target_event.start_time - duration.
                time_to_leave = target_event.start_time - duration
                # If tl is larger than or equal to current_time then return the current_event's position.
                if time_to_leave >= current_time:
                    return c_vertex # It is not time to leave yet.
                else:
                    # Find the fraction by dividing (current_time - tl)/duration. Call that u.
                    # Find a vertex on the shortest path by using Wire.VertexByParameter(path, u)
                    self.moving = True
                    u = (current_time - time_to_leave)/duration
                    return Wire.VertexByParameter(self.path, u)
            return c_vertex
        else:
            return None

class schedule:
    def __init__(self, name="Untitled", events=[]):
        self.name = name
        start_times = [e.start_time for e in events]
        durations = [e.duration() for e in events]
        self.events = Helper.Sort(events, start_times, durations)

class event:
    def __init__(self, name, location, start_time, end_time, activity):
        self.name = name
        self.location = Vertex.ByCoordinates(location)  # note that the location is now a topologic.Vertex
        self.start_time = start_time
        self.end_time = end_time
        self.activity = activity
    def duration(self):
        return (self.end_time - self.start_time)

class activity:
    def __init__(self,name,p,Rt):
        self.name = name
        self.p = p
        self.Rt = Rt

## Create locations, events, and schedules

In [None]:
# presence
agent_num = 10
p1 = [[0*60,5*60],[5*60,23*60],[23*60,30*60],[30*60,55*60],[55*60,70*60],[70*60,104*60],[104*60,110*60],[110*60,180*60]]
p2 = [[0*60,5*60],[5*60,23*60],[23*60,30*60],[30*60,55*60],[55*60,70*60],[70*60,104*60],[104*60,110*60],[110*60,180*60]]
p3 = [[0*60,5*60],[5*60,23*60],[23*60,30*60],[30*60,55*60],[55*60,70*60],[70*60,104*60],[104*60,110*60],[110*60,180*60]]
p4 = [[0*60,5*60],[5*60,23*60],[23*60,30*60],[30*60,55*60],[55*60,70*60],[70*60,104*60],[104*60,110*60],[110*60,180*60]]
p5 = [[5*60,23*60],[30*60,55*60],[70*60,104*60],[110*60,180*60]]
p6 = [[5*60,23*60],[30*60,55*60],[70*60,104*60],[110*60,180*60]]
p7 = [[5*60,23*60],[30*60,55*60],[70*60,104*60],[110*60,180*60]]
p8 = [[5*60,23*60],[30*60,55*60],[70*60,104*60],[110*60,180*60]]
p9 = [[5*60,23*60],[30*60,55*60],[70*60,104*60],[110*60,180*60]]
p10 = [[70*60,104*60]]
presence = [p1,p2,p3,p4,p5,p6,p7,p8,p9,p10]
names = ['event_1','event_2','event_3','event_4','event_5','event_6','event_7','event_8']

# schedules
resting = activity('resting',1.8e-4,8)
talking = activity('talking',2.2e-4,40)
talking_loudly = activity('talking loudly',2.5e-4,80)
walking = activity('walking',1.1e-3,145)
p1a = [talking,talking,talking,talking,talking,talking,talking,talking]
p2a = [talking,talking_loudly,talking,talking_loudly,talking,talking_loudly,talking,talking_loudly]
p3a = [talking,resting,talking,resting,talking,resting,talking,resting]
p4a = [talking,talking,talking,talking,talking,talking,talking,talking]
p5a = [talking_loudly,talking_loudly,talking_loudly,talking_loudly]
p6a = [resting,resting,resting,resting]
p7a = [resting,resting,resting,resting]
p8a = [talking_loudly,talking_loudly,talking_loudly,talking_loudly]
p9a = [resting,resting,resting,resting]
p10a = [resting,resting,resting,resting]
activities = [p1a,p2a,p3a,p4a,p5a,p6a,p7a,p8a,p9a,p10a]
schedule_list = []
for i in range(agent_num):
    events = []
    events_num = len(presence[i])
    for j in range(events_num):
        events.append(event(names[j],locations[i], presence[i][j][0], presence[i][j][1], activities[i][j]))
    schedule_list.append(schedule(name="Schedule"+str(i+1), events=events))

## Create agents

In [None]:
agent_name = ['P1','P2','P3','P4','P5','P6','P7','P8','P9','P10']
rep = Cell.Cylinder(radius=0.2, height=1.8, placement="bottom")
agents = [agent(name=agent_name[i],infectious=False,schedule=schedule_list[i],rep=rep,mu=mu,dm=dm,eta=eta,speed=1.5) for i in range(agent_num)]
agents[0].infectious = True

## Conduct simulation of agent trajectories using the navigation graph

In [None]:
no_event = event('no_event',[nan,nan],start_time,end_time,resting)
walking_event = event('walking',[nan,nan],start_time,end_time,walking)
for t in tqdm(time_steps):
    for i, agent in enumerate(agents):
        agent.location = agent.get_current_position(t, t-time_step, g)
        # if the agent is inside the geometry
        if agent.event:
            agent.trajectory.append([Vertex.X(agent.location),Vertex.Y(agent.location)])
            if agent.moving == False:
                agent.event_list.append(agent.event)
            else:
                agent.event_list.append(walking_event)
        else:
            agent.event_list.append(no_event)
            agent.trajectory.append([nan,nan])
print("DONE")

# save simulation results
agent_location_matrix = np.array([time_steps])
for agent in agents:
    agent_location_matrix = np.vstack([agent_location_matrix,[agent.trajectory[i][0] for i in range(len(time_steps))]])
    agent_location_matrix = np.vstack([agent_location_matrix,[agent.trajectory[i][1] for i in range(len(time_steps))]])

agent_event_matrix = np.array([])
for agent in agents:
    agent_event_matrix = np.vstack([agent_event_matrix,[i.name for i in agent.event_list]]) if agent_event_matrix.size else np.array([i.name for i in agent.event_list])

In [None]:
###
wall_vertices = Topology.Vertices(room)
wall_vertices_x = [Vertex.X(i) for i in wall_vertices]
wall_vertices_y = [Vertex.Y(i) for i in wall_vertices]
wall_vertices_z = [Vertex.Z(i) for i in wall_vertices]
room_height = max(wall_vertices_z)-min(wall_vertices_z)
wall_face_vertices = Face.Vertices(face)
new_face_vertices = []
for v in wall_face_vertices:
    x = Vertex.X(v)
    y = Vertex.Y(v)
    z = 0
    new_face_vertices.append(Vertex.ByCoordinates([x,y,z]))    
mesh_face = Topology.ReplaceVertices(face,wall_face_vertices,new_face_vertices)
shell = ByMeshFace(mesh_face,meshSize = 0.5, meshName = mesh_name)
Topology.Show(shell)

shell_vertices = Shell.Vertices(shell)
shell_vertex_num = len(shell_vertices)
keys = ["Virus Concentration"]
dose_value = Dictionary.ByKeysValues(keys,[0])
for v in shell_vertices:
    Topology.SetDictionary(v,dose_value)

## Import mesh and calculate the vertex index (between mesh and topologic.Shell)

In [None]:
m = MeshTri.load(mesh_name)
e = ElementTriP1()  # or ElementQuad1
basis = Basis(m, e)
mesh_vertices = basis.doflocs
mesh_vertex_num = mesh_vertices.shape[1]
vertex_index_conversion = []
for i in range(mesh_vertex_num):
    mx, my = mesh_vertices[:,i]
    for j in range(shell_vertex_num):
        if abs(mx-Vertex.X(shell_vertices[j]))<1e-6 and abs(my-Vertex.Y(shell_vertices[j]))<1e-6:
            vertex_index_conversion.append(j)
            break

## Main simulation

In [None]:
x = basis.zeros()
vertex_value = [x]

@BilinearForm
def a(u, v, w):
    return (1+dt*rho+dt*w['prev'])*u*v+D*dt*dot(grad(u), grad(v))

@LinearForm
def L(v,w):
    return w['prev']* v

for i in range(len(time_steps)-1):
    x0 = basis.zeros()
    for agent in agents:
        # If agent is infected
        if not np.isnan(agent.trajectory[i+1][0]):
            x0 = x0+basis.point_source(np.array([agent.trajectory[i+1][0],agent.trajectory[i+1][1]]))*(1-agent.eta)*agent.event_list[i+1].activity.p
            if agent.infectious == True:
                x = x+basis.point_source(np.array([agent.trajectory[i+1][0],agent.trajectory[i+1][1]]))*agent.mu*(1-agent.eta)*agent.event_list[i+1].activity.Rt*dt/room_height
    x = solve(a.assemble(basis, prev=x0), L.assemble(basis, prev=x))
    vertex_value.append(x)

## Infection risk

In [None]:
for i in range(len(time_steps)-1):
    for j in range(len(agents)):
        agent = agents[j]
        if agent.infectious == False and not np.isnan(agent.trajectory[i+1][0]):
            agent_vertex = np.vstack([agent.trajectory[i+1][0],agent.trajectory[i+1][1]])
            agent_probe = basis.probes(agent_vertex)
            agent_loc_concentration = agent_probe @ vertex_value[i+1]
            virus_load = agent.virus_load[-1]+(1-agent.eta)*agent.event_list[i+1].activity.p*dt*agent_loc_concentration
            agent.virus_load = np.append(agent.virus_load, virus_load)
            agent.risk = np.append(agent.risk, 1-np.exp(-virus_load*agent.I))
        else:
            agent.virus_load = np.append(agent.virus_load, agent.virus_load[-1])
            agent.risk = np.append(agent.risk, agent.risk[-1])

# save simulation results
infection_risk_matrix = np.array([time_steps])
for agent in agents:
    if agent.infectious == False:
        infection_risk_matrix = np.vstack([infection_risk_matrix,agent.risk])
    else:
        infection_risk_matrix = np.vstack([infection_risk_matrix,np.array([nan for i in range(len(time_steps))])])

## Plotting

In [None]:
import matplotlib as matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['text.usetex'] = True

matplotlib.rcParams['figure.figsize'] = (3.5, 3.5*3/4)  # 3.5 inches for double column
matplotlib.rcParams['savefig.dpi']    = 400
matplotlib.rcParams['figure.dpi']     = 200

## Font and Font size
matplotlib.rcParams['font.family']      = 'sans-serif'
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['text.usetex']      = False           # dvipng need to be installed
matplotlib.rcParams['font.size']        = 8
matplotlib.rcParams['axes.labelsize']   =     matplotlib.rcParams['font.size']
matplotlib.rcParams['axes.titlesize']   = 1.2*matplotlib.rcParams['font.size']
matplotlib.rcParams['legend.fontsize']  = 0.8*matplotlib.rcParams['font.size']
matplotlib.rcParams['xtick.labelsize']  = 0.8*matplotlib.rcParams['font.size']
matplotlib.rcParams['ytick.labelsize']  = 0.8*matplotlib.rcParams['font.size']

# Figure location
matplotlib.rcParams['figure.subplot.left']   = 0.18        # default is 0.125
matplotlib.rcParams['figure.subplot.right']  = 0.96        # default is 0.9
matplotlib.rcParams['figure.subplot.bottom'] = 0.13        # default is 0.1
matplotlib.rcParams['figure.subplot.top']    = 0.91        # default is 0.9
matplotlib.rcParams['figure.subplot.wspace'] = 0           # default is 0.2
matplotlib.rcParams['figure.subplot.hspace'] = 0           # default is 0.2

# Tick control
matplotlib.rcParams['xtick.direction'] = 'in'
matplotlib.rcParams['ytick.direction'] = 'in'

matplotlib.rcParams['xtick.major.pad']   = 4  # default is 4
matplotlib.rcParams['xtick.major.size']  = 3
matplotlib.rcParams['xtick.major.width'] = 1
matplotlib.rcParams['ytick.major.pad']   = 4  # default is 4
matplotlib.rcParams['ytick.major.size']  = 3
matplotlib.rcParams['ytick.major.width'] = 1

# Tick label formatting
matplotlib.rcParams['axes.formatter.limits']       = -3, 3
matplotlib.rcParams['axes.formatter.use_mathtext'] = True

# Legend
matplotlib.rcParams['legend.frameon']  = False
matplotlib.rcParams['legend.loc']      = 'upper left'
matplotlib.rcParams['legend.numpoints'] = 1

# Line and marker control
matplotlib.rcParams['axes.linewidth']   = 1
matplotlib.rcParams['lines.linewidth']  = 1
matplotlib.rcParams['lines.markersize'] = 2

time_steps_hour = [time/3600+2 for time in time_steps]
event_hour = [0]*8
for i, element in enumerate(p1):
    event_hour[i] = [time/3600+2 for time in element]
plt.plot(time_steps_hour,agents[1].risk,linestyle='solid',color='black',label='P2')
plt.plot(time_steps_hour,agents[2].risk,linestyle='dashed',color='black',label='P3')
plt.plot(time_steps_hour,agents[3].risk,linestyle='dotted',color='black',label='P4')
plt.plot(time_steps_hour,agents[4].risk,color='black',label='P5',marker='v',markevery=10)
plt.plot(time_steps_hour,agents[5].risk,color='black',label='P6',marker='o',markevery=10)
plt.plot(time_steps_hour,agents[9].risk,linestyle='dashdot',color='black',label='P10')
plt.plot([2,5],[0.5,0.5],color='grey',alpha=0.5)
for i, element in enumerate(event_hour):
    if i%2 != 0:
        plt.fill_between(element, 0, 1,color='lightgrey',alpha=0.3)
plt.annotate('break',(3.1,0.85),xytext=(2.6,0.85),horizontalalignment='right',
     verticalalignment='center',arrowprops=dict(arrowstyle='->'))
plt.annotate('hearing',(2.85,0.75),xytext=(2.6,0.75),horizontalalignment='right',
     verticalalignment='center',arrowprops=dict(arrowstyle='->'))

plt.text(2.2,0.5,'50%', horizontalalignment='center', verticalalignment='bottom')
plt.text(time_steps_hour[-10], agents[1].risk[-10]+0.01, r'$\mathrm{P_2}$', horizontalalignment='center', verticalalignment='bottom')
plt.text(time_steps_hour[-10], agents[3].risk[-10]-0.02, r'$\mathrm{P_4}$', horizontalalignment='center', verticalalignment='top')
plt.text(time_steps_hour[-10], agents[2].risk[-10]+0.01, r'$\mathrm{P_3}$', horizontalalignment='center', verticalalignment='bottom')
plt.text(time_steps_hour[-10], agents[4].risk[-10]-0.02, r'$\mathrm{P_5}$', horizontalalignment='center', verticalalignment='top')
plt.text(time_steps_hour[-10], agents[5].risk[-10]+0.01, r'$\mathrm{P_6}$', horizontalalignment='center', verticalalignment='bottom')
plt.text(time_steps_hour[-10], agents[9].risk[-10]-0.005, r'$\mathrm{P_{10}}$', horizontalalignment='center', verticalalignment='top')

plt.xlim(2,5)
plt.xticks([2, 3, 4, 5], ['2:00pm', '3:00pm', '4:00pm', '5:00pm'])
plt.ylim(0,1)
plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1], ['0', '20%', '40%', '60%', '80%', '100%'])
plt.ylabel('Infection risk')
plt.savefig('./courtroom_risk.pdf')
plt.show()

In [None]:
floor_data = Plotly.DataByTopology(walls)
Rmax = 300
time_steps_hour = [time/3600 for time in time_steps]

t = time_steps[-1]
tt = -1
hours, minutes, seconds = convert_seconds_to_hms(t)

agent_infectious = []
agent_infected = []
agent_healthy = []
paths = []
for agent in agents:
    agent.location = Vertex.ByCoordinates(agent.trajectory[100][0], agent.trajectory[100][1], 0)
    rep = agent.rep
    geom = Topology.Place(rep, originA = Vertex.ByCoordinates(0, 0, 0), originB = agent.location)
    if agent.infectious == True:
        agent_infectious.append(geom)
    elif agent.risk[-1]>0.5:
        agent_infected.append(geom)  
    else:
        agent_healthy.append(geom)
data = floor_data
if agent_infectious:
    data = data + Plotly.DataByTopology(Cluster.ByTopologies(agent_infectious),faceColor='red',faceOpacity=1)
if agent_infectious:
    data = data + Plotly.DataByTopology(Cluster.ByTopologies(agent_infected),faceColor='orange',faceOpacity=1)
if agent_healthy:
    data = data + Plotly.DataByTopology(Cluster.ByTopologies(agent_healthy),faceColor='#00FF00',faceOpacity=1)

# floor contours
for i in range(mesh_vertex_num):
        dic_value = Dictionary.ByKeysValues(keys, [vertex_value[tt][i]])
        Topology.SetDictionary(shell_vertices[vertex_index_conversion[i]],dic_value)
virus_data = Plotly.DataByTopology(shell, intensityKey = keys[0], colorScale='plasma', faceOpacity=1, intensities=[x*Rmax/200 for x in range(200)], showVertices=False, showEdges=False)
data = data+virus_data

figure = Plotly.FigureByData(data, backgroundColor = 'white')
annotations = []
for i,agent in enumerate(agents):
    if i == 6 or i == 8:
        annotations.append(dict(
            showarrow=False,
            x=agent.trajectory[100][0],
            y=agent.trajectory[100][1]-0.4,
            z=2.6,
            font=dict(size = 28, color='white', family='Arial'),
            text='P<sub>'+agent.name[1:]+'</sub>'))
    else:
        annotations.append(dict(
            showarrow=False,
            x=agent.trajectory[100][0],
            y=agent.trajectory[100][1]+0.4,
            z=2.6,
            font=dict(size = 28, color='white', family='Arial'),
            text='P<sub>'+agent.name[1:]+'</sub>'))
figure.update_layout(
scene=dict(annotations=annotations),
)
# 3D
figure = Plotly.SetCamera(figure, camera=[0, 0, 2], up=[0,1,0], center = [0,0,0], projection="orthographic")
Plotly.FigureExportToPNG(figure, path='./courtroom_temp.png', width=1920, height=1080, overwrite=True)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
import matplotlib.patches as patches
# import matplotlib.table as table
from matplotlib import cm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from PIL import Image
import os
matplotlib.rcParams['mathtext.fontset'] = 'custom'
matplotlib.rcParams['mathtext.rm'] = 'DejaVu Sans'
matplotlib.rcParams['mathtext.it'] = 'DejaVu Sans:italic'
matplotlib.rcParams['mathtext.bf'] = 'DejaVu Sans:bold'
path = os.path.join(output_folder,"figure"+str(t).zfill(4))
figure = Image.open('./courtroom_temp.png')
figure = figure.crop((520,250,1400,830))
fig = plt.figure(figsize=(16,9))
ax1 = plt.subplot2grid(shape=(1, 6), loc=(0, 0),colspan=5)
ax2 = plt.subplot2grid(shape=(1, 6), loc=(0, 5))
ax1.imshow(figure)
ax1.axis('off')
# adding a colorbar
normalize = matplotlib.colors.Normalize(vmin=0,vmax=Rmax)
axins = inset_axes(ax2, width='20%', height='80%', loc='center')
cbar = fig.colorbar(cm.ScalarMappable(norm=normalize, cmap=cm.plasma), cax=axins, orientation='vertical')
cbar.set_label('Infectious aerosol concentration ($\mathrm{particles/m^3}$)',fontsize=20)
cbar.ax.tick_params(labelsize=20) 
ax2.axis('off')
plt.savefig('./courtroom.png',bbox_inches='tight')