In [14]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.animation import FuncAnimation
from IPython.display import display, Image, clear_output
import ipywidgets as widgets
import os
import networkx as nx
import ndlib.models.ModelConfig as mc
import ndlib.models.epidemics as ep
from ndlib.viz.mpl.DiffusionTrend import DiffusionTrend
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.lines import Line2D  # Import Line2D for custom legend markers

if not os.path.exists('gifs'):
    os.makedirs('gifs')

code_input = widgets.Textarea(
    value="""def get_iterations(g, time_steps):
    model = ep.SIRModel(g)
    cfg = mc.Configuration()
    cfg.add_model_parameter('beta', 0.5)
    cfg.add_model_parameter('gamma', 0.1)
    cfg.add_model_parameter("fraction_infected", 0.2)
    model.set_initial_status(cfg)
    
    iterations = model.iteration_bunch(time_steps)
    return iterations, model
""",
    placeholder='Type your code here',
    description='Code:',
    layout=widgets.Layout(width='100%', height='300px'),
    language='python'
)

output = widgets.Output()

time_steps_slider = widgets.IntSlider(
    value=20,
    min=1,
    max=100,
    step=1,
    description='Time Steps:',
    style={'description_width': 'initial'}
)

total_nodes = 10
g = nx.erdos_renyi_graph(total_nodes, 0.2)
pos = nx.spring_layout(g)

def create_legend(ax):
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=10, label='Recovered'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Infected'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Susceptible'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='grey', markersize=10, label='Unknown'),
    ]
    ax.legend(handles=legend_elements, loc='upper right')

def get_node_colors(status):
    colors = []
    for node in status.keys():
        if status[node] == 0:
            colors.append([0, 0, 1, 1])  # Susceptible - Blue
        elif status[node] == 1:
            colors.append([1, 0, 0, 1])  # Infected - Red
        elif status[node] == 2:
            colors.append([0, 1, 0, 1])  # Recovered - Green
        else:
            colors.append([0.5, 0.5, 0.5, 1])  # Unknown status - Grey
    return colors

def run_code(change):
    with output:
        try:
            clear_output(wait=True)
            local_vars = {}
            exec(change['new'], globals(), local_vars)
            get_iterations = local_vars['get_iterations']
            iterations, model = get_iterations(g, time_steps_slider.value)
            
            
            current_state = iterations[0]['status']
            node_colors = []
            node_colors.append(get_node_colors(current_state))
            for iteration in iterations[1:]:
                current_state.update(iteration['status'])
                node_colors.append(get_node_colors(current_state))
    
            trends = model.build_trends(iterations)
            viz = DiffusionTrend(model, trends)
            viz.plot()
            # plt.show()
            
            fig, ax = plt.subplots()
            nodes = nx.draw_networkx_nodes(g, pos, node_color=node_colors[0], ax=ax)
            edges = nx.draw_networkx_edges(g, pos, ax=ax)
            plt.axis('off')
            create_legend(ax)
            timestep_text = None
            def update(ii):
                nonlocal timestep_text
                if timestep_text:  # If the text object exists, remove it from the plot
                    timestep_text.remove()
                nodes.set_facecolor(node_colors[ii])
                timestep_text = ax.text(0.05, 0.95, f'Timestep: {ii}', transform=ax.transAxes, fontsize=9, verticalalignment='top', bbox=dict(boxstyle="round", facecolor='wheat', alpha=0.5))
                return nodes,
            
            animation = FuncAnimation(fig, update, interval=50, frames=time_steps_slider.value, blit=True)
            gif_path = 'gifs/test.gif'
            animation.save(gif_path, writer='pillow', savefig_kwargs={'facecolor':'white'}, fps=3)
            plt.close(fig)
            
            if os.path.exists(gif_path):
                display(Image(gif_path))
            else:
                print("GIF was not created.")

            graph = nx.complete_graph(total_nodes)    
            
        except Exception as e:
            print("Error executing code:", e)
            print(e.with_traceback())
# if not hasattr(run_code, '_observer_set'):
#     time_steps_slider.observe(run_code, names='value')
#     run_code._observer_set = True
code_input.observe(run_code, names='value')
display(time_steps_slider)
display(code_input)
display(output)

IntSlider(value=20, description='Time Steps:', min=1, style=SliderStyle(description_width='initial'))

Textarea(value='def get_iterations(g, time_steps):\n    model = ep.SIRModel(g)\n    cfg = mc.Configuration()\n…

Output()