In [1]:
import importlib
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import Seizures
import Aggregation
import networkx as nx



In [6]:
# Read the IDS dataset
importlib.reload(Seizures)
df_ids = Seizures.read_xlsx()

In [7]:
# Obtain all countries, sub-regions and regions in the IDS dataset for cocaine and heroin
cocaine_countries_list, cocaine_sub_region_dict, cocaine_region_dict = Seizures.get_ids_locations(df_ids, drug_list = ['Cocaine'])
heroin_countries_list, heroin_sub_region_dict, heroin_region_dict = Seizures.get_ids_locations(df_ids, drug_list = ['Heroin'])

## Drug prices

In [8]:
def get_drug_prices(file = '/Users/mateicosa/Bocconi/BIDSA/Network_Science/data/Drug_prices.xlsx', drug_list = ['Cocaine', 'Heroin'], 
                    countries_list = None, sub_region_dict = None, region_dict = None, 
                    start_year = 2006, end_year = 2017):
    '''
    Parameters
    ----------
    file : str, optional
    drug_list : list of str, optional
    countries_list : list of str, optional 
    sub_region_dict : dict, optional
    region_dict : dict, optional
    start_year : int, optional
    end_year : int, optional
    Returns
    -------
    seiz_df : dict of pd.DataFrames
        Creates a dict of dataframes with the wholesale prices for the specified drug list
    '''
    
    # Input validation
    if start_year > end_year:
        raise Exception('Invalid years!')
    if not (isinstance(start_year, int) and isinstance(end_year, int)):
        raise Exception('Invalid years!')
    
    # Read the data from the file
    df_price = pd.read_excel(file)
    
    # Take only entries measured in USD/kg at wholesale level
    df_price = df_price[(df_price['Unit'] == 'Kilogram') & (df_price["LevelOfSale"] == 'Wholesale')]
    
    # Rename drug type for convenience
    df_price['Drug'] = df_price['Drug'].replace('Cocaine salts', 'Cocaine')
    df_price['Drug'] = df_price['Drug'].replace('Cocaine hydrochloride', 'Cocaine')
    
    # Auxiliary function to deal with missing vals
    def _remove_missing_vals(df_aux):
        # Impute typical vals with averages or existing vals and remove the remaining entries
        for index, row in df_aux.iterrows():
            if np.isnan(row['Typical_USD']):
                if not np.isnan(row['Minimum_USD']) and not np.isnan(row['Maximum_USD']):
                    row['Typical_USD'] = (row['Minimum_USD'] + row['Maximum_USD']) / 2
                elif not np.isnan(row['Minimum_USD']):
                    row['Typical_USD'] = row['Minimum_USD']
                else:
                    row['Typical_USD'] = row['Maximum_USD']
            df_aux.at[index, 'Typical_USD'] = row['Typical_USD']
            
        df_aux = df_aux[df_aux['Typical_USD'].notna()]
        return df_aux
    
    # Auxiliary function to compute sub_regional vals
    def _get_sub_region_avgs(df_aux, drug, start_year = 2006, end_year = 2017):
        output = dict()
        for year in range(start_year, end_year + 1):
            output[year] = dict(df_aux[(df_aux['Drug'] == drug) & (df_aux['Year'] == year)].groupby('SubRegion')['Typical_USD'].mean())
        return output

    # Auxiliary function to compute regional vals
    def _get_region_avgs(df_aux, drug, start_year = 2006, end_year = 2017):
        output = dict()
        for year in range(start_year, end_year + 1):
            output[year] = dict(df_aux[(df_aux['Drug'] == drug) & (df_aux['Year'] == year)].groupby('Region')['Typical_USD'].mean())
        return output

    # Auxiliary function to compute global vals
    def _get_global_avgs(region_avgs):
        # Create dict of total averages
        year_total = dict()

        # Keep track of total avergae for the entire period and missing years
        total_average = 0
        missing_years = 0

        # Iterate over years
        for year in region_avgs.keys():
            # Check if no data exists for the given year
            if len(region_avgs[year]) == 0:
                year_total[year] = None
                missing_years += 1
                continue
            # Compute the yearly global average across all available regions
            total = 0
            for region in region_avgs[year].keys():
                total += region_avgs[year][region]
            year_total[year] = (total / len(region_avgs[year].keys()))
            total_average += year_total[year]

        # Fill in missing years if they exist
        if missing_years > 0:
            total_average = (total_average / (len(year_total.keys()) - missing_years))
            for year in year_total.keys():
                if year_total[year] is None:
                    year_total[year] = total_average

        return year_total
    
    # Auxiliary function to extract the list of locations
    def _get_locations(df_aux):
        return list(set(df_aux['Country/Territory']))

    # Auxiliary function to return the subregion
    def _get_sub_region(location, df_aux):
        return list(set(df_aux[df_aux['Country/Territory'] == location]['SubRegion']))[0]

    # Auxiliary function to return the region
    def _get_region(location, df_aux):
        return list(set(df_aux[df_aux['Country/Territory'] == location]['Region']))[0]

    # Auxiliary function to create output data structure: dict of dataframes
    def _create_output_df(start_year = 2006, end_year = 2017):
        price_df = dict()
        for i in range(start_year, end_year + 1):
            price_df[i] = pd.DataFrame(columns = ['Country', 'Sub_Region', 'Region', 'Price(USD)'])
        return price_df
    
    # Auxiliary function to retrieve price data for a given country for a given period
    def _get_location_values(location, df_aux, target_df, drug_types = drug_list, sub_region_dict = None, region_dict = None, start_year = 2006, end_year = 2017):
        
        # If provided, we use the sub-region and region data
        if sub_region_dict is None:
            sub_region = _get_sub_region(location, df_aux)
        else:
            sub_region = sub_region_dict[location]

        if region_dict is None:
            region = _get_region(location, df_aux)
        else:
            region = region_dict[location]

        # Iterate over all drug types
        for drug in drug_types:
            # Retrieve regional and sub-regional values, as well as annual global averages to use for imputation
            sub_region_vals = _get_sub_region_avgs(df_aux, drug)
            region_vals = _get_region_avgs(df_aux, drug)
            global_vals = _get_global_avgs(region_vals)

            # Prepare an empty vector to store the time-series
            time_series = np.zeros(end_year - start_year + 1)

            # Iterate over the given period of time
            for year in range(start_year, end_year + 1):
                #Extract data for the given location, drug, and year
                df_temp = df_aux[(df_aux['Country/Territory'] == location) & (df_aux['Drug'] == drug) & (df_aux['Year'] == year)]
                # If no results are found
                if len(df_temp) == 0:
                    # If sub-regional values are available, we use those
                    if sub_region in sub_region_vals[year].keys():
                        time_series[year-start_year] = sub_region_vals[year][sub_region]
                    # Else if regional values are available, we use those
                    elif region in region_vals[year].keys():
                        time_series[year-start_year] = region_vals[year][region]
                    else:
                        time_series[year-start_year] = global_vals[year]
                # If exactly one result is found
                elif len(df_temp) == 1:
                    time_series[year-start_year] = df_temp['Typical_USD']
                # If multiple results are found
                else:
                    time_series[year-start_year] = df_temp['Typical_USD'].mean()

                new_row = {'Country': location, 
                           'Sub_Region': sub_region, 
                           'Region': region,
                           'Price(USD)': time_series[year-start_year]}
                target_df[year].loc[len(target_df[year])] = new_row
    
    # Remove missing values
    df_price = _remove_missing_vals(df_price)
    
    # Create the output data structure: dict of dataframes
    output_df = _create_output_df()
    
    # Get the list of countries
    if countries_list is None:
        locations = _get_locations(df_price)
    else:
        locations = countries_list
    
    # Iterate over the list of countries
    for location in locations:
        _get_location_values(location, df_price, output_df, 
                             sub_region_dict = sub_region_dict, region_dict = region_dict, 
                             start_year = start_year, end_year = end_year)
        
    # Return output_df
    return output_df

