## Kernel to load: vax_inc_general 

In [1]:
import numpy as np
import pandas as pd
from statsmodels.tsa.holtwinters import ExponentialSmoothing
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.simplefilter("ignore", category=RuntimeWarning)
import pycountry


In [2]:
def hybrid_imputation_with_fallback_handling(
    series, trend='add', seasonal=None, 
    original_flags=None, original_flag_descriptions=None
):
    series = series.reset_index(drop=True).copy()  

    if series.isnull().all():
        raise ValueError("The entire series is missing; cannot perform imputation.")

    # Initialize flags and flag descriptions
    if original_flags is None:
        original_flags = ['O'] * len(series)  # Default flag for "Original"
    if original_flag_descriptions is None:
        original_flag_descriptions = ['Original value'] * len(series)

    flags = list(original_flags)
    flag_descriptions = list(original_flag_descriptions)

    missing_indices = series[series.isnull()].index.tolist()
    non_missing_indices = series[series.notnull()].index.tolist()

    # If only one non-missing value exists, fill all missing values with it
    if len(non_missing_indices) == 1:
        single_value = series.iloc[non_missing_indices[0]]
        series.fillna(single_value, inplace=True)
        for idx in missing_indices:
            flags[idx] = 'N'
            flag_descriptions[idx] = 'Filled with nearest neighbor (single value available)'
        return series, flags, flag_descriptions

    #Fill missing values before the first non-missing value
    if pd.isnull(series.iloc[0]):
        first_valid_index = series.first_valid_index()
        nearest_value = series.iloc[first_valid_index]
        series.iloc[:first_valid_index] = nearest_value
        for idx in range(first_valid_index):
            flags[idx] = 'N'
            flag_descriptions[idx] = 'Filled with nearest neighbor (start of series)'

    # Fill missing values between observed points using linear interpolation
    series.interpolate(method='linear', inplace=True)
    for idx in missing_indices:
        if idx in non_missing_indices:
            continue
        if idx > non_missing_indices[0] and idx < non_missing_indices[-1]:
            flags[idx] = 'L'
            flag_descriptions[idx] = 'Filled using linear interpolation'

    # Fill missing values after the last observed value using forecasting
    for idx in missing_indices:
        if idx > non_missing_indices[-1]:
            # Use Exponential Smoothing to forecast missing values
            try:
                model = ExponentialSmoothing(
                    series[:idx].dropna().values,
                    trend=trend,
                    seasonal=seasonal,
                    initialization_method="heuristic"
                )
                fit = model.fit()
            except ValueError:
                model = ExponentialSmoothing(
                    series[:idx].dropna().values,
                    trend=trend,
                    seasonal=seasonal,
                    initialization_method="legacy-heuristic"
                )
                fit = model.fit()
            
            forecast = fit.forecast(steps=1)

            #Fallback to nearest neighbor or 0 if forecast is negative
            if forecast[0] < 0:
                last_valid_value = series[:idx].dropna().iloc[-1]  # Last observed value
                fallback_value = max(last_valid_value, 0)  # Default to 0 if last valid value is 0
                series.iloc[idx] = fallback_value
                flags[idx] = 'N'
                flag_descriptions[idx] = 'Fallback to nearest neighbor (negative forecast)'
            else:
                series.iloc[idx] = forecast[0]
                flags[idx] = 'I'
                flag_descriptions[idx] = 'Imputed using exponential smoothing (forecasted)'

    return series, flags, flag_descriptions


In [3]:
# Fill in years 
def add_missing_years(df, start_year=1961, end_year=2024):
    earliest_year = df["Year"].min()
    latest_year = df["Year"].max()

    template_row = df.iloc[0].copy()
    static_columns = [
        "Domain Code", "Domain", "Area Code (M49)", "Area",
        "Element Code", "Element", "Item Code (CPC)", "Item", "Unit","ISO3"
    ]
    template_row = template_row[static_columns]
    
    # Create new rows for years before the earliest year
    before_years = pd.DataFrame([
        {**template_row, "Year Code": year, "Year": year}
        for year in range(start_year, earliest_year)
    ])
    
    # Create new rows for years after the latest year
    after_years = pd.DataFrame([
        {**template_row, "Year Code": year, "Year": year}
        for year in range(latest_year + 1, end_year + 1)
    ])
    
    other_years_excluded=[year for year in range(earliest_year,latest_year) if year not in df['Year'].tolist()]
    during_years = pd.DataFrame([
        {**template_row, "Year Code": year, "Year": year}
        for year in other_years_excluded 
    ])
    
    df = pd.concat([before_years, df, after_years, during_years], ignore_index=True)

    # Add columns to be filled in
    df.loc[df['Year'] < earliest_year, 'Value'] = None
    df.loc[df['Year'] < earliest_year, 'Flag'] = None
    df.loc[df['Year'] < earliest_year, 'Flag Description'] = None
    df.loc[df['Year'] < earliest_year, 'Note'] = None

    return df

