# KNN

## Imports

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from collections import defaultdict
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.ensemble
import sklearn.neighbors

sys.path.append('../')
from batchers import batcher, dataset_constants
from models.histograms import (
    get_per_image_histograms,
    plot_band_hists,
    plot_label_hist,
    split_nl_hist)
from models.linear_model import ridge_cv
from models.knn import knn_cv_opt
from utils.analysis import calc_score, evaluate
from utils.general import load_npz
from utils.plot import scatter_preds

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = ''

DATASET_NAME = '2009-17'
LABEL_NAME = 'wealthpooled'

FOLDS = ['A', 'B', 'C', 'D', 'E']
SPLITS = ['train', 'val', 'test']
COUNTRIES = dataset_constants.DHS_COUNTRIES

MEANS = dataset_constants.MEANS_DICT[DATASET_NAME]
STD_DEVS = dataset_constants.STD_DEVS_DICT[DATASET_NAME]

LOGS_ROOT_DIR = '../logs/'

In [4]:
file_path = '../data/dhs_image_hists.npz'

result = load_npz(file_path)
image_hists = result['image_hists'] # 8 bands: ['BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR', 'NIGHTLIGHTS']
labels = result['labels']
locs = result['locs']
years = result['years']
nls_center = result['nls_center']
nls_mean = result['nls_mean']

dmsp_mask = years < 2012
viirs_mask = ~dmsp_mask

# split NL band
image_hists = split_nl_hist(image_hists, years) # 9 bands: ['BLUE', 'GREEN', 'RED', 'SWIR1', 'SWIR2', 'TEMP1', 'NIR', 'DMSP', 'VIIRS']

image_hists: dtype=int64, shape=(19669, 8, 102)
labels: dtype=float32, shape=(19669,)
locs: dtype=float32, shape=(19669, 2)
years: dtype=int32, shape=(19669,)
nls_center: dtype=float32, shape=(19669,)
nls_mean: dtype=float32, shape=(19669,)


## Load folds

### Incountry folds + loc_dict

`loc_dict` has the format:
```python
{
    (lat, lon): {
        'cluster': 1,
        'country': 'malawi',
        'country_year': 'malawi_2012',  # surveyID
        'households': 25,
        'urban': False,
        'wealth': -0.513607621192932,
        'wealthpooled': -0.732255101203918,
        'year': 2012
    }, ...
}
```

NOTE: `year` and `country_year` might differ in the year. `country_year` is the survey ID, which says which year the survey started. However, sometimes the DHS surveys cross the year-boundary, in which case `country_year` will remain the same but `year` will be the next year.

In [7]:
with open('../data/dhs_incountry_folds.pkl', 'rb') as f:
    incountry_folds = pickle.load(f)

with open('../data/dhs_loc_dict.pkl', 'rb') as f:
    loc_dict = pickle.load(f)

incountry_group_labels = np.zeros_like(labels, dtype=np.int32)
for i, f in enumerate(FOLDS):
    test_indices = incountry_folds[f]['test']
    incountry_group_labels[test_indices] = i

### `country_indices` and `country_labels`

`country_indices` is a dictionary that maps a country name to a sorted `np.array` of its indices
```python
{ 'malawi': np.array([ 8530,  8531,  8532, ..., 10484, 10485, 10486]), ... }
```

`country_labels` is a `np.array` that shows which country each example belongs to
```python
np.array([0, 0, 0, 0, ..., 22, 22, 22])
```
where countries are indexed by their position in `dataset_constants.DHS_COUNTRIES`

In [8]:
country_indices = defaultdict(list)  # country => np.array of indices
country_labels = np.zeros(len(locs), dtype=np.int32)  # np.array of country labels
households = np.zeros(len(locs), dtype=np.int32)  # np.array of household counts

for i, loc in enumerate(locs):
    country = loc_dict[tuple(loc)]['country']
    country_indices[country].append(i)
    households[i] = loc_dict[tuple(loc)]['households']

for i, country in enumerate(COUNTRIES):
    country_indices[country] = np.asarray(country_indices[country])
    indices = country_indices[country]
    country_labels[indices] = i

### OOC folds

In [19]:
# 'A': {
#     'train': np.array([1, 10, 13, ...]),
#     ...
# }
ooc_folds = {
    f: {split: [] for split in SPLITS}
    for f in FOLDS
}

for f in FOLDS:
    surveys_dict = dataset_constants.SURVEY_NAMES[f'2009-17{f}']
    for split, countries in surveys_dict.items():
        split_indices = np.sort(np.concatenate([
            country_indices[country] for country in countries
        ]))
        ooc_folds[f][split] = split_indices

## KNN (OOC, leave-one-country-out)

In [9]:
def knn_ooc_wrapper(img_hists, model_name, years=None):
    savedir = os.path.join(LOGS_ROOT_DIR, 'dhs_knn', model_name)
    features = img_hists.reshape(len(img_hists), -1)

    filename = 'test_preds.npz'
    npz_path = os.path.join(savedir, filename)
    # assert not os.path.exists(npz_path)

    dmsp_mask = years < 2012
    viirs_mask = ~dmsp_mask
    test_preds = np.zeros(len(image_hists), dtype=np.float32)
    for mask in [dmsp_mask, viirs_mask]:
        test_preds[mask] = knn_cv_opt(
            features=features[mask], labels=labels[mask],
            group_labels=country_labels[mask], group_names=COUNTRIES)

    os.makedirs(savedir, exist_ok=True)
    save_dict = {
        'labels': labels,
        'test_preds': test_preds
    }
    print('saving test preds to:', npz_path)
    np.savez_compressed(npz_path, **save_dict)

    evaluate(labels=labels, preds=test_preds, do_print=True,
                title='DMSP and VIIRS test preds')

In [10]:
model_name = 'nlmean_scalar'
knn_ooc_wrapper(model_name=model_name, img_hists=nls_mean, years=years)

Pre-computing distance matrix... took 0.18 seconds.
Group: angola
best val mse: 0.196, best k: 64, test mse: 0.354
Group: benin
no examples corresponding to group benin were found
Group: burkina_faso
best val mse: 0.218, best k: 256, test mse: 0.158
Group: cameroon
best val mse: 0.199, best k: 64, test mse: 0.315
Group: cote_d_ivoire
no examples corresponding to group cote_d_ivoire were found
Group: democratic_republic_of_congo
no examples corresponding to group democratic_republic_of_congo were found
Group: ethiopia
best val mse: 0.206, best k: 128, test mse: 0.219
Group: ghana
no examples corresponding to group ghana were found
Group: guinea
no examples corresponding to group guinea were found
Group: kenya
no examples corresponding to group kenya were found
Group: lesotho
best val mse: 0.208, best k: 256, test mse: 0.138
Group: malawi
best val mse: 0.207, best k: 256, test mse: 0.179
Group: mali
no examples corresponding to group mali were found
Group: mozambique
best val mse: 0.210,

In [12]:
results = load_npz("../logs/dhs_knn/nlmean_scalar/test_preds.npz")

labels: dtype=float32, shape=(19669,)
test_preds: dtype=float32, shape=(19669,)


In [16]:
correlation_coefficient = np.corrcoef(results['labels'], results['test_preds'])[0,1]

In [17]:
correlation_coefficient

0.8100337075857674