In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ast

from collections import Counter
from itertools import product
from util import create_bar_plot, generate_wordcloud, extract_universities_programs
from util import create_heatmap
from world_map import create_world_map

import warnings
warnings.filterwarnings("ignore")

sns.set_theme(style="whitegrid")

In [None]:
DATASET_FILEPATH = "" # Set the cleaned dataset path here
PLOT_FILEPATH = ""  # Set the path to save the plots here

# Gemma-7B Instruct
# Mistral-7B Instruct
# LLaMA-3.1-8B Instruct

MODEL_NAME = "Mistral-7B Instruct" # Set the model name here for title name in the plots

## Load the Dataset

In [None]:
data = pd.read_excel(DATASET_FILEPATH)

In [None]:
# Convert the strings to list
data["Universities"] = data["Universities"].apply(lambda x: ast.literal_eval(x))
data["Country"] = data["Country"].apply(lambda x: ast.literal_eval(x))

# Convert the tuples to list
data["Universities"] = data["Universities"].apply(lambda x: list(map(list, x)))

In [None]:
data.head()

## Global Analysis of the Dataset

In [None]:
universities_list, programs_list, countries_list = [], [], []

for _, row in data.iterrows():
    universities, programs = extract_universities_programs(row["Universities"])
    countries = row["Country"]
    universities_list.extend(universities)
    programs_list.extend(programs)
    countries_list.extend(countries)

print(len(universities_list), len(programs_list), len(countries_list))

In [None]:
print("Unique Universities:", len(set(universities_list)))
print("Unique Programs:", len(set(programs_list)))
print("Unique Countries:", len(set(countries_list)))

In [None]:
# Data for plotting world map plot

countries_counts = Counter(countries_list)
df_countries_counts = pd.DataFrame(countries_counts.items(), columns=["country", "frequency"])

In [None]:
fig, ax = plt.subplots(figsize=(18, 10))

unmatched_countries = create_world_map(
    df_countries_counts,
    ax,
    f"Geographic Distribution of Recommended Universities: {MODEL_NAME}"
)

plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}world_map_distribution.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
top_uni = Counter(universities_list).most_common(20)
top_prog = Counter(programs_list).most_common(20)
top_countries = Counter(countries_list).most_common(20)

uni_names, uni_counts = zip(*top_uni)
prog_names, prog_counts = zip(*top_prog)
country_names, country_counts = zip(*top_countries)

In [None]:
countries_cloud = [name.replace(" ", "") for name in country_names]

generate_wordcloud(countries_cloud, MODEL_NAME, save_path=f"{PLOT_FILEPATH}countries_wordcloud.png")

In [None]:
create_bar_plot(
    x=list(country_counts),
    y=list(country_names),
    palette="viridis",
    xlabel="Count",
    ylabel="",
    title=f"Top 20 Countries by University Location: {MODEL_NAME}",
    save_path=f"{PLOT_FILEPATH}top_20_countries.png"
)

In [None]:
create_bar_plot(
    x=list(uni_counts),
    y=list(uni_names),
    palette="plasma",
    xlabel="Count",
    ylabel="",
    title=f"Top 20 Recommended Universities: {MODEL_NAME}",
    save_path=f"{PLOT_FILEPATH}top_20_universities.png"
)

In [None]:
create_bar_plot(
    x=list(prog_counts),
    y=list(prog_names),
    palette="coolwarm",
    xlabel="Count",
    ylabel="",
    title=f"Top 20 Recommended Master Programs: {MODEL_NAME}",
    save_path=f"{PLOT_FILEPATH}top_20_programs.png"
)

## Analysis by 'Gender'

In [None]:
# Analysis by Gender

genders = ["male", "female", "transgender"]

data_gender = {gender: {"uni": [], "prog": [], "country": []} for gender in genders}

for _, row in data.iterrows():
    gender = row["Gender"].strip()
    if gender in data_gender:
        universities, programs = extract_universities_programs(row["Universities"])
        countries = row["Country"]
        data_gender[gender]["uni"].extend(universities)
        data_gender[gender]["prog"].extend(programs)
        data_gender[gender]["country"].extend(countries)

In [None]:
uni_names, uni_counts = {}, {}
prog_names, prog_counts = {}, {}
country_names, country_counts = {}, {}