In [4]:
import pandas as pd

def find_gaps_in_years(df):
    gaps = []

    grouped = df.groupby(['Area', 'Item'])

    for (country, item), group in grouped:
        group = group.sort_values('Year')

        min_year, max_year = group['Year'].min(), group['Year'].max()
        all_years = set(range(min_year, max_year + 1))
        reported_years = set(group['Year'])

        missing_years = all_years - reported_years

        if missing_years:
            gaps.append({
                'Area': country,
                'Item': item,
                'Min Year': min_year,
                'Max Year': max_year,
                'Missing Years': sorted(missing_years)
            })

    gaps_df = pd.DataFrame(gaps)
    return gaps_df

In [5]:
#Accounting for cutoffs for countries (FAO has data from 1960 to 2022)
country_existence = {
    # Yugoslavia, successors
    "YUG": {"start": 1918, "end": 1992}, 
    "SCG": {"start": 1992, "end": 2006}, 

    # Soviet Union
    "SUN": {"start": 1922, "end": 1991}, 

    # Germany before reunification
    "DDR": {"start": 1949, "end": 1990}, 

    # Czechoslovakia
    "CSK": {"start": 1918, "end": 1992}, 

    # Zaire
    "ZAR": {"start": 1971, "end": 1997}, 

    # Sudan before splitting
    "SDN-PRE": {"start": 1956, "end": 2011},  

    # Yemen
    "YMD": {"start": 1967, "end": 1990},  
    "YAR": {"start": 1962, "end": 1990},  

    # Other
    "TPT": {"start": 1976, "end": 1999},  
    "NF": {"start": 1901, "end": 1980},   
    "RHO": {"start": 1965, "end": 1980},  

    # Federation of the West Indies
    "WIF": {"start": 1958, "end": 1962},  
}


## Exponential smoothing below, best

In [6]:
datapaths=['poultry/original_poultry_pop_2024.csv','poultry/original_killed_poultry_pop_2024.csv',
           'poultry/original_pop_egg_layers.csv','cattle/original_cattle_pop_2024.csv',
           'cattle/original_killed_cattle_pop_2024.csv','cattle/original_pop_cattle_dairy_cattle.csv',
          'swine/original_swine_pop_2024.csv','swine/original_killed_swine_pop_2024.csv']

In [7]:
def build_countries_mapping():
    """
    Build and return a dictionary mapping country names to their ISO3 codes.
    This includes standard mappings from pycountry and custom overrides.
    """
    mapping = {}
    for country in pycountry.countries:
        mapping[country.name] = country.alpha_3

    mapping['USA']='USA'
    mapping['UK']='GBR'
    mapping['Taiwan']='TWN'
    mapping['South Korea']='KOR'
    mapping['Czech Republic']='CZE'
    mapping['Brunei']='BRN'
    mapping['Russia']='RUS'
    mapping['Iran']='IRN'
    mapping['United States of America']='USA'
    mapping['Venezuela']='VEN'
    mapping['China (Hong Kong SAR)']='HKG'
    mapping["Cote d'Ivoire"]='CIV'
    mapping['DR Congo']='COD'
    mapping['Guinea Bissau']='GNB'
    mapping['Lao PDR']='LAO'
    mapping['Micronesia (Federated States of)']='FSM'
    mapping['North Korea']='PRK'
    mapping['Occupied Palestinian Territory']='PSE'
    mapping['Swaziland']='SWZ'
    mapping['Tanzania']='TZA'
    mapping['Bolivia']='BOL'
    mapping['Macedonia (TFYR)']='MKD'
    mapping['Moldova']='MDA'
    mapping['Bolivia (Plurinational State of)']='BOL'
    mapping['China, Hong Kong SAR']='HKG'
    mapping['China, Taiwan Province of']='TWN'
    mapping['China, mainland']='CHN'
    mapping['Czechoslovakia']='CSK'
    mapping["Democratic People's Republic of Korea"]='PRK'
    mapping['Democratic Republic of the Congo']='COD'
    mapping['French Guyana']='GUF'
    mapping['Micronesia']='FSM'
    mapping['Palestine']='PSE'
    mapping['Polynesia']='PYF'
    mapping['Republic of Korea']='KOR'
    mapping['Serbia and Montenegro']='SCG'
    mapping['Sudan (former)']='SDN'
    mapping['Türkiye']='TUR'
    mapping['USSR']='SUN'
    mapping['Iran (Islamic Republic of)']='IRN'
    mapping['Republic of Moldova']='MDA'
    mapping['United Kingdom of Great Britain and Northern Ireland']='GBR'
    mapping['United Republic of Tanzania']='TZA'
    mapping['Venezuela (Bolivarian Republic of)']='VEN'
    mapping['Yugoslav SFR']='YUG'
    mapping['Ethiopia PDR']='ETH'
    mapping['Central African (Rep.)']='CAF'
    mapping["China (People's Rep. of)"]='CHN'
    mapping['Chinese Taipei']='TWN'
    mapping['Congo (Dem. Rep. of the)']='COD'
    mapping['Congo (Rep. of the)']='COG'
    mapping["Cote D'Ivoire"]='CIV'
    mapping['Dominican (Rep.)']='DOM'
    mapping["Korea (Dem People's Rep. of)"]='PRK'
    mapping['Korea (Rep. of)']='KOR'
    mapping['Laos']='LAO'
    mapping['South Sudan (Rep. of)']='SSD'
    mapping['Syria']='SYR'
    mapping['St. Vincent and the Grenadines']='VCT'
    mapping['Vietnam']='VNM'
    mapping['Reunion']='REU'
    mapping['Guadaloupe']='GLP'
    mapping['China, Macao SAR']='MAC'
    mapping['Netherlands (Kingdom of the)']='NLD'
    mapping['Türkiye (Rep. of)']='TUR'
    mapping['Belgium-Luxembourg']='BLX'
    mapping['Faeroe Islands']='FRO'
    mapping['Cabo verde']='CPV'
    mapping['St. Helena']='SHN'
    return mapping