In [9]:
df = get_drug_prices(drug_list = ['Cocaine'], 
                    countries_list = cocaine_countries_list, 
                    sub_region_dict = cocaine_sub_region_dict, 
                    region_dict = cocaine_region_dict)

In [10]:
df[2006]

Unnamed: 0,Country,Sub_Region,Region,Price(USD)
0,Afghanistan,Near and Middle East /South-West Asia,Asia,53620.000000
1,Albania,East Europe,Europe,50298.423177
2,Algeria,North Africa,Africa,21036.666667
3,Andorra,West & Central Europe,Europe,50298.423177
4,Angola,Southern Africa,Africa,21036.666667
...,...,...,...,...
169,"Venezuela, Bolivarian Republic of",South America,Americas,4190.000000
170,Viet Nam,East and South-East Asia,Asia,53620.000000
171,Yemen,Near and Middle East /South-West Asia,Asia,53620.000000
172,Zambia,Southern Africa,Africa,21036.666667


## GDP per capita

In [11]:
def get_gdp_per_capita(countries_list, file = "/Users/mateicosa/Bocconi/BIDSA/Network_Science/data/GDP_per_capita.xlsx", 
                       start_year = 2006, end_year = 2017):
    
    # Input validation
    if start_year > end_year:
        raise Exception('Invalid years!')
    if not (isinstance(start_year, int) and isinstance(end_year, int)):
        raise Exception('Invalid years!')
    
    # Read the data from the file
    df_gdp = pd.read_excel(file)
    
    # Create output data structure: dict of dataframes
    output_df = dict()
    for i in range(start_year, end_year + 1):
        output_df[i] = pd.DataFrame(columns = ['Country', 'GDP/capita'])
    
    for year in range(start_year, end_year + 1):
        year_format = f"{year} [YR{year}]"
        for country in countries_list:
            row = {'Country': country, 'GDP/capita': float(df_gdp[df_gdp['Country Name'] == country][year_format])}
            output_df[year].loc[len(output_df[year])] = row
        
    # Return output_df
    return output_df
    

