In [None]:
import numpy as np
import pandas as pd

import sys, os
sys.path.append("../..")
sys.path.append("..")
sys.path.append(os.getcwd())
sys.path.append("../../..")


import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso 
from sklearn.decomposition import PCA #USE PCA FOR PCR REGRESSION
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
import datetime
import math
from tslib.src import tsUtils
from tslib.src.synthcontrol.syntheticControl import RobustSyntheticControl
from tslib.tests import testdata

import statsmodels.api as sm
from sklearn.metrics import r2_score
from statsmodels.stats.outliers_influence import variance_inflation_factor

In [None]:
# Link for NFL Attendance Data: https://docs.google.com/spreadsheets/d/1DRMB5FLC3tdngeurDwps1CS8-6smEZvmq5R-ghjEv5k/edit#gid=0


#Import NFL Attendance Data using Pandas
attendance_df = pd.read_csv('NFL_Data.csv')

#Lots of NaN values, remove values without 'Team'
attendance_df = attendance_df.loc[~attendance_df['Team'].isnull()]

#Import NFL stadium County data, along with Neutral Counties
stadium_county_df = pd.read_csv('Stadium_County.csv',header=1) #Data is a little bit Messy 


#County Covid Rates Provided from NYTIMES Github Database, https://github.com/nytimes/covid-19-data/blob/master/us-counties-2020.csv
county_covid = pd.read_csv('County_Covid_Data.csv')


stadium_county_df['Team '].fillna(method = 'ffill', inplace= True)
stadium_county_df.rename(columns={'Team ': 'Team'}, inplace=True)
county_covid['date'] = pd.to_datetime(county_covid['date'])
county_covid = county_covid.loc[~(county_covid['county'] == 'Unknown')]
county_covid = county_covid.loc[~(county_covid['cases'].isnull())]
county_covid['county'] = county_covid['county'].apply(lambda x: x.lower() if isinstance(x, str) else x)



In [None]:
#We want to fill in state/county Data for cities that don't have stadium counties listed. Go through 1 by 1
no_fans_list = ['Buffalo', 'Chicago', 'Detroit', 'Green Bay', 'Las Vegas', 'LA Chargers','LA Rams', 'Minnesota', 'New England', 'New Orleans', 'NY Giants','NY Jets', 'San Francisco', 'Seattle', 'Washington']
stadium_county_df.loc[stadium_county_df['Team'] == 'Arizona', 'State'] = 'AZ'
stadium_county_df.loc[stadium_county_df['Team'] == 'Chicago', 'State'] = 'IL'

stadium_county_df.loc[stadium_county_df['Team'] == 'Detroit', 'State'] = 'MI'

stadium_county_df.loc[stadium_county_df['Team'] == 'Las Vegas', 'State'] = 'NV'

stadium_county_df.loc[stadium_county_df['Team'] == 'LA Chargers', 'State'] = 'CA'

stadium_county_df.loc[stadium_county_df['Team'] == 'LA Rams', 'State'] = 'CA'

stadium_county_df.loc[stadium_county_df['Team'] == 'Minnesota', 'State'] = 'MN'

stadium_county_df.loc[stadium_county_df['Team'] == 'New England', 'State'] = 'MA'

stadium_county_df.loc[stadium_county_df['Team'] == 'New Orleans', 'State'] = 'LA'

stadium_county_df.loc[stadium_county_df['Team'] == 'NY Giants', 'State'] = 'NJ'

stadium_county_df.loc[stadium_county_df['Team'] == 'NY Jets', 'State'] = 'NJ'

stadium_county_df.loc[stadium_county_df['Team'] == 'San Francisco', 'State'] = 'CA'

stadium_county_df.loc[stadium_county_df['Team'] == 'Seattle', 'State'] = 'WA'

stadium_county_df.loc[stadium_county_df['Team'] == 'Washington', 'State'] = 'District of Columbia'

stadium_county_df.loc[stadium_county_df['Team'] == 'Arizona', 'County(s)'] = 'Maricopa'
stadium_county_df.loc[stadium_county_df['Team'] == 'Chicago', 'County(s)'] = 'Cook'
stadium_county_df.loc[stadium_county_df['Team'] == 'Detroit', 'County(s)'] = 'Wayne'
stadium_county_df.loc[stadium_county_df['Team'] == 'Las Vegas', 'County(s)'] = 'Clark'
stadium_county_df.loc[stadium_county_df['Team'] == 'LA Chargers', 'County(s)'] = 'Los Angeles'
stadium_county_df.loc[stadium_county_df['Team'] == 'LA Rams', 'County(s)'] = 'Los Angeles'
stadium_county_df.loc[stadium_county_df['Team'] == 'Minnesota', 'County(s)'] = 'Ramsey'
stadium_county_df.loc[stadium_county_df['Team'] == 'New England', 'County(s)'] = 'Norfolk'
stadium_county_df.loc[stadium_county_df['Team'] == 'New Orleans', 'County(s)'] = 'Orleans'
stadium_county_df.loc[stadium_county_df['Team'] == 'NY Giants', 'County(s)'] = 'Bergen'
stadium_county_df.loc[stadium_county_df['Team'] == 'NY Jets', 'County(s)'] = 'Bergen'
stadium_county_df.loc[stadium_county_df['Team'] == 'San Francisco', 'County(s)'] = 'Santa Clara'
stadium_county_df.loc[stadium_county_df['Team'] == 'Seattle', 'County(s)'] = 'King'
stadium_county_df.loc[stadium_county_df['Team'] == 'Washington', 'County(s)'] = 'District of Columbia'

stadium_county_df['County(s)'] = stadium_county_df['County(s)'].apply(lambda x: x.lower() if isinstance(x, str) else x)

stadium_county_df['Counties'] = stadium_county_df['Counties'].apply(lambda x: x.lower() if isinstance(x, str) else x)



In [None]:
home_state_dict = {
        'AK': 'Alaska',
        'AL': 'Alabama',
        'AR': 'Arkansas',
        'AS': 'American Samoa',
        'AZ': 'Arizona',
        'CA': 'California',
        'CO': 'Colorado',
        'CT': 'Connecticut',
        'DC': 'District of Columbia',
        'DE': 'Delaware',
        'FL': 'Florida',
        'GA': 'Georgia',
        'GU': 'Guam',
        'HI': 'Hawaii',
        'IA': 'Iowa',
        'ID': 'Idaho',
        'IL': 'Illinois',
        'IN': 'Indiana',
        'KS': 'Kansas',
        'KY': 'Kentucky',
        'LA': 'Louisiana',
        'MA': 'Massachusetts',
        'MD': 'Maryland',
        'ME': 'Maine',
        'MI': 'Michigan',
        'MN': 'Minnesota',
        'MO': 'Missouri',
        'MP': 'Northern Mariana Islands',
        'MS': 'Mississippi',
        'MT': 'Montana',
        'NA': 'National',
        'NC': 'North Carolina',
        'ND': 'North Dakota',
        'NE': 'Nebraska',
        'NH': 'New Hampshire',
        'NJ': 'New Jersey',
        'NM': 'New Mexico',
        'NV': 'Nevada',
        'NY': 'New York',
        'OH': 'Ohio',
        'OK': 'Oklahoma',
        'OR': 'Oregon',
        'PA': 'Pennsylvania',
        'PR': 'Puerto Rico',
        'RI': 'Rhode Island',
        'SC': 'South Carolina',
        'SD': 'South Dakota',
        'TN': 'Tennessee',
        'TX': 'Texas',
        'UT': 'Utah',
        'VA': 'Virginia',
        'VI': 'Virgin Islands',
        'VT': 'Vermont',
        'WA': 'Washington',
        'WI': 'Wisconsin',
        'WV': 'West Virginia',
        'WY': 'Wyoming'
}


In [None]:

#We want to associate NFL teams to their respective Counties. Add Column that contains 'HOME COUNTIES', and columns that contains 'NEUTRAL COUNTIES'

#attendance_df['Team City'] = attendance_df.Team.str.split().str[:-1].str.join(sep=' ') #To connect with county information