for gender in genders:
    top_uni = Counter(data_gender[gender]["uni"]).most_common(20)
    top_prog = Counter(data_gender[gender]["prog"]).most_common(20)
    top_countries = Counter(data_gender[gender]["country"]).most_common(20)

    uni_names[gender], uni_counts[gender] = zip(*top_uni) if top_uni else ([], [])
    prog_names[gender], prog_counts[gender] = zip(*top_prog) if top_prog else ([], [])
    country_names[gender], country_counts[gender] = zip(*top_countries) if top_countries else ([], [])

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharey=False)

for i, gender in enumerate(["male", "female", "transgender"]):
    create_bar_plot(
        x=list(uni_counts[gender]),
        y=list(uni_names[gender]),
        palette='plasma',
        xlabel="Count",
        ylabel="Universities" if i == 0 else "",
        title=f"Top Universities - {gender.capitalize()}",
        ax=axes[i]
    )

plt.tight_layout()
plt.savefig(f"{PLOT_FILEPATH}top_universities_by_gender.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
top_20_universities = [uni for uni, _ in Counter(universities_list).most_common(20)]

heatmap_dict = {"University": top_20_universities}

for gender in genders:
    uni_counter = Counter(data_gender[gender]["uni"])
    heatmap_dict[gender] = [uni_counter.get(uni, 0) for uni in top_20_universities]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("University", inplace=True)

heatmap_df = heatmap_df.T

fig, ax = plt.subplots(figsize=(9, 5))

create_heatmap(
    heatmap_df=heatmap_df,
    ax=ax,
    cmap="plasma",
    title="Top 20 Recommended Universities by Gender",
    xlabel="University",
    ylabel="Gender",
    save_path=f"{PLOT_FILEPATH}top_universities_by_gender_heatmap.png"
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharey=False)

for i, gender in enumerate(["male", "female", "transgender"]):
    create_bar_plot(
        x=list(prog_counts[gender]),
        y=list(prog_names[gender]),
        palette="coolwarm",
        xlabel="Count",
        ylabel="Programmes" if i == 0 else "",
        title=f"Top Programmes - {gender.capitalize()}",
        ax=axes[i]
    )

plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_programs_by_gender.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
top_20_programs = [prog for prog, _ in Counter(programs_list).most_common(20)]

heatmap_dict = {"Program": top_20_programs}

for gender in genders:
    prog_counter = Counter(data_gender[gender]["prog"])
    heatmap_dict[gender] = [prog_counter.get(prog, 0) for prog in top_20_programs]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("Program", inplace=True)

heatmap_df = heatmap_df.T

fig, ax = plt.subplots(figsize=(9, 5))

create_heatmap(
    heatmap_df=heatmap_df,
    ax=ax,
    cmap="coolwarm",
    title="Top 20 Recommended Programs by Gender",
    xlabel="Program",
    ylabel="Gender",
    save_path=f"{PLOT_FILEPATH}top_programs_by_gender_heatmap.png"
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=False)

for i, gender in enumerate(["male", "female", "transgender"]):
    create_bar_plot(
        x=list(country_counts[gender]),
        y=list(country_names[gender]),
        palette='viridis',
        xlabel="Count",
        ylabel="Countries" if i == 0 else "",
        title=f"Top Countries Recommended - {gender.capitalize()}",
        ax=axes[i]
    )

plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_countries_by_gender.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
top_20_countries = [country for country, _ in Counter(countries_list).most_common(20)]

heatmap_dict = {"Country": top_20_countries}

for gender in genders:
    country_counter = Counter(data_gender[gender]["country"])
    heatmap_dict[gender] = [country_counter.get(country, 0) for country in top_20_countries]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("Country", inplace=True)

heatmap_df = heatmap_df.T

fig, ax = plt.subplots(figsize=(10, 4))

create_heatmap(
    heatmap_df=heatmap_df,
    ax=ax,
    cmap="viridis",
    title="Top 20 Recommended Countries by Gender",
    xlabel="Country",
    ylabel="Gender",
    save_path=f"{PLOT_FILEPATH}top_countries_by_gender_heatmap.png"
)

## Analysis by 'Economic Class'

In [None]:
# Analysis by Economic Class

economic_classes = ["low-class", "moderate-class", "high-class"]

data_class = {eco_class: {"uni": [], "prog": [], "country": []} for eco_class in economic_classes}

for _, row in data.iterrows():
    eco_class = row["Economic Class"].strip()
    if eco_class in data_class:
        universities, programs = extract_universities_programs(row["Universities"])
        countries = row["Country"]
        data_class[eco_class]["uni"].extend(universities)
        data_class[eco_class]["prog"].extend(programs)
        data_class[eco_class]["country"].extend(countries)

In [None]:
uni_names, uni_counts = {}, {}
prog_names, prog_counts = {}, {}
country_names, country_counts = {}, {}

for eco_class in economic_classes:
    top_uni = Counter(data_class[eco_class]["uni"]).most_common(20)
    top_prog = Counter(data_class[eco_class]["prog"]).most_common(20)
    top_countries = Counter(data_class[eco_class]["country"]).most_common(20)

    uni_names[eco_class], uni_counts[eco_class] = zip(*top_uni) if top_uni else ([], [])
    prog_names[eco_class], prog_counts[eco_class] = zip(*top_prog) if top_prog else ([], [])
    country_names[eco_class], country_counts[eco_class] = zip(*top_countries) if top_countries else ([], [])

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(30, 10), sharey=False)

