In [None]:
import pandas as pd
import plotly.express as px
import os
from PIL import Image

In [None]:
africa_data = pd.read_csv('atlas_africa.csv')

In [None]:
#1. Function to filter data and generate resistance trend for each species

def plot_resistance_trend(africa_data, species_list, output_folder='species_plots', image_width=1200, image_height=800):
    # output folder
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for species in species_list:
        species_data = africa_data[africa_data['Species'] == species]
        
        resistant_prevalence = species_data[species_data['Status'] == 'Resistant'].groupby(['Year', 'antibiotics_class']).size() / species_data.groupby(['Year', 'antibiotics_class']).size()
        
        resistant_prevalence = resistant_prevalence.reset_index(name='Proportion Resistant')
        
        # Create line plot
        fig = px.line(
            resistant_prevalence, 
            x='Year', 
            y='Proportion Resistant', 
            color='antibiotics_class',
            title=f'Resistance Trend for {species} in Africa',
            labels={'Proportion Resistant': 'Proportion of Resistant Isolates', 'antibiotics_class': 'Antibiotic Class'},
            markers=True 
        )
        
        fig.update_layout(
            width=image_width, 
            height=image_height, 
            title_font_size=20, 
            legend_title_text='Antibiotic Class' 
        )
        
        # Save the plot as an image 
        plot_filename = os.path.join(output_folder, f'{species}_resistance_trend.png')
        fig.write_image(plot_filename, width=image_width, height=image_height)
        

def create_gif_from_plots(output_folder, gif_filename='species_trend.gif', duration=3000):
    images = []
    for filename in sorted(os.listdir(output_folder)):
        if filename.endswith(".png"):
            file_path = os.path.join(output_folder, filename)
            images.append(Image.open(file_path))
    
    # Save as GIF
    images[0].save(gif_filename, save_all=True, append_images=images[1:], duration=duration, loop=0)

species_list = ['Pseudomonas aeruginosa', 'Klebsiella pneumoniae',
       'Staphylococcus aureus', 'Acinetobacter baumannii',
       'Streptococcus pneumoniae', 'Haemophilus influenzae',
       'Enterococcus faecium', 'Neisseria gonorrhoeae']

plot_resistance_trend(africa_data, species_list)
create_gif_from_plots('species_plots')