#Change exceptions with 2 name cities



def find_stadium_counties(team_city_str):
    team_city_data = stadium_county_df.loc[stadium_county_df['Team'] == team_city_str]
    #team_city_data = team_city_data[team_city_data['County(s)'].notnull()]
    
    #Take away counties that don't belong in home state
    home_state = list(team_city_data['State'])[0]
    team_city_data = team_city_data[team_city_data['State.1'] == home_state]
    
    if team_city_str in no_fans_list:
        #If no fans, there actually is NO stadium county, since they are *not* affected by opening stadium. 
        return ""
    
    return list(set(team_city_data['County(s)']))

def find_donor_counties(team_city_str): #All counties within the state that is NOT in buffer counties.
    team_city_data = stadium_county_df.loc[stadium_county_df['Team'] == team_city_str]
    team_city_data = team_city_data[team_city_data['Counties'].notnull()]
    
    #Take away counties that don't belong in home state
    if(len(list(team_city_data['State'])) > 0): #If a home state exists... some not included in data set
        home_state = list(team_city_data['State'])[0]
    else: #Exceptions
        if team_city_str == 'Arizona':
            home_state = "AZ"
        elif team_city_str == 'Chicago':
            home_state = "IL"
        elif team_city_str == 'Detroit':
            home_state = "MI"
        elif team_city_str == 'Las Vegas':
            home_state = "NV"
        elif team_city_str == 'Chicago':
            home_state = "IL"
        elif team_city_str == 'LA Chargers':
            home_state = "CA"
        elif team_city_str == 'Kansas Chity':
            home_state = "MO"
        elif team_city_str == 'LA Rams':
            home_state = "CA"
        elif team_city_str == 'Minnesota':
            home_state = "MN"
        elif team_city_str == 'New England':
            home_state = "MA"
        elif team_city_str == 'New Orleans':
            home_state = "LA"
        elif team_city_str == 'NY Giants':
            home_state = "NJ"
        elif team_city_str == 'NY Jets':
            home_state = "NJ"
        elif team_city_str == 'San Francisco':
            home_state = "CA"
        elif team_city_str == 'Seattle':
            home_state = "WA"
        elif team_city_str == 'Washington':
            home_state = "MD"
        else:
            print(team_city_str)
            home_state = ''
            
    #Home State dictionary from 2 Letters to Full Name

    buffer_counties = list(set(stadium_county_df.loc[(stadium_county_df['State.1'] == home_state)]['Counties'])) #List of neutral counties IN HOME STATE
    
    stadium_counties = list(set(stadium_county_df.loc[(stadium_county_df['State'] == home_state)]['County(s)'])) #List of Stadium counties
    
    #Run through home_state_dict, as county_covid dataset uses full names for states rather than abbreviations
    home_state = home_state_dict[home_state]
    
    donor_counties = county_covid.loc[(county_covid['state'] == home_state)] #All county covid for home state

    donor_counties = donor_counties.loc[~donor_counties['county'].isin(buffer_counties) & (~donor_counties['county'].isin(stadium_counties))]

    donor_counties = donor_counties.loc[donor_counties['cases'] > 200]

    donor_counties = list(set(donor_counties['county']))
    
    #Sanity Check
    if team_city_str == "Washington":
        print(home_state)
        print(donor_counties)
        print(stadium_counties)
    


    return donor_counties

stadium_county_df['Stadium_Counties'] = stadium_county_df['Team'].apply(find_stadium_counties)
stadium_county_df['Donor_Counties'] = stadium_county_df['Team'].apply(find_donor_counties)
    
    
stadium_county_df['Stadium_Counties'] = stadium_county_df['Stadium_Counties'].apply(lambda x: [s.lower() if isinstance(s, str) else s for s in x])
stadium_county_df['Donor_Counties'] = stadium_county_df['Donor_Counties'].apply(lambda x: [s.lower() if isinstance(s, str) else s for s in x])


In [None]:
stadium_county_df

In [None]:
#Make earlier process into a function to generalize to other Stadiums

Total_prediction_data_fans = []
Total_prediction_data_no_fans = []

def create_synthetic_graph(team_name_str, stadium_county_str, state_str, intervention_date, show_plot, week):
    print(team_name_str)
    
    stadium_county_str = [x for x in stadium_county_str if x != ""]
    state_str = [x for x in state_str if x != ""]
    intervention_date = [x for x in intervention_date if x != ""]
    
    #Convert to lower case to avoid case insensitivity later
    stadium_county_str = [s.lower() for s in stadium_county_str]
    #To find intervention_date, we want first entry that has numbers, since some in data is text only.
    for s in intervention_date:
        if any(c.isdigit() for c in s):
            intervention_date = s

    intervention_date = pd.to_datetime(intervention_date)
    

    #Convert State Acronym to full state name
    for state in range(len(state_str)):
        if state_str[state] in home_state_dict:
            state_str[state] = home_state_dict[state_str[state]]
            
    #Find Synthetic Counties
    synthetic_counties = list(stadium_county_df.loc[stadium_county_df['Team'] == team_name_str]['Donor_Counties'])[0].copy()
    synthetic_counties = [s.lower() for s in synthetic_counties]
    n_donors = len(synthetic_counties)
    
    print(stadium_county_str)
#     print(sorted(synthetic_counties))
    
    #Find Dataframe of X and Y data
    #Special case where 
    if team_name_str == 'Washington':
        stadium_county_data = county_covid.loc[(county_covid['county'].isin(stadium_county_str)) | ((county_covid['county'].isin(synthetic_counties)) & (county_covid['state'] == 'Maryland'))]
    else:
        stadium_county_data = county_covid.loc[(county_covid['county'].isin(stadium_county_str) | (county_covid['county'].isin(synthetic_counties))) & (county_covid['state'].isin(state_str))]
    
    stadium_county_data = stadium_county_data.fillna(method='bfill')
    stadium_county_data['date'] = pd.to_datetime(stadium_county_data['date'], infer_datetime_format=True)
    
    earliest_date = list(stadium_county_data.loc[stadium_county_data['county'].isin(stadium_county_str)]['date'])[0]
    
    #Start training from the earliest date of when our stadium county data becomes available.
    stadium_county_data = stadium_county_data.loc[stadium_county_data['date'] >= earliest_date]
    
    #CONVERT ALL OF THIS INTO A PIVOT TABLE FIRST
    
    ########
    
    
    #Total Pivot is pivot table cases for entire dataset, training pivot is the same but for < intervention date
    total_pivot = stadium_county_data.pivot_table(columns='county', values='cases', index= 'date').reset_index()
    total_pivot = total_pivot.loc[total_pivot['date'] >= earliest_date]
    
    #Sum up stadium counties for our prediction. 
    total_pivot['Stadium_County'] = total_pivot.apply(lambda row: row[stadium_county_str].sum(), axis=1)
    #total_pivot['Stadium_County'] = total_pivot.loc[:, total_pivot.columns == (stadium_county_str[0])]
    
    total_pivot.drop(stadium_county_str, axis=1, inplace=True)
    
    total_pivot.fillna(0, inplace=True)
    
    training_pivot = total_pivot.loc[total_pivot['date'] < intervention_date]
    
    training_dates = training_pivot['date']
    
    total_dates = total_pivot['date']
    
    training_pivot = training_pivot.drop(['date'], axis=1)
    
    total_pivot = total_pivot.drop(['date'], axis=1)
    
    X_train = training_pivot.loc[:, ~training_pivot.columns.isin(['Stadium_County'])]
    
    Y_train = training_pivot['Stadium_County']
    
    num_pre_dates = X_train.shape[0]
    
    
    #Keep total X and Y data for future plots
    total_X = total_pivot.loc[:, ~total_pivot.columns.isin(['Stadium_County'])]
    total_Y = total_pivot['Stadium_County']
    
    assert total_X.shape[0] == total_Y.shape[0]
    
    
    ###########
    
    U, S, V = np.linalg.svd(X_train, full_matrices=False)

    total_columns = X_train.columns
    stadium_key = 'Stadium_County'
    donor_key = total_columns
    
    aggregate_errors = []
    r_squareds = []
    
    for svalue in range(1,len(S)+1):
        

        singvals = svalue
        rscModel = RobustSyntheticControl(stadium_key, singvals, len(training_pivot), probObservation=1.0, modelType='svd', svdMethod='numpy', otherSeriesKeysArray=donor_key)
        rscModel.fit(training_pivot)
        
        denoisedDF = rscModel.model.denoisedDF()

        predictions = []
        predictions = np.dot(X_train, rscModel.model.weights) #Prediction fits pre-intervention
        #predictions = predictions.astype(int)
        
        assert predictions.shape[0] == num_pre_dates
        
        aggregate_y = Y_train # Actual Y cases 

        assert len(aggregate_y) == len(predictions)
        
        
        aggregate_error = np.linalg.norm(aggregate_y - predictions) / np.linalg.norm(aggregate_y)

        aggregate_errors.append(aggregate_error)
    

        if (aggregate_error < 0.01):
            