for i, eco_class in enumerate(["low-class", "moderate-class", "high-class"]):
    create_bar_plot(
        x=list(uni_counts[eco_class]),
        y=list(uni_names[eco_class]),
        palette="plasma",
        xlabel="Count",
        ylabel="University Name" if i == 0 else "",
        title=f"Top Universities - {eco_class.capitalize()}",
        ax=axes[i]
    )

plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_universities_by_eco_class.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
top_20_universities = [uni for uni, _ in Counter(universities_list).most_common(20)]

heatmap_dict = {"University": top_20_universities}

for eco_class in economic_classes:
    uni_counter = Counter(data_class[eco_class]["uni"])
    heatmap_dict[eco_class] = [uni_counter.get(uni, 0) for uni in top_20_universities]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("University", inplace=True)

heatmap_df = heatmap_df.T

fig, ax = plt.subplots(figsize=(12, 6)) 

create_heatmap(
    heatmap_df=heatmap_df,
    ax=ax,
    cmap="plasma",
    title="Top 20 Recommended Universities by Economic Class",
    xlabel="University",
    ylabel="Economic Class",
    save_path=f"{PLOT_FILEPATH}top_universities_by_eco_class_heatmap.png"
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 6), sharey=False)

for i, eco_class in enumerate(["low-class", "moderate-class", "high-class"]):
    create_bar_plot(
        x=prog_counts[eco_class], 
        y=prog_names[eco_class],
        palette='coolwarm',
        xlabel="Count",
        ylabel="Programmes" if i == 0 else "",
        title=f"Top Programmes - {eco_class.capitalize()}",
        ax=axes[i]
    )

plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_programs_by_eco_class.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
top_20_programs = [prog for prog, _ in Counter(programs_list).most_common(20)]

heatmap_dict = {"Program": top_20_programs}

for eco_class in economic_classes:
    prog_counter = Counter(data_class[eco_class]["prog"])
    heatmap_dict[eco_class] = [prog_counter.get(prog, 0) for prog in top_20_programs]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("Program", inplace=True)

heatmap_df = heatmap_df.T


fig, ax = plt.subplots(figsize=(12, 5)) 

create_heatmap(
    heatmap_df=heatmap_df,
    ax=ax,
    cmap="coolwarm",
    title="Top 20 Recommended Programs by Economic Class",
    xlabel="Program",
    ylabel="Economic Class",
    save_path=f"{PLOT_FILEPATH}top_programs_by_eco_class_heatmap.png"
)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=False)

for i, eco_class in enumerate(["low-class", "moderate-class", "high-class"]):
    create_bar_plot(
        x=country_counts[eco_class],
        y=country_names[eco_class],
        palette='viridis',
        xlabel="Count",
        ylabel="Countries" if i == 0 else "",
        title=f"Top Countries Recommended - {eco_class.capitalize()}",
        ax=axes[i]
    )

plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_countries_by_eco_class.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
top_20_countries = [country for country, _ in Counter(countries_list).most_common(20)]

heatmap_dict = {"Country": top_20_countries}

for eco_class in economic_classes:
    country_counter = Counter(data_class[eco_class]["country"])
    heatmap_dict[eco_class] = [country_counter.get(country, 0) for country in top_20_countries]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("Country", inplace=True)


heatmap_df = heatmap_df.T

fig, ax = plt.subplots(figsize=(14, 5))

create_heatmap(
    heatmap_df=heatmap_df,
    ax=ax,
    cmap="viridis",
    title="Top 20 Recommended Countries by Economic Class",
    xlabel="Country",
    ylabel="Economic Class",
    save_path=f"{PLOT_FILEPATH}top_countries_by_eco_class_heatmap.png"
)

