In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
%cd gdrive/My Drive/UncertaintyModelling/

/content/gdrive/My Drive/UncertaintyModelling


In [3]:
!pip install pgmpy

Collecting pgmpy
  Downloading pgmpy-0.1.18-py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 4.4 MB/s 
Installing collected packages: pgmpy
Successfully installed pgmpy-0.1.18


In [4]:
import pandas as pd
import numpy as np
import pickle
import pgmpy
import sklearn
from sklearn.metrics import f1_score, accuracy_score, recall_score

In [5]:
with open("models/bayesian_18countries_learnt_bn_treesearch.pickle", "rb") as handle:
  model = pickle.load(handle)

In [6]:
train_df = pd.read_csv("data/train_18_countries.csv")
print(train_df.shape)
train_df.head()

(495, 307)


Unnamed: 0,deaths_per_mil_cat_india,facial_covering_new_zealand,cancel_public_events_argentina,debt_relief_india,income_support_singapore,school_closures_argentina,vaccine_policy_china,restriction_internal_movement_indonesia,workplace_closures_finland,internation_travel_control_finland,...,cases_per_mil_cat_finland,cases_per_mil_cat_hong_kong,cases_per_mil_cat_indonesia,cases_per_mil_cat_india,cases_per_mil_cat_israel,cases_per_mil_cat_malaysia,cases_per_mil_cat_new_zealand,cases_per_mil_cat_singapore,cases_per_mil_cat_united_states,cases_per_mil_cat_south_africa
0,0,3,2,1.0,2,1,5,2,1,2.0,...,1,0,1,1,5,3,0,1,6,2
1,0,2,2,0.0,2,3,4,2,2,3.0,...,1,0,0,1,4,1,0,1,4,1
2,0,3,1,1.0,2,0,5,2,1,2.0,...,4,0,1,1,6,4,0,3,6,3
3,0,3,2,1.0,2,1,5,2,1,2.0,...,1,0,1,1,5,3,0,1,5,2
4,0,2,2,0.0,2,1,4,2,2,3.0,...,1,0,0,0,4,1,0,1,4,1


In [7]:
test_df = pd.read_csv("data/test_18_countries.csv")
print(test_df.shape)
test_df.head()

(165, 307)


Unnamed: 0,deaths_per_mil_cat_india,facial_covering_new_zealand,cancel_public_events_argentina,debt_relief_india,income_support_singapore,school_closures_argentina,vaccine_policy_china,restriction_internal_movement_indonesia,workplace_closures_finland,internation_travel_control_finland,...,cases_per_mil_cat_finland,cases_per_mil_cat_hong_kong,cases_per_mil_cat_indonesia,cases_per_mil_cat_india,cases_per_mil_cat_israel,cases_per_mil_cat_malaysia,cases_per_mil_cat_new_zealand,cases_per_mil_cat_singapore,cases_per_mil_cat_united_states,cases_per_mil_cat_south_africa
0,0,3,1,1.0,2,0,5,2,1,2.0,...,4,0,1,1,6,4,0,3,6,3
1,0,3,2,1.0,2,1,5,2,1,2.0,...,1,0,1,1,6,3,0,1,6,2
2,0,2,2,2.0,2,3,0,2,1,3.0,...,0,0,0,0,1,0,0,1,1,1
3,0,3,2,1.0,2,1,5,2,1,2.0,...,1,0,1,1,5,3,0,1,5,2
4,0,0,2,2.0,2,3,0,1,1,3.0,...,0,0,0,0,0,0,0,0,1,0


In [8]:
test_df = test_df.astype(int)

**Split by features**

In [9]:
policies = [
            "school_closures",
            "workplace_closures",
            "cancel_public_events",
            "restrict_public_gathering",
            "closure_public_transport",
            "shn_requirement",
            "restriction_internal_movement",
            "internation_travel_control",
            "pi_campaign",
            "testing_policy",
            "contact_tracing",
            "facial_covering",
            "vaccine_policy",
            "income_support",
            "debt_relief"
]

In [10]:
dfs_policy = {}

In [11]:
import regex as re
for policy in policies:
  pattern = "^policy"
  my_regex = r"^" + re.escape(policy)
  dfs_policy[policy] = test_df.filter(regex=(my_regex))

In [12]:
from pgmpy.inference import VariableElimination
ve = VariableElimination(model)

  import pandas.util.testing as tm


In [None]:
import networkx as nx
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1,1, figsize=(80,80))
nx.draw(model, with_labels=True, ax=ax)
plt.tight_layout()

**Split by countries**

In [13]:
countries = [
    "singapore",
    "china",
    "malaysia",
    "indonesia",
    "hong_kong",
    "australia",
    "new_zealand",
    "united_states",
    "canada",
    "argentina",
    "brazil",
    "south_africa",
    "egypt",
    "germany",
    "finland",
    "switzerland",
    "israel",
    "india"
]