#             print("Our final Kept singular values")
            print(svalue)
            break
            
            
        if svalue == (len(S)):
            print("Uses all singular values initially")
            
            if min(aggregate_errors) < 0.015:
                svalue = aggregate_errors.index(list(filter(lambda k: k < 0.015, aggregate_errors))[0]) + 1
                
            elif min(aggregate_errors) < 0.02:
                svalue = aggregate_errors.index(list(filter(lambda k: k < 0.02, aggregate_errors))[0]) + 1
            else:
                svalue = len(S)
            print("Our final Kept singular value")
            print(svalue)
            break
        
    total_columns = total_X.columns
    stadium_key = 'Stadium_County'
    donor_key = total_columns
    
#     if team_name_str == 'San Francisco':
#         svalue = 9
#     if team_name_str == 'Seattle':
#         svalue = 12
    


    rscModel = RobustSyntheticControl(stadium_key, svalue, len(training_pivot), probObservation=1.0, modelType='svd', svdMethod='numpy', otherSeriesKeysArray=donor_key)
    rscModel.fit(training_pivot)
    denoisedDF = rscModel.model.denoisedDF()
    
    predictions = []
    predictions = np.dot(total_X, rscModel.model.weights)

    #Following adds prediction data to construct historgram results T+days after intervention date
    if team_name_str in no_fans_list:
        Total_prediction_data_no_fans.append((-predictions[num_pre_dates:num_pre_dates + 21] +
                                     list(total_Y)[num_pre_dates:num_pre_dates + 21])/
                                             predictions[num_pre_dates:num_pre_dates + 21])
    else:
        Total_prediction_data_fans.append((-predictions[num_pre_dates:num_pre_dates + 21] +
                                     list(total_Y)[num_pre_dates:num_pre_dates + 21])/
                                          predictions[num_pre_dates:num_pre_dates + 21])
    
    assert predictions.shape[0] == total_Y.shape[0]
    
    if show_plot == True:

        #Plot Hamilton_synthetic along with actual Hamilton data past 10/4 results. 
        fig, ax = plt.subplots()


        plt.plot(total_dates, total_Y, label = stadium_county_str)
        plt.plot(total_dates, predictions, label='Synthetic RBSC')
        
        plt.xticks(rotation=45)
        plt.title(team_name_str + " " + stadium_county_str[0] + " County Covid Cases ")

        plt.tick_params(axis='x', which='major')
        plt.axvline(x=intervention_date, ymin = 0, ymax = 1, color='grey')
        fig.autofmt_xdate()

        ax.legend()
        ax.xaxis_date()
        plt.show()
    else: #Below is for Synthetic Lines, which plots all six weeks onto the same plot. 
        true_intervention = intervention_date + datetime.timedelta(weeks=week)
        
        if week == 0:
            plt.plot(total_dates, predictions, label='Synthetic Week Before', color = 'blue')
        else:
            plt.plot(total_dates, predictions, label='Synthetic Week Before', color = 'grey')

        plt.xticks(rotation=45)
        plt.title(team_name_str + " " + stadium_county_str[0] + " County Covid Cases ")
        

        plt.tick_params(axis='x', which='major')
        plt.axvline(x=true_intervention, ymin = 0, ymax = 1, color='grey')
        #plt.axvline(x=intervention_date, ymin = 0, ymax = 1, color='grey')

    return 


In [None]:
create_synthetic_graph("Kansas City",["Johnson", "Jackson"], ["KS", "MO"], ["09/21/2020"], True, 0)

In [None]:
#Compute graphs for all combinations 

no_fans_list = ['Buffalo', 'Chicago', 'Detroit', 'Green Bay', 'Las Vegas', 'LA Chargers','LA Rams', 'Minnesota', 'New England', 'New Orleans', 'NY Giants','NY Jets', 'San Francisco', 'Seattle', 'Washington']
Total_prediction_data_fans = []
Total_prediction_data_no_fans = []
dropped_stadium_df = stadium_county_df.copy()
dropped_stadium_df = dropped_stadium_df[['Team', 'First date home stadium open to fans','County(s)', 'State']]
dropped_stadium_df = dropped_stadium_df.fillna("")
grouped_df = dropped_stadium_df.groupby('Team').agg(list)
grouped_df.reset_index(inplace=True)