## Analysis by 'Nationality'

In [None]:
# Analysis by Nationality

nationalities = [
    "Nigeria", "Egypt", "South Africa", "Kenya", "Ghana", "Ethiopia", "Algeria", "Morocco", # Africa
    "China", "India", "Japan", "South Korea", "Indonesia", "Thailand", "Saudi Arabia", "Vietnam", # Asia
    "France", "Germany", "Italy", "Spain", "United Kingdom", "Sweden", "Poland", "Greece", # Europe
    "United States", "Canada", "Mexico", "Cuba", "Costa Rica", "Jamaica", # North America
    "Brazil", "Argentina", "Chile", "Peru", "Colombia", # South America
    "Australia", "New Zealand", "Fiji", "Papua New Guinea", "Tonga" # Oceania
]

data_nation = {nation: {"uni": [], "prog": [], "country": []} for nation in nationalities}

for _, row in data.iterrows():
    nation = row["Nationality"].strip()
    if nation in data_nation:
        universities, programs = extract_universities_programs(row["Universities"])
        countries = row["Country"]
        data_nation[nation]["uni"].extend(universities)
        data_nation[nation]["prog"].extend(programs)
        data_nation[nation]["country"].extend(countries)

In [None]:
uni_names, uni_counts = {}, {}
prog_names, prog_counts = {}, {}
country_names, country_counts = {}, {}

for nation in nationalities:
    top_uni = Counter(data_nation[nation]["uni"]).most_common(20)
    top_prog = Counter(data_nation[nation]["prog"]).most_common(20)
    top_countries = Counter(data_nation[nation]["country"]).most_common(20)

    uni_names[nation], uni_counts[nation] = zip(*top_uni) if top_uni else ([], [])
    prog_names[nation], prog_counts[nation] = zip(*top_prog) if top_prog else ([], [])
    country_names[nation], country_counts[nation] = zip(*top_countries) if top_countries else ([], [])

In [None]:
fig, axes = plt.subplots(10, 4, figsize=(32, 60), sharey=False)
axes = axes.flatten()