## Geographical coordinates

In [12]:
def get_coordinates(countries_list, file = '/Users/mateicosa/Bocconi/BIDSA/Network_Science/data/countries.csv'):
    
    # Read the data from the file
    df_coord = pd.read_csv(file)
    
    # Replace countries names for consistency
    names_to_replace = {
    'Moldova': 'Moldova, Republic of',
    'R?union': 'Réunion',
    'Bolivia': 'Bolivia, Plurinational State of',
    'Taiwan': 'Taiwan, Province of China',
    'Iran': 'Iran, Islamic Republic of',
    'S?o Tom? and Pr?ncipe': 'Sao Tome and Principe',
    'Laos': "Lao People's Democratic Republic",
    'Macedonia [FYROM]': 'North Macedonia',
    'North Korea': "Korea, Democratic People's Republic of",
    'Swaziland': 'Eswatini',
    'Saint Lucia': 'St. Lucia',
    'Venezuela': 'Venezuela, Bolivarian Republic of',
    'Tanzania': 'Tanzania, United Republic of',
    'Vietnam': 'Viet Nam',
    "C?te d'Ivoire": "Côte d'Ivoire",
    'Kosovo': 'Kosovo under UNSCR 1244',
    'Syria': 'Syrian Arab Republic',
    'Libya': 'Libyan Arab Jamahiriya',
    'Myanmar [Burma]': 'Myanmar',
    'Russia': 'Russian Federation',
    'South Korea': 'Korea, Republic of',
    'Congo [Republic]': 'Congo',
    'Congo [DRC]': 'Congo, the Democratic Republic of the'
    }
    df_coord['name'] = df_coord['name'].replace(names_to_replace)
    
    # Fill in missing values
    df_coord.loc[len(df_coord)] = {
            'country': 'CW', 
            'latitude': float(df_coord[df_coord['name'] == 'Netherlands Antilles']['latitude']),
            'longitude': float(df_coord[df_coord['name'] == 'Netherlands Antilles']['longitude']),
            'name': 'Curaçao'
            }
    
    # Rename columns for consistency
    df_coord = df_coord.rename(columns = {'country': 'ISO', 'latitude': 'Latitude', 'longitude': 'Longitude', 'name': 'Country'})
    
    # Pick only the countries in countries_list
    df_coord = df_coord.query("Country in @countries_list")
    df_coord = df_coord.reset_index(drop = True)
    
    # Missing ISO for Namibia
    df_coord = df_coord.set_index('Country')
    df_coord.at['Namibia', 'ISO'] = 'NA'
    df_coord = df_coord.reset_index()
    
    # Return the output
    return df_coord