#We want to manually fill in Dates for stadiums not open to fans, easier this way. 
grouped_df.loc[grouped_df['Team'] == 'Chicago', 'First date home stadium open to fans'] = [['9/20/2020']]
grouped_df.loc[grouped_df['Team'] == 'Detroit', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Las Vegas', 'First date home stadium open to fans'] = [['9/21/2020']]
grouped_df.loc[grouped_df['Team'] == 'LA Chargers', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'LA Rams', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Minnesota', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'New England', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'New Orleans', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'NY Giants', 'First date home stadium open to fans'] = [['9/14/2020']]
grouped_df.loc[grouped_df['Team'] == 'NY Jets', 'First date home stadium open to fans'] = [['9/14/2020']]
grouped_df.loc[grouped_df['Team'] == 'San Francisco', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Seattle', 'First date home stadium open to fans'] = [['9/20/2020']]
grouped_df.loc[grouped_df['Team'] == 'Tennessee', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Washington', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Pittsburgh', 'First date home stadium open to fans'] = [['10/11/2020']]
grouped_df.loc[grouped_df['Team'] == 'Kansas City', 'First date home stadium open to fans'] = [['9/10/2020']] #Error in data set



#dropped_stadium_df = dropped_stadium_df.drop_duplicates(subset='Team') #Only need 1 for each team
zipped_input = zip(grouped_df['Team'], grouped_df['County(s)'], grouped_df['State'], grouped_df['First date home stadium open to fans'])

# for i, (team, county, state, date) in enumerate(zipped_input):
#     if (team != 'Washington'):
#         create_synthetic_graph(team, county, state, date)
    
for i, (team, county, state, date) in enumerate(zipped_input):

    if (team != 'Washington'):
        create_synthetic_graph(team, county, state, date, True, 0)
    else:
        create_synthetic_graph('Kansas City',["Johnson", "Jackson"] ,["MO"], ['9/10/2020'], True, 0)


In [None]:
#Make earlier function to be able to graph weeks 1-6 before the intervention. 
no_fans_list = ['Buffalo', 'Chicago', 'Detroit', 'Green Bay', 'Las Vegas', 'LA Chargers','LA Rams', 'Minnesota', 'New England', 'New Orleans', 'NY Giants','NY Jets', 'San Francisco', 'Seattle', 'Washington']

def create_synthetic_lines(team_name_str, stadium_county_str, state_str, intervention_date):
    intervention_date = [x for x in intervention_date if x != ""]
    
    for s in intervention_date:
        if any(c.isdigit() for c in s):
            intervention_date = s
            
    intervention_date = pd.to_datetime(intervention_date)
    intervention_array = [intervention_date - datetime.timedelta(weeks=i) for i in range(0, 7)]
    print(team_name_str)
    count = 0
    for intervention in intervention_array: 
        
        intervention = intervention.strftime('%m/%d/%Y')
        l = []
        l.append(intervention)
        l.append("")
        
        create_synthetic_graph(team_name_str, stadium_county_str, state_str, l, False, count)
        
        count += 1
    plt.show()
    
    return 


In [None]:
#Compute graphs for all combinations 

dropped_stadium_df = stadium_county_df[['Team', 'First date home stadium open to fans','County(s)', 'State']]
dropped_stadium_df = dropped_stadium_df.fillna("")
grouped_df = dropped_stadium_df.groupby('Team').agg(list)
grouped_df.reset_index(inplace=True)

#We want to manually fill in Dates for stadiums not open to fans, easier this way. 
grouped_df.loc[grouped_df['Team'] == 'Chicago', 'First date home stadium open to fans'] = [['9/20/2020']]
grouped_df.loc[grouped_df['Team'] == 'Detroit', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Las Vegas', 'First date home stadium open to fans'] = [['9/21/2020']]
grouped_df.loc[grouped_df['Team'] == 'LA Chargers', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'LA Rams', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Minnesota', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'New England', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'New Orleans', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'NY Giants', 'First date home stadium open to fans'] = [['9/14/2020']]
grouped_df.loc[grouped_df['Team'] == 'NY Jets', 'First date home stadium open to fans'] = [['9/14/2020']]
grouped_df.loc[grouped_df['Team'] == 'San Francisco', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Seattle', 'First date home stadium open to fans'] = [['9/20/2020']]
grouped_df.loc[grouped_df['Team'] == 'Tennessee', 'First date home stadium open to fans'] = [['10/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Washington', 'First date home stadium open to fans'] = [['9/13/2020']]
grouped_df.loc[grouped_df['Team'] == 'Pittsburgh', 'First date home stadium open to fans'] = [['10/11/2020']]
grouped_df.loc[grouped_df['Team'] == 'Kansas City', 'First date home stadium open to fans'] = [['9/10/2020']] #Error in data set

#dropped_stadium_df = dropped_stadium_df.drop_duplicates(subset='Team') #Only need 1 for each team
zipped_input = zip(grouped_df['Team'], grouped_df['County(s)'], grouped_df['State'], grouped_df['First date home stadium open to fans'])

# for i, (team, county, state, date) in enumerate(zipped_input):
#     if (team != 'Washington'):
#         create_synthetic_graph(team, county, state, date)
    
for i, (team, county, state, date) in enumerate(zipped_input):
    if team != "Washington":
        create_synthetic_lines(team, county, state, date)
        plt.show()


In [None]:
#create_synthetic_lines("Kansas City",["Johnson", "Jackson"], ["KS", "MO"], ["09/10/2020"])

In [None]:
#Create graphs for results after Game day. 

#Total_prediction_data should have

print(np.array(Total_prediction_data_no_fans).shape)


assert np.array(Total_prediction_data_no_fans).shape == (14, 21) #NUmber of teams, number of days past intervention
assert np.array(Total_prediction_data_fans).shape == (17, 21)

IQR_list = []

Time_series_no_fans = np.array(Total_prediction_data_no_fans).T #Transpose of time series data
Time_series_fans = np.array(Total_prediction_data_fans).T

# q1, q3 = np.percentile(Time_series, [25, 75], axis=1)
# iqr = q3 - q1

# print(q1)
# print(q3)

#plt.fill_betweenx(np.arange(len(iqr)), q1, q3, color='b', alpha=0.2)
fig, ax = plt.subplots()

plt.boxplot(Time_series_no_fans.T, vert=True, showfliers = False)
ax.set_ylim(bottom=-0.3,top = 0.3)
plt.xlabel("Days after Intervention")
plt.ylabel("Relative Difference Delta")
plt.title("Days after Intervention Delta for Stadium with No Fans")
ax.axhline(0, color='grey', linestyle='--')

# Show the plot
plt.show()

fig,ax = plt.subplots()

plt.boxplot(Time_series_fans.T, vert=True, showfliers = False)
ax.set_ylim(bottom=-0.3,top=0.3)

plt.xlabel("Days after Intervention")
plt.ylabel("Relative Difference Delta")
plt.title("Days after Intervention Delta for Stadium with Fans")
ax.axhline(0, color='grey', linestyle='--')
plt.show()
    

In [None]:
#MCMC Bayesian factor Synthetic Control

#Donors and Target. Key idea in Synthetic control is to learn some model to predict the values in target based on donors.

#Original method proves that you can caputre causal effects only if it's linear. 

#What we're trying to show is that a general model applies. 

#Draft that is a mess right now. 

#Stat 348. 

#Knowing Z gives me all the information about Y. 

#Probabilistic version of PCA

#Find a factor model, poisson factor. MCMC bayesian inference. 

#familiarize with PYRO, probabilistic programming framework. 

#Want to compute posterior given what you observe. Usually, the denominator you don't know how to compute. 

#So you have to approximate that integral. MCMC samples from the posterior. 

#Bayesian factor model generative story for the data, then gain back sample from the posterior. 

#Search through space of models efficiently, specify generative model and then get back posterior sample. 

#Alternative to PYRO is implementation, find alternative to probabilistic PCA and bayesian poisson factorization. 

#Probabilistic PCA bayesian, and poisson factorization samples of the posterior, apply it to dataset. 

#And then we will see how to check the models. 

#You are going to implement a function that can generate samples of data. Function represents hypothesis on how the dataset 
#Probabilistic generative for your data. Now give me a datset and I will sample from the posterior distribution from that data. 

In [None]:
#Sample Cincinnati Data. 

import pyro
import pyro.distributions as dist
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro.infer.mcmc as mcmc

def get_training_data(team_name_str, stadium_county_str, state_str, intervention_date, show_plot, week):

    stadium_county_str = [x for x in stadium_county_str if x != ""]
    state_str = [x for x in state_str if x != ""]
    intervention_date = [x for x in intervention_date if x != ""]
    
    #Convert to lower case to avoid case insensitivity later
    stadium_county_str = [s.lower() for s in stadium_county_str]
    #To find intervention_date, we want first entry that has numbers, since some in data is text only.
    for s in intervention_date:
        if any(c.isdigit() for c in s):
            intervention_date = s

    intervention_date = pd.to_datetime(intervention_date)
    

    #Convert State Acronym to full state name
    for state in range(len(state_str)):
        if state_str[state] in home_state_dict:
            state_str[state] = home_state_dict[state_str[state]]
            
    #Find Synthetic Counties
    synthetic_counties = list(stadium_county_df.loc[stadium_county_df['Team'] == team_name_str]['Donor_Counties'])[0].copy()
    synthetic_counties = [s.lower() for s in synthetic_counties]
    n_donors = len(synthetic_counties)
    
    print(stadium_county_str)
#     print(sorted(synthetic_counties))
    
    #Find Dataframe of X and Y data
    #Special case where 
    if team_name_str == 'Washington':
        stadium_county_data = county_covid.loc[(county_covid['county'].isin(stadium_county_str)) | ((county_covid['county'].isin(synthetic_counties)) & (county_covid['state'] == 'Maryland'))]
    else:
        stadium_county_data = county_covid.loc[(county_covid['county'].isin(stadium_county_str) | (county_covid['county'].isin(synthetic_counties))) & (county_covid['state'].isin(state_str))]
    
    stadium_county_data = stadium_county_data.fillna(method='bfill')
    stadium_county_data['date'] = pd.to_datetime(stadium_county_data['date'], infer_datetime_format=True)
    
    earliest_date = list(stadium_county_data.loc[stadium_county_data['county'].isin(stadium_county_str)]['date'])[0]
    
    #Start training from the earliest date of when our stadium county data becomes available.
    stadium_county_data = stadium_county_data.loc[stadium_county_data['date'] >= earliest_date]
    
    #CONVERT ALL OF THIS INTO A PIVOT TABLE FIRST
    
    ########
    
    
    #Total Pivot is pivot table cases for entire dataset, training pivot is the same but for < intervention date
    total_pivot = stadium_county_data.pivot_table(columns='county', values='cases', index= 'date').reset_index()
    total_pivot = total_pivot.loc[total_pivot['date'] >= earliest_date]
    
    #Sum up stadium counties for our prediction. 
    total_pivot['Stadium_County'] = total_pivot.apply(lambda row: row[stadium_county_str].sum(), axis=1)
    #total_pivot['Stadium_County'] = total_pivot.loc[:, total_pivot.columns == (stadium_county_str[0])]
    
    total_pivot.drop(stadium_county_str, axis=1, inplace=True)
    
    total_pivot.fillna(0, inplace=True)
    
    training_pivot = total_pivot.loc[total_pivot['date'] < intervention_date]
    
    training_dates = training_pivot['date']
    
    total_dates = total_pivot['date']
    
    training_pivot = training_pivot.drop(['date'], axis=1)
    
    total_pivot = total_pivot.drop(['date'], axis=1)
    
    X_train = training_pivot.loc[:, ~training_pivot.columns.isin(['Stadium_County'])]

    return X_train

get_train = get_training_data('Kansas City',["Johnson", "Jackson"] ,["MO"], ['9/10/2020'], True, 0)

pd.options.display.max_rows = None
pd.options.display.max_columns = None
np.set_printoptions(threshold=np.inf)
a = get_train.values.astype(int)
a


In [79]:
fake_data = [[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    1,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    1,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    1,    0],
       [   0,    0,    0,    3,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    1,    0],
       [   0,    0,    0,    3,    0,    0,    1,    4,    0,    0,    0,
           0,    0,    0,    1,    0],
       [   0,    0,    0,    5,    0,    0,    3,    5,    0,    0,    0,
           1,    0,    0,    1,    0],
       [   0,    0,    0,    5,    0,    0,    4,   10,    0,    0,    0,
           2,    0,    0,    1,    0],
       [   1,    0,    0,    9,    0,    0,    4,   15,    2,    0,    0,
           3,    0,    0,    1,    0],
       [   1,    0,    0,   10,    0,    0,    6,   17,    2,    0,    0,
           3,    0,    0,    3,    0],
       [   1,    0,    0,   11,    0,    0,    7,   17,    2,    0,    0,
           3,    0,    0,    3,    0],
       [   1,    0,    0,   13,    1,    0,    7,   18,    5,    0,    0,
           2,    1,    0,    5,    0],
       [   1,    0,    0,   16,    1,    0,   13,   30,    7,    0,    0,
           2,    1,    0,    6,    0],
       [   1,    0,    1,   21,    1,    0,   16,   38,   10,    0,    0,
           2,    2,    0,    6,    0],
       [   1,    0,    1,   23,    1,    0,   19,   47,   12,    0,    0,
           2,    2,    0,    6,    0],
       [   1,    0,    1,   26,    1,    0,   23,   62,   16,    1,    0,
           3,    2,    2,    7,    0],
       [   1,    0,    2,   32,    2,    0,   24,   69,   16,    1,    0,
           4,    2,    2,    7,    0],
       [   1,    0,    2,   33,    3,    0,   25,   71,   17,    1,    0,
           4,    3,    3,    7,    0],
       [   1,    0,    2,   37,    3,    0,   33,   81,   18,    2,    0,
           5,    3,    3,   10,    0],
       [   1,    0,    5,   37,    3,    0,   35,   92,   21,    3,    0,
           5,    4,    5,   10,    0],
       [   1,    0,    6,   37,    5,    0,   56,   95,   25,    3,    0,
           5,    4,    7,   11,    1],
       [   1,    0,    6,   38,    4,    0,   62,   98,   25,    3,    0,
           6,    4,   11,   13,    1],
       [   1,    0,   10,   40,    4,    0,   68,  107,   25,    3,    0,
           6,    4,   18,   13,    1],
       [   1,    0,   13,   46,    4,    0,   75,  115,   26,    3,    0,
           6,    4,   22,   14,    1],
       [   1,    0,   13,   48,    4,    0,   81,  115,   26,    3,    0,
           6,    4,   22,   15,    1],
       [   1,    0,   16,   51,    4,    0,   84,  129,   26,    4,    0,
           6,    4,   23,   24,    1],
       [   2,    0,   16,   51,    5,    0,   84,  138,   27,    5,    0,
           6,    4,   35,   25,    1],
       [   2,    0,   17,   52,    5,    0,   83,  147,   28,    5,    0,
           6,    4,   44,   26,    2],
       [   2,    0,   21,   52,    5,    0,   89,  155,   29,    5,    0,
           6,    4,   49,   41,    3],
       [   2,    0,   22,   55,    6,    0,   91,  179,   30,    5,    0,
           7,    4,   51,   42,    4],
       [   2,    0,   23,   56,    5,    0,   95,  183,   31,    5,    1,
           7,    4,   51,   49,    4],
       [   2,    0,   23,   57,    5,    0,   98,  214,   31,    5,    1,
           7,    4,   54,   51,    4],
       [   2,    0,   24,   59,    5,    0,   99,  222,   33,    5,    3,
           7,    4,   54,   50,    5],
       [   2,    0,   24,   61,    5,    0,  102,  230,   34,    5,    3,
           7,    5,   54,   52,    6],
       [   2,    0,   24,   61,    5,    0,  102,  238,   37,    5,    4,
           7,    5,   55,   59,    6],
       [   2,    0,   26,   62,    6,    0,  103,  240,   35,    5,    4,
           7,    5,   81,   59,    7],
       [   2,    0,   25,   63,    7,    0,  106,  251,   36,    5,    4,
           7,    5,  139,   59,    8],
       [   2,    0,   25,   65,    7,    0,  106,  251,   43,    5,    4,
           7,    5,  143,   60,    8],
       [   2,    0,   26,   67,    7,    0,  106,  250,   43,    5,    5,
           7,    5,  159,   61,    8],
       [   2,    0,   26,   68,    7,    0,  106,  249,   45,    5,    5,
           7,    5,  168,   61,    8],
       [   2,    0,   26,   69,    7,    0,  110,  255,   46,    5,    4,
           7,    5,  170,   62,    8],
       [   2,    0,   25,   69,    7,    0,  110,  256,   46,    5,    4,
           7,    5,  182,   63,    8],
       [   2,    0,   25,   70,    8,    0,  114,  260,   46,    5,    4,
           7,    5,  184,   64,    8],
       [   2,    0,   25,   77,    8,    0,  114,  261,   46,    5,    4,
           7,    5,  190,   70,    8],
       [   2,    0,   26,   77,    8,    0,  119,  265,   48,    5,    4,
           7,    5,  193,   76,    8],
       [   2,    0,   26,   77,    8,    0,  122,  271,   48,    5,    4,
           7,    5,  199,   76,    8],
       [   2,    0,   26,   81,    8,    0,  125,  279,   51,    5,    4,
           7,    5,  204,   76,    9],
       [   2,    0,   26,   86,    8,    0,  127,  278,   49,    5,    4,
           7,    5,  205,   75,    9],
       [   2,    0,   26,   88,    8,    0,  132,  286,   51,    5,    5,
           7,    5,  208,   76,    9],
       [   2,    0,   26,   91,    8,    0,  137,  289,   51,    5,    5,
           7,    5,  208,   77,    9],
       [   2,    0,   26,   96,    8,    0,  129,  289,   54,    5,    5,
           7,    5,  214,   78,    9],
       [   2,    0,   26,   96,    8,    0,  128,  292,   75,    5,    5,
           7,    5,  215,   79,    9],
       [   2,    0,   27,   96,    8,    0,  128,  292,   76,    5,    5,
           7,    5,  219,   80,    9],
       [   2,    0,   27,  101,    8,    0,  130,  309,   77,    5,    5,
           7,    5,  219,   80,    9],
       [   2,    0,   27,  102,    8,    0,  131,  310,   76,    5,    5,
           7,    5,  219,   82,    9],
       [   2,    0,   27,  106,    8,    0,  133,  314,   76,    5,    5,
           7,    5,  224,   83,    9],
       [   2,    0,   27,  111,    8,    0,  134,  316,   76,    5,    5,
           7,    5,  224,   84,    9],
       [   2,    0,   27,  113,    8,    0,  134,  318,   77,    5,    5,
           7,    5,  225,   85,    9],
       [   2,    0,   28,  116,    8,    0,  134,  326,   78,    5,    5,
           7,    5,  241,   86,    9],
       [   2,    0,   28,  114,    9,    0,  134,  328,   78,    5,    5,
           7,    5,  246,   87,    9],
       [   2,    0,   30,  120,    9,    0,  135,  330,   78,    5,    5,
           7,    5,  249,   87,    9],
       [   2,    0,   32,  123,    8,    0,  135,  331,   78,    5,    5,
           7,    5,  253,   86,    9],
       [   3,    0,   34,  125,    8,    0,  136,  331,   78,    5,    5,
           7,    5,  256,   88,    9],
       [   3,    0,   28,  127,    8,    0,  137,  333,   78,    5,    5,
           7,    5,  256,   89,   10],
       [   3,    0,   28,  133,    8,    0,  137,  334,   78,    5,    5,
           7,    5,  260,   91,   10],
       [   3,    0,   28,  141,    8,    0,  137,  337,   79,    5,    5,
           8,    5,  263,   91,   10],
       [   3,    0,   28,  149,    8,    0,  138,  337,   79,    6,    5,
           8,    5,  265,   93,   10],
       [   3,    0,   28,  148,    8,    0,  140,  337,   80,    6,    5,
           8,    5,  267,   93,   10],
       [   3,    0,   28,  152,    8,    0,  139,  344,   81,    6,    5,
           8,    5,  267,   94,   10],
       [   3,    0,   28,  158,    8,    0,  140,  345,   81,    6,    5,
           8,    5,  267,   94,   10],
       [   3,    0,   29,  162,    8,    0,  140,  348,   81,    6,    5,
           9,    5,  267,   98,   10],
       [   3,    0,   32,  169,    8,    0,  139,  349,   81,    6,    5,
           9,    5,  267,  100,   11],
       [   3,    0,   32,  171,    8,    0,  139,  349,   81,    9,    5,
           9,    5,  267,  108,   11],
       [   3,    0,   33,  173,    8,    1,  139,  349,   82,    9,    5,
           9,    5,  267,  110,   11],
       [   3,    0,   34,  179,    8,    1,  141,  363,   84,    9,    5,
           9,    5,  268,  108,   12],
       [   3,    0,   35,  184,    8,    2,  142,  390,   85,    9,    5,
           9,    5,  268,  109,   12],
       [   3,    0,   35,  189,    8,    2,  143,  395,   84,   10,    5,
           9,    5,  268,  113,   12],
       [   3,    0,   35,  192,    8,    2,  149,  396,   85,   10,    5,
          10,    5,  268,  117,   12],
       [   3,    0,   37,  228,    8,    2,  149,  398,   86,   10,    5,
          10,    5,  268,  119,   12],
       [   3,    0,   37,  228,    8,    2,  150,  398,   86,   10,    5,
          11,    5,  268,  121,   13],
       [   3,    0,   37,  233,    8,    2,  150,  398,   87,   10,    5,
          11,    5,  269,  121,   13],
       [   3,    0,   39,  243,    8,    3,  153,  407,   87,   10,    5,
          13,    6,  270,  123,   13],
       [   3,    0,   41,  255,    8,    3,  156,  411,   87,   10,    6,
          12,    6,  270,  127,   13],
       [   3,    0,   39,  263,    8,    3,  155,  411,   87,   11,    6,
          12,    5,  270,  128,   13],
       [   3,    1,   39,  291,    8,    3,  164,  427,   87,   11,    7,
          12,    5,  270,  132,   13],
       [   3,    1,   40,  295,    8,    3,  166,  428,   88,   11,    7,
          12,    5,  270,  134,   13],
       [   3,    1,   41,  295,    8,    3,  166,  428,   87,   11,    8,
          12,    5,  271,  139,   13],
       [   3,    1,   41,  295,    8,    3,  166,  428,   87,   11,    8,
          12,    5,  271,  139,   13],
       [   3,    7,   46,  298,    9,    3,  168,  432,   87,   11,    8,
          12,    5,  270,  145,   13],
       [   3,    9,   48,  299,    9,    3,  172,  439,   87,   11,    8,
          12,    5,  272,  145,   13],
       [   3,   11,   53,  305,    9,    3,  173,  445,   90,   12,    9,
          12,    5,  274,  149,   13],
       [   3,   10,   58,  319,   10,    3,  178,  456,   91,   13,    9,
          12,    5,  277,  149,   14],
       [   3,   15,   59,  334,   10,    3,  180,  461,   91,   13,   10,
          13,    5,  278,  155,   14],
       [   3,   15,   62,  342,   10,    3,  188,  461,   91,   13,   11,
          13,    5,  278,  157,   14],
       [   3,   15,   62,  342,   10,    3,  188,  461,   91,   13,   11,
          13,    5,  278,  159,   14],
       [   3,   15,   62,  348,   10,    3,  191,  465,   92,   13,   11,
          13,    5,  278,  160,   14],
       [   3,   16,   67,  367,   10,    3,  193,  473,   94,   14,   13,
          15,    6,  278,  159,   15],
       [   3,   17,   68,  375,   10,    3,  192,  501,   95,   14,   13,
          15,    6,  278,  159,   17],
       [   3,   21,   71,  386,   10,    3,  201,  506,   96,   14,   14,
          16,    6,  282,  159,   19],
       [   3,   21,   75,  406,   11,    3,  202,  519,   95,   14,   14,
          16,    6,  282,  161,   21],
       [   3,   21,   75,  414,   11,    3,  204,  519,   96,   15,   15,
          16,    7,  282,  161,   21],
       [   3,   22,   75,  431,   11,    3,  206,  519,   96,   16,   16,
          16,    7,  283,  162,   23],
       [   3,   24,   99,  443,   12,    4,  207,  535,   97,   16,   16,
          16,    7,  283,  164,   23],
       [   3,   25,  101,  454,   12,    4,  206,  539,   96,   16,   17,
          17,    7,  283,  168,   23],
       [   3,   26,  106,  460,   12,    5,  208,  540,   96,   16,   19,
          20,    7,  283,  169,   23],
       [   4,   26,  108,  472,   13,    6,  213,  551,   99,   16,   20,
          20,    7,  283,  170,   23],
       [   4,   26,  108,  478,   15,    6,  217,  551,   98,   17,   25,
          20,    7,  284,  172,   25],
       [   5,   28,  108,  493,   16,    6,  217,  551,   98,   18,   25,
          20,    9,  284,  173,   25],
       [   5,   28,  108,  500,   17,    6,  217,  551,   98,   18,   25,
          20,    9,  284,  181,   25],
       [   5,   29,  108,  504,   17,    6,  227,  582,  102,   18,   25,
          21,    9,  287,  185,   25],
       [   6,   30,  111,  510,   18,    6,  231,  604,  105,   19,   26,
          21,   10,  290,  188,   25],
       [   6,   31,  111,  526,   19,    6,  235,  631,  106,   19,   27,
          21,   11,  295,  191,   26],
       [   6,   31,  116,  539,   19,    8,  248,  666,  110,   20,   29,
          23,   11,  299,  192,   26],
       [   6,   31,  116,  550,   21,   11,  264,  666,  110,   20,   29,
          23,   12,  307,  198,   26],
       [   6,   33,  117,  557,   21,   12,  273,  666,  112,   20,   31,
          22,   13,  309,  206,   26],
       [   6,   34,  117,  558,   22,   12,  279,  743,  119,   20,   31,
          22,   13,  310,  207,   26],
       [   6,   34,  118,  565,   22,   13,  283,  769,  120,   20,   33,
          23,   15,  314,  209,   27],
       [   6,   37,  122,  579,   23,   16,  305,  814,  122,   22,   36,
          22,   15,  319,  219,   27],
       [   6,   41,  122,  589,   24,   17,  317,  834,  127,   22,   38,
          23,   15,  324,  235,   28],
       [   6,   41,  125,  597,   26,   21,  324,  857,  128,   22,   39,
          23,   15,  324,  237,   38],
       [   6,   42,  133,  607,   27,   22,  337,  873,  133,   21,   41,
          23,   15,  328,  240,   39],
       [   6,   45,  142,  617,   28,   22,  348,  873,  144,   21,   45,
          23,   17,  343,  242,   39],
       [   6,   48,  145,  625,   28,   26,  354,  873,  153,   22,   48,
          24,   17,  346,  244,   40],
       [   6,   51,  146,  630,   28,   28,  360,  969,  153,   23,   50,
          24,   22,  348,  244,   40],
       [   6,   51,  151,  646,   30,   31,  382, 1007,  160,   24,   53,
          24,   28,  349,  250,   41],
       [   7,   51,  154,  655,   31,   32,  394, 1018,  172,   25,   58,
          24,   27,  351,  266,   41],
       [   9,   52,  169,  685,   32,   36,  408, 1101,  185,   27,   68,
          26,   28,  359,  272,   42],
       [  11,   53,  177,  688,   32,   40,  424, 1122,  202,   27,   75,
          26,   31,  365,  282,   42],
       [  13,   52,  189,  705,   33,   43,  443, 1122,  212,   28,   82,
          27,   29,  368,  286,   46],
       [  13,   53,  192,  717,   34,   46,  447, 1155,  221,   28,   85,
          27,   32,  369,  285,   46],
       [  13,   55,  201,  744,   35,   47,  451, 1211,  228,   28,   92,
          29,   32,  371,  294,   46],
       [  11,   62,  210,  784,   33,   53,  469, 1248,  242,   28,  102,
          31,   32,  378,  311,   49],
       [  11,   62,  215,  826,   36,   56,  484, 1291,  258,   30,  111,
          31,   32,  389,  313,   50],
       [  11,   64,  227,  857,   40,   61,  490, 1370,  272,   30,  127,
          33,   32,  399,  330,   54],
       [  11,   65,  236,  882,   47,   68,  509, 1417,  288,   30,  139,
          33,   36,  407,  342,   55],
       [  11,   65,  239,  896,   52,   73,  516, 1417,  293,   30,  150,
          33,   37,  409,  341,   54],
       [  12,   65,  239,  900,   56,   73,  519, 1417,  301,   31,  150,
          36,   39,  410,  345,   56],
       [  12,   65,  247,  914,   57,   81,  542, 1559,  314,   31,  157,
          38,   42,  415,  352,   56],
       [  12,   66,  249,  944,   58,   83,  554, 1614,  331,   30,  160,
          40,   45,  416,  356,   60],
       [  14,   66,  253,  974,   62,   83,  566, 1646,  343,   31,  164,
          41,   46,  421,  358,   66],
       [  14,   66,  264,  984,   69,   84,  586, 1679,  352,   30,  171,
          41,   44,  421,  368,   67],
       [  14,   69,  275, 1011,   71,   83,  590, 1716,  360,   29,  181,
          41,   45,  424,  381,   69],
       [  14,   69,  275, 1011,   71,   83,  590, 1716,  360,   29,  181,
          41,   45,  424,  381,   69],
       [  18,   68,  281, 1037,   73,   82,  618, 1822,  374,   32,  193,
          41,   48,  428,  395,   77],
       [  18,   69,  283, 1053,   78,   87,  631, 1836,  382,   33,  201,
          41,   49,  435,  400,   82],
       [  19,   70,  286, 1064,   79,   89,  642, 1876,  398,   33,  204,
          41,   52,  441,  406,   87],
       [  18,   70,  289, 1087,   83,   89,  669, 1912,  412,   35,  223,
          42,   53,  459,  421,   94],
       [  19,   71,  297, 1098,   90,   89,  680, 1974,  416,   38,  236,
          43,   54,  472,  429,  105],
       [  19,   71,  302, 1126,   97,   89,  704, 2006,  421,   38,  256,
          44,   55,  478,  449,  124],
       [  20,   71,  311, 1133,  102,   92,  729, 2006,  430,   39,  264,
          44,   58,  481,  461,  126],
       [  20,   72,  317, 1143,  111,   94,  742, 2006,  439,   41,  269,
          44,   62,  483,  476,  133],
       [  20,   73,  321, 1158,  115,   94,  762, 2166,  449,   41,  272,
          45,   62,  483,  492,  143],
       [  21,   74,  322, 1167,  115,   94,  778, 2196,  454,   41,  278,
          47,   64,  483,  492,  149],
       [  19,   74,  326, 1180,  121,   95,  791, 2250,  461,   41,  293,
          47,   66,  496,  515,  156],
       [  21,   75,  327, 1194,  126,   97,  806, 2297,  468,   41,  307,
          47,   66,  502,  519,  169],
       [  21,   76,  337, 1210,  132,   97,  818, 2317,  474,   41,  315,
          48,   68,  503,  535,  181],
       [  22,   78,  345, 1245,  137,   97,  844, 2317,  488,   43,  325,
          48,   70,  517,  544,  184],
       [  22,   79,  351, 1261,  141,   97,  863, 2412,  497,   43,  335,
          48,   70,  522,  544,  191],
       [  20,   83,  354, 1269,  146,   98,  877, 2471,  505,   41,  349,
          49,   70,  525,  548,  199],
       [  20,   84,  354, 1285,  148,   98,  889, 2518,  510,   44,  360,
          49,   71,  525,  551,  200],
       [  22,   88,  358, 1309,  159,  101,  922, 2566,  514,   46,  365,
          51,   73,  530,  565,  212],
       [  23,   92,  363, 1323,  164,  102,  943, 2605,  525,   46,  378,
          53,   75,  531,  581,  224],
       [  25,   94,  366, 1346,  175,  106,  981, 2663,  537,   46,  425,
          55,   75,  534,  582,  227],
       [  28,  102,  373, 1355,  178,  109, 1002, 2663,  551,   47,  453,
          56,   78,  543,  606,  228],
       [  28,  102,  373, 1363,  179,  110, 1018, 2776,  559,   47,  463,
          56,   78,  543,  619,  237],
       [  28,  103,  376, 1381,  186,  112, 1025, 2813,  564,   47,  471,
          56,   80,  546,  625,  238],
       [  29,  105,  379, 1403,  193,  112, 1040, 2857,  574,   49,  483,
          56,   82,  547,  632,  239],
       [  33,  108,  383, 1412,  198,  114, 1055, 2891,  586,   49,  489,
          55,   85,  551,  652,  246],
       [  33,  110,  386, 1443,  209,  119, 1092, 2963,  598,   44,  497,
          57,   87,  557,  655,  251],
       [  34,  111,  398, 1455,  215,  117, 1110, 3025,  602,   49,  505,
          56,   88,  569,  662,  253],
       [  36,  114,  410, 1481,  219,  117, 1133, 3025,  614,   51,  531,
          58,   91,  581,  679,  262],
       [  37,  115,  417, 1493,  217,  118, 1157, 3064,  621,   51,  545,
          60,   92,  585,  684,  268],
       [  37,  116,  425, 1506,  220,  119, 1170, 3109,  625,   48,  558,
          58,   93,  589,  690,  267],
       [  37,  117,  430, 1512,  218,  120, 1178, 3248,  626,   49,  563,
          61,   94,  591,  692,  269],
       [  37,  122,  430, 1530,  221,  124, 1203, 3297,  630,   49,  571,
          63,   96,  593,  708,  273]]

In [80]:
#Import bayesian PCA

import pyro
import pyro.distributions as dist
import torch
import numpy as np
import matplotlib.pyplot as plt
import pyro.infer.mcmc as mcmc

def bayesian_pca(data, latent_dim):
    # Define model parameters
    n, p = data.shape
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mu = torch.zeros(p)
    cov = sigma * torch.eye(p)
    

    # Mask the lower right corner of the data
    mask = torch.ones(n, p)
    mask[-30:,-3:,] = 0
    masked_data = data * mask


    # Define the custom mask distribution for training comparison
    def mask_dist(mean, covariance):
        masked_mean = mask * mean
        return dist.MultivariateNormal(masked_mean, covariance_matrix=covariance)
    
    # Define the latent variables for the masked data
    Z_mean = pyro.param("Z_mean", torch.zeros(n, latent_dim))
    Z_cov = pyro.param("Z_cov", torch.eye(latent_dim))


    W_mean = pyro.param("W_mean", torch.zeros(latent_dim, p))
    W_cov = pyro.param("W_cov", torch.eye(p))
    
    
    Z_mean.data = torch.zeros(n, latent_dim)
    Z_cov.data = torch.eye(latent_dim)
    W_mean.data = torch.zeros(latent_dim, p)
    W_cov.data = torch.eye(p)

    
    Z = pyro.sample("Z", dist.MultivariateNormal(Z_mean, Z_cov))
    
    W = pyro.sample("W", dist.MultivariateNormal(W_mean, W_cov))
    
    X = pyro.sample("X", mask_dist(Z @ W, cov * torch.eye(p)), obs=masked_data)

    # Return the estimated latent variables
    return X


In [None]:
#dummy_train is what we are currently using to test our Bayesian Factor models. 

#Our PCA becomes 177 * 10, with last 30 * 3 Data to be masked. 
 
latent_dim = 5

# dummy_train = get_train.values
# dummy_train = torch.tensor(dummy_train)

#dummy_train = torch.tensor(get_train.values, dtype=int)

# Run MCMC
num_samples = 1000
dummy_train = np.array(fake_data)

#Lower warmup steps for code to run
warmup_steps = 500
kernel = mcmc.NUTS(bayesian_pca)
mcmc_run = mcmc.MCMC(kernel, num_samples=num_samples, warmup_steps=warmup_steps)

#Apply MCMC to our data.
mcmc_run.run(torch.tensor(dummy_train), latent_dim)

pyro.clear_param_store()
#print(dummy_train)
     

Warmup:  22%|██▏       | 329/1500 [14:58,  2.79s/it, step size=1.84e-05, acc. prob=0.743]

In [None]:
# Extract posterior samples, includes W, Z, and X. W is transformation, Z is weights, X is our data. 
#Sigma is our covariance matrix for X, assumed to be diagonal for PPCA I believe. 
posterior_samples = mcmc_run.get_samples()

# Extract W, sigma, and Z samples
W_samples = posterior_samples["W"]
sigma_samples = posterior_samples["sigma"]
Z_samples = posterior_samples["Z"]


print(W_samples.size())
print(Z_samples.size())

W_mean = W_samples.mean(dim=0)
Z_mean = Z_samples.mean(dim=0)



In [None]:
reconstructed_X = Z_samples @ W_samples #Data_Rep Drawn from posterior Distribution


data_test = dummy_train[-30:, -3:]

log_prob_list = []
test_prob_list = []
for w,z,sig,x in zip(W_samples, Z_samples, sigma_samples, reconstructed_X):
    #w is 2*5, z is 100*2, multiplied is 100*5. )
    #Calculate log probability that distribution 
    
    #Create new Y distribution based off our parameters
    sample_dist = dist.MultivariateNormal((z @ w)[-30:, -3:], covariance_matrix=(torch.eye(3)*sig))
    y_pred  = sample_dist.sample()

    #Calculate the likelihood given this sample
    y_pred_likelihood = sample_dist.log_prob(y_pred).sum()
    
    #Now, we want to calculate the likelihood of the actual data. 
    test_prob = dist.MultivariateNormal((z @ w)[-30:, -3:], covariance_matrix=(torch.eye(3)*sig)).log_prob(data_test).sum()
    
    
    log_prob_list.append(y_pred_likelihood) 
    test_prob_list.append(test_prob)

count = sum([1 for x, y in zip(log_prob_list, test_prob_list) if x > y])
percent_likelihood = count / len(log_prob_list)
        
print("Percentage of Test distribution more likely than Y_Pred is" + str(percent_likelihood))

In [None]:
print(log_prob_list[:5])
print(test_prob_list[:5])

In [None]:
####POISSON FACTORIZATION EXAMPLE

def poisson_factorization(data, latent_dim, mask_rows, mask_cols):
    # Define model parameters
    n, p = data.shape
    
    #Construct samples of F, G, and X. X is assumed to be poisson distribution of F * G
    F = pyro.sample("F", dist.Gamma(1., 1.).expand([n, latent_dim]))
    G = pyro.sample("G", dist.Gamma(1., 1.).expand([latent_dim, p]))
    
    
    # Define masking function to hide lower right corner of the data
    mask = torch.ones(n,p)
    
    #We want to select random times and counties to 0 out for our testing data mask.
    
    mask_rows = torch.randperm(N)[:20]
    mask_cols = torch.randperm(P)[:5]

    # set selected rows and columns to 0
    mask[mask_rows, :] = 0.00001
    mask[:, mask_cols] = 0.00001

    masked_data = data * mask
    mu = F@G
    mu = mu * mask
    
    # Observe the observed entries of X
    pyro.sample("X_observed", dist.Poisson(mu), obs=masked_data)
    
    return mu

# Define model

#Is our latent dimension here also deterined as a hyperparameter? 
latent_dim = 10
model = poisson_factorization

# # Convert data to PyTorch tensor
dummy_train = torch.tensor(dummy_train)

N, P = dummy_train.shape
mask_rows = torch.randperm(N)[:20]
mask_cols = torch.randperm(P)[:5]


# Run MCMC
num_samples = 1000
warmup_steps = 1500
kernel = mcmc.NUTS(model)
mcmc_run = mcmc.MCMC(kernel, num_samples=num_samples, warmup_steps=warmup_steps)

#Run MCMC process on our data with given Latent dimension
mcmc_run.run(dummy_train, latent_dim, mask_rows, mask_cols)

In [None]:
posterior_samples = mcmc_run.get_samples()

# Extract F, G, lambda_f, and lambda_g samples
F_samples = posterior_samples["F"]
G_samples = posterior_samples["G"]


F_samples[-1] @ G_samples[-1]

In [None]:
reconstructed_X = F_samples @ G_samples #Data_Rep Drawn from posterior Distribution
log_prob = []
test_prob_list = []
for f,g,sig,x in zip(F_samples, G_samples, sigma_samples, reconstructed_X):
    #w is 2*5, z is 100*2, multiplied is 100*5. )
    #Calculate log probability that distribution 
    
    
    #NEED TO CHECK IF THIS WAY OF MASKING VALUES INTERFERE WITH JOINT PROBABILITY DISTRIBUTION ! 
    
    #Create new Y distribution based off our parameters, only masked values
    sample_dist = dist.Poisson((f@g)[mask_rows,:][:, mask_cols])
    
    #Only have prediction on selected mask_rows and mask_cols
    y_pred  = sample_dist.sample()
    
    #Calculate the likelihood given this sample
    y_pred_likelihood = sample_dist.log_prob(y_pred).sum()
    
    #print(y_pred_likelihood)
    
    #Now, we want to calculate the likelihood of the actual data. 
    data_test = dummy_train[mask_rows,:][:, mask_cols]
    test_prob = dist.Poisson((f @ g)[-40:, -4:,]).log_prob(data_test).sum()

    print(dist.Poisson((f @ g)[-40:, -4:]).log_prob(data_test))
    print(dist.Poisson((f @ g)[-40:, -4:]).log_prob(y_pred))
    break
    
    log_prob.append(y_pred_likelihood) 
    test_prob_list.append(test_prob)

count = sum([1 for x, y in zip(log_prob, test_prob_list) if x > y])
percent_likelihood = count / len(log_prob)
        
print("Percentage of Test distribution more likely than Y_Pred is" + str(percent_likelihood))
        

In [None]:
log_prob

In [None]:
test_prob_list