# Build the mapping once to avoid rebuilding it on every call.
COUNTRIES_MAPPING = build_countries_mapping()

def get_iso3(country):
    """
    Return the ISO3 code for a given country name.

    Args:
        country (str): The country name to look up.

    Returns:
        str or None: The corresponding ISO3 code if found; otherwise, None.
    """
    return COUNTRIES_MAPPING.get(country, None)


In [8]:
for datapath in datapaths:
    start,end=datapath.split('/')
    end=end.replace('original_','')
    
    data=pd.read_csv(datapath)
    data = data[data['Year'] <= 2022] 
    data.loc[data['Flag'] == 'M', 'Value'] = None  # Set 'Value' to NaN where Flag is 'M' (Missing)
    
    data['ISO3'] = data['Area'].apply(get_iso3)
    if data['ISO3'].isnull().any(): print("FIX THIS: Missing ISO3 codes found.")

    
    data_frames=[]
    
    #Process each (ISO3, animal) group separately
    for (iso3, animal), group in data.groupby(['ISO3', 'Item']):
        # Checking if the ISO3 exists in the historical record
        if iso3 in country_existence:
            print(f"{iso3} has historical data.")
    
        
        # Get the valid year range for the country
        if iso3 in country_existence.keys():
            valid_start = max(1960, country_existence[iso3]['start'])
            valid_end = min(2024, country_existence[iso3]['end'])
        else:
            valid_start=1960
            valid_end=2024
        
        # Skip processing if the group is empty after filtering
        if group.empty:
            print(f"No valid data for {iso3} ({animal}) in the range {valid_start}-{valid_end}.")
            continue
        
        group = group.sort_values('Year').copy()  
        group = add_missing_years(group)  
        group = group[(group['Year'] >= valid_start) & (group['Year'] <= valid_end)]
    
        try:
            vals, flags, flag_desc = hybrid_imputation_with_fallback_handling(
                group['Value'],
                original_flags=group['Flag'],
                original_flag_descriptions=group['Flag Description']
            )
            
            try:
                vals=vals.tolist()
            except:
                pass
            
            try:
                flags=flags.tolist()
            except:
                pass
            
            try:
                flag_desc=flag_desc.tolist()
            except:
                pass
            
            #Assign these results back to the group
            group['Value'] = vals
            group['Flag'] = flags
            group['Flag Description'] = flag_desc
            
            
            data_frames+=[group]
    
        except Exception as e:  # Catch exceptions for debugging
            print(f"Error processing group ({iso3}, {animal}): {e}")
        
    
    
    full_data=pd.concat(data_frames)
    print('# Missing years:',len(find_gaps_in_years(full_data)))
    
    full_data.to_csv(start+'/'+end,index=False)
    
    if len(full_data[full_data['Value']<0])>0:
        print("-------Error: negative prediction. Debug this -------")
    
    print('*****Finished:',start+'/'+end,'****')

Error processing group (BDI, Turkeys): The entire series is missing; cannot perform imputation.
Error processing group (CIV, Ducks): The entire series is missing; cannot perform imputation.
CSK has historical data.
CSK has historical data.
CSK has historical data.
CSK has historical data.
Error processing group (FSM, Ducks): The entire series is missing; cannot perform imputation.
Error processing group (OMN, Ducks): The entire series is missing; cannot perform imputation.
SCG has historical data.
SCG has historical data.
SCG has historical data.
SCG has historical data.
SUN has historical data.
SUN has historical data.
YUG has historical data.
YUG has historical data.
YUG has historical data.
YUG has historical data.
# Missing years: 2
*****Finished: poultry/poultry_pop_2024.csv ****
Error processing group (CMR, Meat of ducks, fresh or chilled): The entire series is missing; cannot perform imputation.
CSK has historical data.
CSK has historical data.
CSK has historical data.
CSK has h