## Social and political indicators

In [13]:
def get_social_indicators(countries_list, file = '/Users/mateicosa/Bocconi/BIDSA/Network_Science/data/Governance_Indicators.xlsx', start_year = 2006, end_year = 2017):
    
    # Input validation
    if start_year > end_year:
        raise Exception('Invalid years!')
    if not (isinstance(start_year, int) and isinstance(end_year, int)):
        raise Exception('Invalid years!')
        
    # Read the data from the file
    df_gov = pd.read_excel(file)
    
    # Create output data structure: dict of dataframes
    output_df = dict()
    for i in range(start_year, end_year + 1):
        output_df[i] = pd.DataFrame(columns = ['Country', 'Control_of_Corruption', 
                                              'Gov_Effectiveness', 'Stability_No_Terrorism',
                                              'Regulatory_Quality', 'Rule_of_Law'])
    
    for year in range(start_year, end_year + 1):
        year_format = f"{year} [YR{year}]"
        for country in countries_list:
            try:
                row = {'Country': country, 
                       'Control_of_Corruption': float(df_gov[(df_gov['Country Name'] == country) & (df_gov['Series Name'] == 'Control of Corruption: Estimate')][year_format]),
                       'Gov_Effectiveness': float(df_gov[(df_gov['Country Name'] == country) & (df_gov['Series Name'] == 'Government Effectiveness: Estimate')][year_format]),
                       'Stability_No_Terrorism': float(df_gov[(df_gov['Country Name'] == country) & (df_gov['Series Name'] == 'Political Stability and Absence of Violence/Terrorism: Estimate')][year_format]),
                       'Regulatory_Quality': float(df_gov[(df_gov['Country Name'] == country) & (df_gov['Series Name'] == 'Regulatory Quality: Estimate')][year_format]),
                       'Rule_of_Law': float(df_gov[(df_gov['Country Name'] == country) & (df_gov['Series Name'] == 'Rule of Law: Estimate')][year_format])
                       }
                output_df[year].loc[len(output_df[year])] = row
            except:
                print(country)
    
    # Return the output
    return output_df

## Markets

In [14]:
importlib.reload(Aggregation)
importlib.reload(Seizures)
df_mark = Aggregation.get_national_markets_df(cocaine_countries_list, 
                                              cocaine_sub_region_dict,
                                              cocaine_region_dict,
                                              df_ids,
                                              drug_list = ['Cocaine'])

In [15]:
df_mark[2006]

Unnamed: 0,SubRegion,Country,Drug,Seizures(kg),Consumption(kg),Market(kg)
0,East Africa,Burundi,Cocaine,0.000000,124.573736,124.573736
1,East Africa,Eritrea,Cocaine,0.000000,48.053745,48.053745
2,East Africa,Ethiopia,Cocaine,0.000000,43.60554,43.60554
3,East Africa,Kenya,Cocaine,12.303475,21.453179,33.756654
4,East Africa,Madagascar,Cocaine,0.000000,319.028582,319.028582
...,...,...,...,...,...,...
169,Oceania,Australia,Cocaine,0.000000,429.694661,429.694661
170,Oceania,Cook Islands,Cocaine,0.000000,0.292864,0.292864
171,Oceania,Fiji,Cocaine,0.000000,17.624036,17.624036
172,Oceania,New Zealand,Cocaine,0.264625,113.783647,114.048272