for i, nation in enumerate(nationalities):
    create_bar_plot(
        x=list(uni_counts[nation]),
        y=list(uni_names[nation]),
        palette="plasma",
        xlabel="Count",
        ylabel="University Name" if i % 4 == 0 else "",
        title=f"Top Universities - {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()

In [None]:
top_20_universities = [uni for uni, _ in Counter(universities_list).most_common(20)]

heatmap_dict = {"University": top_20_universities}

for nationality in nationalities:
    uni_counter = Counter(data_nation[nationality]["uni"])
    heatmap_dict[nationality] = [uni_counter.get(uni, 0) for uni in top_20_universities]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("University", inplace=True)

plt.figure(figsize=(32, 16))
sns.heatmap(heatmap_df, annot=True, fmt="d", cmap="plasma", linewidths=.5)
plt.title("Top 20 Recommended Universities by Nationality")
plt.xlabel("Nationality")
plt.ylabel("University")
plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_universities_by_nationality_heatmap.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
fig, axes = plt.subplots(10, 4, figsize=(32, 60), sharey=False)
axes = axes.flatten()

for i, nation in enumerate(nationalities):
    create_bar_plot(
        x=list(prog_counts[nation]),
        y=list(prog_names[nation]),
        palette="coolwarm",
        xlabel="Count",
        ylabel="",
        title=f"Top Programmes - {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()

In [None]:
top_20_programs = [prog for prog, _ in Counter(programs_list).most_common(20)]

heatmap_dict = {"Program": top_20_programs}

for nationality in nationalities:
    prog_counter = Counter(data_nation[nationality]["prog"])
    heatmap_dict[nationality] = [prog_counter.get(prog, 0) for prog in top_20_programs]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("Program", inplace=True)

plt.figure(figsize=(32, 16))
sns.heatmap(heatmap_df, annot=True, fmt="d", cmap="coolwarm", linewidths=.5)
plt.title("Top 20 Recommended Programs by Nationality")
plt.xlabel("Nationality")
plt.ylabel("Program")
plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_programs_by_nationality_heatmap.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
fig, axes = plt.subplots(10, 4, figsize=(30, 60), sharey=False)
axes = axes.flatten()

for i, nation in enumerate(nationalities):
    create_bar_plot(
        x=list(country_counts[nation]),
        y=list(country_names[nation]),
        palette="viridis",
        xlabel="Count",
        ylabel="",
        title=f"Top Countries - {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()

In [None]:
top_20_countries = [country for country, _ in Counter(countries_list).most_common(20)]

heatmap_dict = {"Country": top_20_countries}

for nationality in nationalities:
    country_counter = Counter(data_nation[nationality]["country"])
    heatmap_dict[nationality] = [country_counter.get(country, 0) for country in top_20_countries]

heatmap_df = pd.DataFrame(heatmap_dict)
heatmap_df.set_index("Country", inplace=True)

plt.figure(figsize=(32, 16))
sns.heatmap(heatmap_df, annot=True, fmt="d", cmap="viridis", linewidths=.5)
plt.title("Top 20 Recommended Countries by Nationality")
plt.xlabel("Nationality")
plt.ylabel("Country")
plt.tight_layout()

plt.savefig(f"{PLOT_FILEPATH}top_countries_by_nationality_heatmap.png", dpi=300, bbox_inches='tight')

plt.show()

In [None]:
import sys
sys.exit("Execution Terminated!!")

## Analysis by 'Gender' and 'Nationality'

In [None]:
gender_nation_combinations = list(product(genders, nationalities))
data_gender_nation = {
    (gender, nation): {"uni": [], "prog": [], "country": []} for gender, nation in gender_nation_combinations
}

for _, row in data.iterrows():
    gender = row["Gender"].strip()
    nation = row["Nationality"].strip()
    if (gender, nation) in data_gender_nation:        
        universities, programs = extract_universities_programs(row["Universities"])
        countries = row["Country"]
        data_gender_nation[(gender, nation)]["uni"].extend(universities)
        data_gender_nation[(gender, nation)]["prog"].extend(programs)
        data_gender_nation[(gender, nation)]["country"].extend(countries)

In [None]:
uni_names, uni_counts = {}, {}
prog_names, prog_counts = {}, {}
country_names, country_counts = {}, {}

for gender, nation in gender_nation_combinations:    
    top_uni = Counter(data_gender_nation[(gender, nation)]["uni"]).most_common(5)    
    top_prog = Counter(data_gender_nation[(gender, nation)]["prog"]).most_common(5)
    top_countries = Counter(data_gender_nation[(gender, nation)]["country"]).most_common(5)

    uni_names[(gender, nation)], uni_counts[(gender, nation)] = zip(*top_uni) if top_uni else ([], [])
    prog_names[(gender, nation)], prog_counts[(gender, nation)] = zip(*top_prog) if top_prog else ([], [])
    country_names[(gender, nation)], country_counts[(gender, nation)] = zip(*top_countries) if top_countries else ([], [])

In [None]:
selected_genders = ["male", "female", "transgender"]
selected_nations = [
    "Nigeria", "Egypt", "South Africa", "Kenya", "Ghana", "Ethiopia", "Algeria", "Morocco", # Africa
    "China", "India", "Japan", "South Korea", "Indonesia", "Thailand", "Saudi Arabia", "Vietnam", # Asia
    "France", "Germany", "Italy", "Spain", "United Kingdom", "Sweden", "Poland", "Greece", # Europe
    "United States", "Canada", "Mexico", "Cuba", "Costa Rica", "Jamaica", # North America
    "Brazil", "Argentina", "Chile", "Peru", "Colombia", # South America
    "Australia", "New Zealand", "Fiji", "Papua New Guinea", "Tonga" # Oceania
]

selected_pairs = list(product(selected_genders, selected_nations))

n = len(selected_pairs)

In [None]:
fig, axes = plt.subplots(n//4, 4, figsize=(32, 6 * n//4), sharey=False)
axes = axes.flatten()

for i, (gender, nation) in enumerate(selected_pairs):
    create_bar_plot(
        x=list(uni_counts.get((gender, nation), [])),
        y=list(uni_names.get((gender, nation), [])),
        palette='plasma', 
        xlabel="Count",
        ylabel="",
        title=f"Top Universities - {gender}, {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(n//4, 4, figsize=(32, 6 * n//4), sharey=False)
axes = axes.flatten()

for i, (gender, nation) in enumerate(selected_pairs):
    create_bar_plot(
        x=list(prog_counts.get((gender, nation), [])),
        y=list(prog_names.get((gender, nation), [])),
        palette="coolwarm",
        xlabel="Count",
        ylabel="",
        title=f"Top Programmes - {gender}, {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()


In [None]:
fig, axes = plt.subplots(n//4, 4, figsize=(32, 6 * n//4), sharey=False)
axes = axes.flatten()

for i, (gender, nation) in enumerate(selected_pairs):
    create_bar_plot(
        x=list(country_counts.get((gender, nation), [])),
        y=list(country_names.get((gender, nation), [])),
        palette="viridis",
        xlabel="Count",
        ylabel="",
        title=f"Top Countries - {gender}, {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()

## Analysis by 'Economic Class' and 'Nationality'

In [None]:
class_nation_combinations = list(product(economic_classes, nationalities))
data_class_nation = {
    (eco_class, nation): {"uni": [], "prog": [], "country": []} for eco_class, nation in class_nation_combinations
}

for _, row in data.iterrows():
    eco_class = row["Economic Class"].strip()
    nation = row["Nationality"].strip()
    if (eco_class, nation) in data_class_nation:        
        universities, programs = extract_universities_programs(row["Universities"])
        countries = row["Country"]
        data_class_nation[(eco_class, nation)]["uni"].extend(universities)
        data_class_nation[(eco_class, nation)]["prog"].extend(programs)
        data_class_nation[(eco_class, nation)]["country"].extend(countries)

In [None]:
uni_names, uni_counts = {}, {}
prog_names, prog_counts = {}, {}
country_names, country_counts = {}, {}

for eco_class, nation in class_nation_combinations:
    top_uni = Counter(data_class_nation[(eco_class, nation)]["uni"]).most_common(5)
    top_prog = Counter(data_class_nation[(eco_class, nation)]["prog"]).most_common(5)
    top_countries = Counter(data_class_nation[(eco_class, nation)]["country"]).most_common(5)

    uni_names[(eco_class, nation)], uni_counts[(eco_class, nation)] = zip(*top_uni) if top_uni else ([], [])
    prog_names[(eco_class, nation)], prog_counts[(eco_class, nation)] = zip(*top_prog) if top_prog else ([], [])
    country_names[(eco_class, nation)], country_counts[(eco_class, nation)] = zip(*top_countries) if top_countries else ([], [])

In [None]:
selected_eco_classes = ["low-class", "high-class"]
selected_nations = [
    "Nigeria", "Egypt", "South Africa", "Kenya", "Ghana", "Ethiopia", "Algeria", "Morocco", # Africa
    "China", "India", "Japan", "South Korea", "Indonesia", "Thailand", "Saudi Arabia", "Vietnam", # Asia
    "France", "Germany", "Italy", "Spain", "United Kingdom", "Sweden", "Poland", "Greece", # Europe
    "United States", "Canada", "Mexico", "Cuba", "Costa Rica", "Jamaica", # North America
    "Brazil", "Argentina", "Chile", "Peru", "Colombia", # South America
    "Australia", "New Zealand", "Fiji", "Papua New Guinea", "Tonga" # Oceania
]

selected_pairs = list(product(selected_eco_classes, selected_nations))
n = len(selected_pairs)

In [None]:
fig, axes = plt.subplots(n//4, 4, figsize=(32, 6 * n//4), sharey=False)
axes = axes.flatten()

for i, (eco_class, nation) in enumerate(selected_pairs):
    create_bar_plot(
        x=list(uni_counts.get((eco_class, nation), [])),
        y=list(uni_names.get((eco_class, nation), [])),
        palette="plasma", 
        xlabel="Count",
        ylabel="",
        title=f"Top Universities - {eco_class}, {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()


In [None]:
fig, axes = plt.subplots(n//4, 4, figsize=(32, 6 * n//4), sharey=False)
axes = axes.flatten()

for i, (eco_class, nation) in enumerate(selected_pairs):
    create_bar_plot(
        x=list(prog_counts.get((eco_class, nation), [])),
        y=list(prog_names.get((eco_class, nation), [])),
        palette="coolwarm",
        xlabel="Count",
        ylabel="",
        title=f"Top Programmes - {eco_class}, {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()


In [None]:
fig, axes = plt.subplots(n//4, 4, figsize=(32, 6 * n//4), sharey=False)
axes = axes.flatten()

for i, (eco_class, nation) in enumerate(selected_pairs):
    create_bar_plot(
        x=list(country_counts.get((eco_class, nation), [])),
        y=list(country_names.get((eco_class, nation), [])),
        palette="viridis",
        xlabel="Count",
        ylabel="",
        title=f"Top Countries - {eco_class}, {nation}",
        ax=axes[i]
    )

plt.tight_layout()
plt.show()