In [14]:
dfs = {}

In [15]:
import regex as re
for country in countries:
  pattern = "country$"
  my_regex = re.escape(country) + r"$"
  dfs[country] = test_df.filter(regex=(my_regex))

In [16]:
from pgmpy.inference import VariableElimination
ve = VariableElimination(model)

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1,1, figsize=(80,80))
nx.draw(model, with_labels=True, ax=ax)
plt.tight_layout()

In [17]:
pred_list = {}
for country in countries:
  preds = []
  deaths_per_mil_cat_country = 'deaths_per_mil_cat_'+country
  cases_per_mil_cat_country = 'cases_per_mil_cat_'+country

  for idx, row in dfs[country].drop([deaths_per_mil_cat_country, cases_per_mil_cat_country], axis = 1).iterrows():
    evidence = row.to_dict()
    res = ve.query(
          [cases_per_mil_cat_country], evidence=evidence, show_progress=False
      )
    ans_dict = {}
    for key, val in zip(res.state_names[res.variables[0]], res.values):
      ans_dict[key] = val
    max_val = max(ans_dict, key=ans_dict.get)
    preds.append(max_val)
  pred_list[country] = preds

  phi.values = phi.values / phi.values.sum()


In [18]:
predictions_list = {}
for country in countries:
  predictions = []
  for t in pred_list[country]:
    predictions.append(t)
  cases_per_mil_cat_country = 'cases_per_mil_cat_'+country
  preds_df = pd.DataFrame(predictions, columns = [cases_per_mil_cat_country])
  predictions_list[country] = preds_df

In [20]:
from sklearn.metrics import f1_score, accuracy_score, recall_score

In [22]:
scores = {}
for country in countries:
  preds_df = predictions_list[country]
  cases_per_mil_cat_country = 'cases_per_mil_cat_'+country
  actual_df = test_df[cases_per_mil_cat_country]
  f1_s = f1_score(actual_df, preds_df, average='micro')
  accuracy_s = accuracy_score(actual_df, preds_df)
  recall_s = recall_score(actual_df, preds_df, average='weighted')
  scores[country] = {
        "accuracy": accuracy_s,
        "f1": f1_s,
        "recall": recall_s,
    }

In [23]:
scores

{'argentina': {'accuracy': 0.8424242424242424,
  'f1': 0.8424242424242424,
  'recall': 0.8424242424242424},
 'australia': {'accuracy': 0.9393939393939394,
  'f1': 0.9393939393939394,
  'recall': 0.9393939393939394},
 'brazil': {'accuracy': 0.8, 'f1': 0.8000000000000002, 'recall': 0.8},
 'canada': {'accuracy': 0.8848484848484849,
  'f1': 0.8848484848484849,
  'recall': 0.8848484848484849},
 'china': {'accuracy': 1.0, 'f1': 1.0, 'recall': 1.0},
 'egypt': {'accuracy': 1.0, 'f1': 1.0, 'recall': 1.0},
 'finland': {'accuracy': 0.9333333333333333,
  'f1': 0.9333333333333333,
  'recall': 0.9333333333333333},
 'germany': {'accuracy': 0.9090909090909091,
  'f1': 0.9090909090909091,
  'recall': 0.9090909090909091},
 'hong_kong': {'accuracy': 1.0, 'f1': 1.0, 'recall': 1.0},
 'india': {'accuracy': 0.9515151515151515,
  'f1': 0.9515151515151515,
  'recall': 0.9515151515151515},
 'indonesia': {'accuracy': 0.9212121212121213,
  'f1': 0.9212121212121213,
  'recall': 0.9212121212121213},
 'israel': {'ac

In [12]:
nodes_and_children = [(node, model.get_children(node)) for node in model.nodes]
node_to_children_map = {
    node: children
    for node, children in sorted(
        nodes_and_children, key=lambda el: len(el[1]), reverse=True
    )
}

In [None]:
for idx, (node, children) in enumerate(node_to_children_map.items()):
    if node == "date":
        continue

    print(f"Node - {node}")
    print(f"Number of children - {len(children)}")
    policy = node[: node.rindex("_")]
    print(
        f"Number of related policies in other countries - {[child for child in children if policy in child]}"
    )
    linked_countries = {
        child[child.rindex("_") + 1 :] for child in children if child != "date"
    }
    print(f"Number of linked countries - {len(linked_countries)} ({linked_countries})")
    print()

    if idx == 50:
        break

In [27]:
[edge for edge in model.edges if "cases_per_mil_cat_india" in edge[0]]

[]

New Zealand had some of the most strict measures when it came to covid control. And this was rightly inferred from the bayesian network model that we built as we got f1 scores close to 1 when we tried to infer number of covid cases from the policies enforced there since the consistent strict measures kept the covid cases consistently low.