In [None]:
#4
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import imageio

In [None]:
africa_data = pd.read_csv('atlas_africa.csv', low_memory=False)

In [None]:
resistant_data = africa_data[africa_data['Status'] == 'Resistant']

species_list = resistant_data['Species'].unique()

image_paths = []

rows, cols = 3, 3
max_subplots = rows * cols

for species in species_list:
    species_data = resistant_data[resistant_data['Species'] == species]

    infection_resistance = species_data.groupby(['infection type', 'antibiotics_class']).size() / species_data.groupby('infection type').size()

    infection_resistance_df = infection_resistance.reset_index(name='Resistance Rate')

    unique_classes = infection_resistance_df['antibiotics_class'].unique()
    if len(unique_classes) == 0:
        continue 

   
    for i in range(0, len(unique_classes), max_subplots):
        batch_classes = unique_classes[i:i+max_subplots]

        fig = make_subplots(rows=rows, cols=cols, 
                            subplot_titles=batch_classes,
                            vertical_spacing=0.3) 

        for idx, antibiotic_class in enumerate(batch_classes):
            class_data = infection_resistance_df[infection_resistance_df['antibiotics_class'] == antibiotic_class]
            
            class_data_sorted = class_data.sort_values('Resistance Rate', ascending=False)

            row = idx // cols + 1
            col = idx % cols + 1

            trace = go.Bar(x=class_data_sorted['infection type'], 
                           y=class_data_sorted['Resistance Rate'], 
                           name=antibiotic_class,
                           marker=dict(color=px.colors.qualitative.Plotly[idx % len(px.colors.qualitative.Plotly)]))

            fig.add_trace(trace, row=row, col=col)

            fig.update_xaxes(tickangle=45, row=row, col=col)

        fig.update_layout(title_text=f'Resistance by Infection Type for {species}', barmode='stack',
                          showlegend=False, height=1000, width=1200)

        image_path = f'resistance_{species}_{i // max_subplots + 1}.png'
        fig.write_image(image_path)
        image_paths.append(image_path)

with imageio.get_writer('resistance_by_species.gif', mode='I', duration=3000) as writer:
    for image_path in image_paths:
        image = imageio.imread(image_path)
        writer.append_data(image)

plt.imshow(imageio.imread('resistance_by_species.gif'))
plt.axis('off')
plt.show()
