In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


In [None]:
africa_data = pd.read_csv('atlas_africa.csv')

In [None]:
#2

# Convert 'Status' to numeric values: 'Resistant' = 1, everything else = 0
africa_data['Resistance'] = africa_data['Status'].apply(lambda x: 1 if x == 'Resistant' else 0)

# Calculate the percentage resistance by species and antibiotics
pivot_data = africa_data.pivot_table(index='Species', 
                                     columns='Antibiotics', 
                                     values='Resistance', 
                                     aggfunc='mean') * 100 

# drop duplicates
antibiotics_class_data = africa_data[['Antibiotics', 'antibiotics_class']].drop_duplicates()

# Sort antibiotics by their class
antibiotics_class_data.sort_values(by='antibiotics_class', inplace=True)
sorted_antibiotics = antibiotics_class_data['Antibiotics'].values

pivot_data = pivot_data[sorted_antibiotics]

# Get unique classes
unique_classes = antibiotics_class_data['antibiotics_class'].unique()

color_palette = sns.color_palette('tab20', len(unique_classes)) 
class_color_map = dict(zip(unique_classes, color_palette))

antibiotic_colors = antibiotics_class_data.set_index('Antibiotics')['antibiotics_class'].map(class_color_map)

plt.figure(figsize=(16, 10))
sns.heatmap(pivot_data, cmap='coolwarm', cbar_kws={'label': 'Percentage Resistance (%)'}, linewidths=0.5)

plt.title('MDR Pattern in Africa (Percentage Resistance by Species)')
plt.xlabel('Antibiotics')
plt.ylabel('Species')


#plt.xticks(rotation=45, ha='right')

ax = plt.gca()
for xtick, label in zip(ax.get_xticklabels(), pivot_data.columns):
    xtick.set_color(antibiotic_colors[label])
    
legend_handles = [plt.Line2D([0], [0], color=class_color_map[cls], lw=4) for cls in unique_classes if cls is not np.nan]
plt.legend(legend_handles, unique_classes, title='Antibiotics Classes', bbox_to_anchor=(1.15, 1.1), loc='upper left', fontsize='small') 
plt.tight_layout(rect=[0, 0, 0.9, 1]) 
plt.savefig('mdr_pattern_africa.png', bbox_inches='tight')

plt.show()