## Join

In [16]:
def get_node_attributes(drug, df_ids, start_year = 2006, end_year = 2017):
    
    # Input validation
    if start_year > end_year:
        raise Exception('Invalid years!')
    if not (isinstance(start_year, int) and isinstance(end_year, int)):
        raise Exception('Invalid years!')
    
    # Get the locations present in the network of the specific drug
    countries_list, sub_region_dict, region_dict = Seizures.get_ids_locations(df_ids, drug_list = [drug])
    
    # Get individual data structures
    df_price = get_drug_prices(countries_list = countries_list, sub_region_dict = sub_region_dict,
                           region_dict = region_dict, drug_list = [drug],
                           start_year = start_year, end_year = end_year)
    
    df_mark = Aggregation.get_national_markets_df(countries_list, 
                                              sub_region_dict,
                                              region_dict,
                                              df_ids,
                                              drug_list = [drug])
    
    df_gdp = get_gdp_per_capita(countries_list = countries_list, start_year = start_year, end_year = end_year)
    
    df_coord = get_coordinates(countries_list = countries_list)
    
    df_gov = get_social_indicators(countries_list = countries_list, start_year = start_year, end_year = end_year)
    
    # Create the output_df
    output_df = df_price.copy()
    
    # Iterate over the time period
    for year in range(start_year, end_year + 1):
        output_df[year] = output_df[year].join(df_mark[year].drop(columns = ['Drug', 'SubRegion']).set_index('Country'), on = 'Country')
        output_df[year] = output_df[year].join(df_gdp[year].set_index('Country'), on = 'Country')
        output_df[year] = output_df[year].join(df_coord.set_index('Country'), on = 'Country')
        output_df[year] = output_df[year].join(df_gov[year].set_index('Country'), on = 'Country')
    
        # Rearrange the columns
        output_df[year] = output_df[year].reindex(columns = [
                                    'Country', 'Sub_Region', 'Region', 'ISO', 'Latitude', 'Longitude', 
                                    'Price(USD)', 'Seizures(kg)', 'Consumption(kg)', 'Market(kg)',
                                    'GDP/capita', 'Control_of_Corruption', 'Gov_Effectiveness', 
                                    'Stability_No_Terrorism', 'Regulatory_Quality', 'Rule_of_Law'])
    
    # Return the output
    return output_df

In [17]:
joint_df_cocaine = get_node_attributes('Cocaine', df_ids)

In [18]:
joint_df_cocaine[2006]

Unnamed: 0,Country,Sub_Region,Region,ISO,Latitude,Longitude,Price(USD),Seizures(kg),Consumption(kg),Market(kg),GDP/capita,Control_of_Corruption,Gov_Effectiveness,Stability_No_Terrorism,Regulatory_Quality,Rule_of_Law
0,Afghanistan,Near and Middle East /South-West Asia,Asia,AF,33.939110,67.709953,53620.000000,0.0,384.07183,384.07183,274.000656,-1.446292,-1.473652,-2.219135,-1.689469,-1.879005
1,Albania,East Europe,Europe,AL,41.153332,20.168331,50298.423177,0.0,82.398561,82.398561,2972.743618,-0.790545,-0.580953,-0.508157,-0.148599,-0.703136
2,Algeria,North Africa,Africa,DZ,28.033886,1.659626,21036.666667,0.0,23.99524,23.99524,3500.134610,-0.564963,-0.438162,-1.126413,-0.500463,-0.775205
3,Andorra,West & Central Europe,Europe,AD,42.546245,1.601554,50298.423177,0.0,2.384238,2.384238,43084.292912,1.251313,1.538403,1.349654,1.349776,0.855686
4,Angola,Southern Africa,Africa,AO,-11.202692,17.873887,21036.666667,0.0,11.203391,11.203391,2597.963585,-1.241038,-1.381049,-0.539453,-1.156216,-1.311273
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
169,"Venezuela, Bolivarian Republic of",South America,Americas,VE,6.423750,-66.589730,4190.000000,0.0,1007.743276,1007.743276,6769.868414,-1.049627,-1.066524,-1.252868,-1.208009,-1.409444
170,Viet Nam,East and South-East Asia,Asia,VN,14.058324,108.277199,53620.000000,0.0,1750.526941,1750.526941,790.592516,-0.747617,-0.246990,0.405120,-0.625980,-0.518899
171,Yemen,Near and Middle East /South-West Asia,Asia,YE,15.552727,48.516388,53620.000000,0.0,349.927954,349.927954,867.782937,-0.789497,-0.920007,-1.345312,-0.773363,-1.039284
172,Zambia,Southern Africa,Africa,ZM,-13.133897,27.849332,21036.666667,0.0,6.540733,6.540733,1065.596417,-0.551736,-0.877599,0.363792,-0.655607,-0.549667


