In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
data = dict()
states = ['Ohio', 'Pennsylvania', 'West Virginia', 'Virginia', 'Kentucky']
states.sort()
for s in states:
    data[s] = pd.read_excel('county-to-county-2011-2015-ins-outs-nets-gross.xlsx',sheet_name=s)

population = {
    'Ohio': 11536504,
    'Pennsylvania': 12702379,
    'West Virginia': 1852994,
    'Virginia': 8001024,
    'Kentucky': 4339333,
}

In [3]:
def mig(data_A, A_population, B):

    mig_count = 0
    for name, number in zip(data_A['Unnamed: 6'], data_A['Unnamed: 8']):
        if name == B:
            mig_count += number
        
    return mig_count / A_population

In [4]:
def calculate_avg_mig(A, B):
    data_A = data[A]
    A_population = population[A]
    data_B = data[B]
    B_population = population[B]
    mig_avg = (mig(data_A, A_population, B) + mig(data_B, B_population, A)) / 2
    return mig_avg

In [5]:
# get_mig_rate
def abbr(A):
    if A == 'Kentucky':
        return 'KY'
    elif A == 'Pennsylvania':
        return 'PA'
    elif A == 'Virginia':
        return 'VA'
    elif A == 'West Virginia':
        return 'WV'
    else:
        return 'OH'

mig_rate = dict()
for A in states:
    A_abbr = abbr(A)
    mig_rate[A_abbr] = dict()
    for B in states:
        if A != B:
            B_abbr = abbr(B)
            mig_rate[A_abbr][B_abbr] = calculate_avg_mig(A, B)

mig_rate_norm = dict()
for mr_k, mr_v in mig_rate.items():
    norm = [float(i)/sum(mr_v.values()) for i in mr_v.values()]
    mig_rate_norm[mr_k] = dict()
    for i, k in enumerate(mr_v.keys()):
        mig_rate_norm[mr_k][k] = norm[i]
        
print(mig_rate)
print(mig_rate_norm)

{'KY': {'OH': 0.002534966663552242, 'PA': 0.0003079073783994705, 'VA': 0.0006015156538899991, 'WV': 0.0007315160662077914}, 'OH': {'KY': 0.002534966663552242, 'PA': 0.0011564047005479653, 'VA': 0.0006840071357808726, 'WV': 0.0024447623637997924}, 'PA': {'KY': 0.0003079073783994705, 'OH': 0.0011564047005479653, 'VA': 0.0012598480473323469, 'WV': 0.0013957645604453207}, 'VA': {'KY': 0.0006015156538899991, 'OH': 0.0006840071357808726, 'PA': 0.0012598480473323469, 'WV': 0.0021018499803849735}, 'WV': {'KY': 0.0007315160662077914, 'OH': 0.0024447623637997924, 'PA': 0.0013957645604453207, 'VA': 0.0021018499803849735}}
{'KY': {'OH': 0.6070459459573866, 'PA': 0.07373427369882789, 'VA': 0.14404435544416588, 'WV': 0.17517542489961957}, 'OH': {'KY': 0.3716883146874044, 'PA': 0.1695573044108427, 'VA': 0.1002922299484163, 'WV': 0.3584621509533366}, 'PA': {'KY': 0.0747361667536267, 'OH': 0.2806858834759872, 'VA': 0.3057939508922894, 'WV': 0.33878399887809685}, 'VA': {'KY': 0.12943556536830542, 'OH': 

In [6]:
import json

with open('x_test.txt') as f:
    x_train = f.read()
    
data = json.loads(x_train)

In [9]:
alpha = 0.1
updated_dict = dict()
for drug, drug_d in data.items():
    updated_dict[drug] = dict()
    for s, train_x_year in drug_d.items():
        train_xs = []
        for train_x in train_x_year:
            train_x_year = train_x[0]
            train_x = train_x[1:]
            print("------------- Before update:")
            print(drug, s, train_x_year, train_x)
            neighbor_fv = [0] * 10
            for neighbor, w in mig_rate_norm[s].items():
                for neighbor_train_x in data[drug][neighbor]:
                    if train_x_year == neighbor_train_x[0]:
                        neighbor_fv = [w * neighbor_train_x[1:][i] + neighbor_fv[i] for i in range(len(neighbor_train_x[1:]))]
            
            train_x = [(1-alpha) * train_x[i] + alpha * neighbor_fv[i] for i in range(len(neighbor_fv))]
            print("\n************* After update:")
            print(drug, s, train_x_year, train_x)
            print()
            train_x = [train_x_year] + train_x
            train_xs.append(train_x)
        updated_dict[drug][s] = train_xs

------------- Before update:
Heroin OH 2015 [1.0009632034106692, 0.989789887820435, 0.39285731621988773, 0.9407551022389452, 0.41888565201381633, 0.6758837859372302, 0.0086329446078292, 0.1365641618986133, 0.020827800172391914, 0.2588669843134454]

************* After update:
Heroin OH 2015 [0.9866077633104273, 0.9755998617139731, 0.3875804446783888, 0.9272926182580848, 0.41299713740392613, 0.6670343072870137, 0.008662106189799044, 0.1347405530679822, 0.021660694139998694, 0.24819056203342488]

------------- Before update:
Heroin OH 2016 [1.0043719483822828, 0.9930471137530053, 0.3954026280405225, 0.9440621699606744, 0.420458572198302, 0.680279398334192, 0.008478391720750064, 0.13520387112074853, 0.02019415933977919, 0.2572189113790452]

************* After update:
Heroin OH 2016 [0.9925849586203166, 0.9814073177988247, 0.39114264639303714, 0.9330468769850173, 0.4156313515033596, 0.6731033664479052, 0.008589479999059148, 0.1336257995414633, 0.02114376578555706, 0.24680197817602983]

--

In [10]:
with open(f'gnn_updated_x_test_{alpha}.txt', 'w') as convert_file:
     convert_file.write(json.dumps(updated_dict))