In [19]:
joint_df_heroin = get_node_attributes('Heroin', df_ids)

In [20]:
joint_df_heroin[2006]

Unnamed: 0,Country,Sub_Region,Region,ISO,Latitude,Longitude,Price(USD),Seizures(kg),Consumption(kg),Market(kg),GDP/capita,Control_of_Corruption,Gov_Effectiveness,Stability_No_Terrorism,Regulatory_Quality,Rule_of_Law
0,Afghanistan,Near and Middle East /South-West Asia,Asia,AF,33.939110,67.709953,67434.445135,2.060000,435.075365,437.135365,274.000656,-1.446292,-1.473652,-2.219135,-1.689469,-1.879005
1,Albania,East Europe,Europe,AL,41.153332,20.168331,39534.547363,0.000000,50.901081,50.901081,2972.743618,-0.790545,-0.580953,-0.508157,-0.148599,-0.703136
2,Algeria,North Africa,Africa,DZ,28.033886,1.659626,41983.333333,0.000000,638.495263,638.495263,3500.134610,-0.564963,-0.438162,-1.126413,-0.500463,-0.775205
3,Angola,Southern Africa,Africa,AO,-11.202692,17.873887,41983.333333,0.000000,176.158159,176.158159,2597.963585,-1.241038,-1.381049,-0.539453,-1.156216,-1.311273
4,Argentina,South America,Americas,AR,-38.416097,-63.616672,9646.000000,0.000000,162.296133,162.296133,5890.978002,-0.331239,-0.033916,0.001968,-0.614567,-0.577016
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
150,Uzbekistan,Central Asia and Transcaucasian countries,Asia,UZ,41.377491,64.585262,67434.445135,93.560050,890.312303,983.872353,654.283837,-1.016336,-1.185932,-1.784660,-1.628457,-1.453595
151,"Venezuela, Bolivarian Republic of",South America,Americas,VE,6.423750,-66.589730,9646.000000,0.000000,111.823561,111.823561,6769.868414,-1.049627,-1.066524,-1.252868,-1.208009,-1.409444
152,Viet Nam,East and South-East Asia,Asia,VN,14.058324,108.277199,116166.505324,0.000000,1457.853216,1457.853216,790.592516,-0.747617,-0.246990,0.405120,-0.625980,-0.518899
153,Zambia,Southern Africa,Africa,ZM,-13.133897,27.849332,41983.333333,1.671115,102.844173,104.515288,1065.596417,-0.551736,-0.877599,0.363792,-0.655607,-0.549667


In [21]:
def aggregate_yearly_features(df, start_year = 2006, end_year = 2017):
    
    # Create a new dataframe and make it containt the start_year values
    df['total'] = df[start_year].copy()
    
    # Get the column names over which we want to average
    avg_cols = [col for col in list(df[start_year].columns) if col not in ['Country', 'Sub_Region', 'Region', 'ISO', 'Latitude', 'Longitude']]
    
    # Iterate over the time period and average over the columns in avg_cols
    for year in range(start_year + 1, end_year + 1):
        df['total'][avg_cols] += df[year][avg_cols]
    df['total'][avg_cols] /= (end_year - start_year + 1)
    return df

In [22]:
aggregate_yearly_features(joint_df_cocaine)['total'].head()

Unnamed: 0,Country,Sub_Region,Region,ISO,Latitude,Longitude,Price(USD),Seizures(kg),Consumption(kg),Market(kg),GDP/capita,Control_of_Corruption,Gov_Effectiveness,Stability_No_Terrorism,Regulatory_Quality,Rule_of_Law
0,Afghanistan,Near and Middle East /South-West Asia,Asia,AF,33.93911,67.709953,92422.410967,0.0,87.063256,87.063256,520.286184,-1.516444,-1.425955,-2.540436,-1.413413,-1.718438
1,Albania,East Europe,Europe,AL,41.153332,20.168331,50272.113193,0.0,150.220646,150.220646,4119.263805,-0.624496,-0.216519,0.018792,0.178123,-0.474197
2,Algeria,North Africa,Africa,DZ,28.033886,1.659626,56311.244432,0.080727,121.523203,121.60393,4602.706771,-0.58736,-0.471587,-1.168148,-1.038943,-0.826442
3,Andorra,West & Central Europe,Europe,AD,42.546245,1.601554,57001.472084,0.033732,2.605317,2.639049,45963.749644,1.241642,1.608962,1.325878,1.291107,1.335628
4,Angola,Southern Africa,Africa,AO,-11.202692,17.873887,35897.639058,18.667667,326.066924,344.73459,3595.820298,-1.367256,-1.08856,-0.4034,-1.027592,-1.216339


In [23]:
def aggregate_yearly_edges(G, start_year = 2006, end_year = 2017):
    edge_set = set()
    for year in range(start_year, end_year):
        for edge in G[year].edges:
            edge_set.add(edge)
    return list(edge_set)

In [74]:
def get_network_data(drug, df_ids = None, df_ids_path = None, 
                     for_pyg = True, for_r = False,
                     write_to_file = True, base_file_path = '/Users/mateicosa/Bocconi/BIDSA/Network_Science/data/',
                     start_year = 2006, end_year = 2017):
    
    # If no IDS data is not loaded, read it
    if df_ids is None:
        df_ids = Seizures.read_xlsx()
    
    # Get the node attributes
    df_yearly = get_node_attributes(drug, df_ids, start_year = start_year, end_year = end_year)
    df_aggregate = aggregate_yearly_features(df_yearly, start_year = start_year, end_year = end_year)
    
    # Get the edge data
    dict_of_nets = Seizures.get_drug_network_by_year(drug, df_ids, start_year = start_year, end_year= end_year)
    edge_list = aggregate_yearly_edges(dict_of_nets, start_year = start_year, end_year = end_year)
    
    # Add the features
    for year in range(start_year, end_year + 1):
        nx.set_node_attributes(dict_of_nets[year], df_aggregate[year])

    if for_pyg:
        # Construct a new network
        output_network = nx.DiGraph()
        node_list = list()
        # Preproccesing: one-hot encoding plus dropping 'ISO' column
        df_features = pd.get_dummies(df_aggregate['total'], columns = ['Sub_Region', 'Region'], dtype = float).drop(columns = ['ISO'])
        # Add the node with attributes in pytorch convention: y is the country name, x is a list of features
        for index, row in df_features.iterrows():
            node_list.append((row['Country'], {'y': row['Country'], 'x': list(row[1:])})) 
        output_network.add_nodes_from(node_list)
        output_network.add_edges_from(edge_list)

        # Write the output
        if write_to_file:
            
            # Create the path
            write_file_path = base_file_path + 'pyg_aggregate' + '.gml'

            # Write to the given path in gml format
            nx.write_gml(output_network, write_file_path)
        
        else:
            raise(NotImplementedError)
            
    return output_network

In [75]:
net = get_network_data('Cocaine', df_ids)