# Import packages

In [1]:
# Python util
from collections import OrderedDict
import pandas as pd
import numpy as np
import pprint
pp = pprint.PrettyPrinter(indent=4)


# Models and statistics libraries
from sklearn.preprocessing import scale
import sklearn.linear_model as skl_lm
from sklearn.metrics import mean_squared_error, r2_score
import statsmodels.api as sm
import statsmodels.formula.api as smf
import statistics


# Plotting
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
import seaborn as sns
%matplotlib inline
plt.style.use('seaborn-white')

# Helper Functions

In [2]:
def calculate_vif(df):
    """Calculate variance inflation factor (VIF).

    Args:
        df: Pandas Dataframe with indepedent variable columns.

    Returns:
        Pandas Dataframe with VIF calculations.
    """
    from statsmodels.stats.outliers_influence import variance_inflation_factor
    # Calculating VIF
    vif = pd.DataFrame()
    vif["variables"] = df.columns
    vif["VIF"] = [variance_inflation_factor(df.values, i) for i in range(df.shape[1])]
    return vif

def create_month_bins(TIME_INTERVAL):
    """Initialize dictionary with months over time interval as keys, empty lists as values.
    
    Args:
        TIME_INTERVAL: A set of string months to use as keys for dictionary.

    Returns:
        Dictionary with months(str) as keys, empty lists as values. 
    """
    return {x: list() for x in sorted(list(TIME_INTERVAL))}

def time_interval(start, end):
    """Define time interval, by months. Format: 'year-month' eg. '2020-05'
    
    Args:
        start: Start of time interval as an int. eg. '2020-02' => 202002
        end: End of time interval as an int. eg. '2023-02' => 202302

    Returns:
        A set of all months (str) to include in time interval. 
    """
    TIME_INTERVAL = set()
    years = ["2020", "2021", "2022", "2023"]
    months = ["01", "02", "03", "04", "05",
              "06", "07", "08", "09", "10",
              "11", "12"]
    for y in years:
        for m in months:
            if int(y + m) < start or int(y + m) > end:
                #print(y + "-" + m)
                continue
            TIME_INTERVAL.add(y + "-" + m)
    return TIME_INTERVAL


def load_month_diversity_data(filename, TIME_INTERVAL):
    """Load and parse data for genetic diversity scores by month.
    
    Args:
        filename: Path(str) to genetic diversity score data file.
        TIME_INTERVAL: A set of all months (str) to include in time interval.

    Returns:
        A set of all months (str) to include in time interval. 
    """
    fp = open(filename, 'r')
    next(fp)
    d = dict()
    for line in fp:
        splitline = line.split(',')
        month = splitline[0]
        if month not in TIME_INTERVAL:
            continue
        score = float(splitline[1].strip('\n'))
        d[month] = score
    return d

def count_rivet_events_by_month(filename, TIME_INTERVAL):
    """Load file with RIVET inferred recombinants each month, 
        and count the number of recombination events each month.
    
    Args:
        filename: Path(str) to file containing RIVET inferred recombinants, one per line.
        TIME_INTERVAL: A set of all months (str) to include in time interval.

    Returns:
        A dictionary with each month(str) as the keys, and the number(int) of RIVET inferred
        recombinants that month as the values.
    """
    fp = open(filename, 'r')
    next(fp)
    d = dict()
    for line in fp:
        splitline = line.split(',')
        month = splitline[0]
        if month not in TIME_INTERVAL:
            continue
        if month not in d.keys():
            d[month] = 1
        else:
            d[month]+=1
    return d

def load_case_data(csvFilename):
    """Load and parse file containing the number of infections per month.
    
    Args:
        csvFilename: Path(str) to file containing number of infections per month.
    
    Returns:
        A dictionary with each month(str) as the keys, and the number(int) of infections
        for that month as the values.
    """
    infection_data = dict()
    # Load case data
    case_fp = open(csvFilename, 'r')
    # Skip over header (Month,Cases)
    next(case_fp)
    # Extract month and count data
    for line in case_fp:
        splitline = line.split(',')
        month = splitline[0]
        count = int(splitline[1].strip('\n'))
        infection_data[month] = count
    return infection_data

def fitness_by_month(filename, TIME_INTERVAL):
    """Load and parse file containing the fitness scores
        of all inferred recombinants for a given month.
    
    Args:
        filename: Path(str) to file containing the fitness scores for individual inferred recombinants.
        TIME_INTERVAL: A set of all months (str) to include in time interval.
        
    Returns:
        A dictionary with each month(str) as the keys, and the values being a list of all the fitness
        scores for the inferred recombinants that emerged during that given month.
    """
    fp = open(filename, 'r')
    # Skip over header
    next(fp)
    month_bins = create_month_bins(TIME_INTERVAL)
    for line in fp:
        splitline = line.split(',')
        month = splitline[0]
        score = float(splitline[1])
        # Only consider months within the pre-defined time-interval
        if month not in month_bins.keys():
            continue
        month_bins[month].append(score)
    fp.close()
    return month_bins

# Datasets

In [3]:
# Genetic diversity scores by month
GENETIC_DIVERSITY_SCORES_FILENAME = "data/genetic_diversity_by_month.csv"

### Fitness of recombinant events

# Rivet recombs fitness vs avg fitness of parents 
RECOMBS_FITNESS_VS_AVG_PARENTS = "data/recomb_fitness_avg_parents.csv"

# Rivet recombs fitness vs max fitness of parents
RECOMBS_FITNESS_VS_MAX_PARENTS = "data/recomb_fitness_max_parents.csv"

# Number of new infections per month during pandemic
CASE_DATA_FILENAME = "data/cases.csv"


# Define Time Interval during for experiments

In [4]:
# Set the time interval (in months) we want to consider
TIME_INTERVAL = time_interval(202002, 202302)
assert(len(TIME_INTERVAL) == 37)

# Transform Data from CSV to DataFrame

In [5]:
"""Create one large CSV file, to transform into a dataframe"""

# Number of months
NUM_ROWS = 37 

diversity_dict = load_month_diversity_data(GENETIC_DIVERSITY_SCORES_FILENAME, TIME_INTERVAL)
months = sorted(list(diversity_dict.keys()))
diversity_scores = list(diversity_dict.values())
assert(len(months) == NUM_ROWS)

rivet_counts_by_month_dict = count_rivet_events_by_month(RECOMBS_FITNESS_VS_AVG_PARENTS, TIME_INTERVAL)
rivet_counts_by_month_array = list(rivet_counts_by_month_dict.values())

infection_data = load_case_data(CASE_DATA_FILENAME)

# Write all this information out to one large CSV File

fp = open('data/linear-regression-data.csv','w')
fp.write('Month,Diversity,Infections,Recombinants\n')
for i in range(NUM_ROWS):
    line = months[i] + "," + str(diversity_dict[months[i]]) + "," + \
    str(infection_data[months[i]]) + "," +  \
    str(rivet_counts_by_month_dict[months[i]]) + "\n"
    fp.write(line)
    
fp.close()

In [6]:
df = pd.read_csv('data/linear-regression-data.csv')
df

Unnamed: 0,Month,Diversity,Infections,Recombinants
0,2020-02,33.1337,9927,1
1,2020-03,55.1048,76096,3
2,2020-04,61.9058,783348,7
3,2020-05,65.9993,2412716,13
4,2020-06,70.1913,2901229,12
5,2020-07,76.2169,4292072,10
6,2020-08,79.3156,7118656,23
7,2020-09,87.4234,7941296,30
8,2020-10,102.205,8498335,41
9,2020-11,112.679,12122070,45


## Correlation Matrix 

In [7]:
df.corr()

Unnamed: 0,Diversity,Infections,Recombinants
Diversity,1.0,0.112636,0.784402
Infections,0.112636,1.0,0.317813
Recombinants,0.784402,0.317813,1.0


# Regression Model Results

# 1a. Detectable recombination as a function of standing genetic diversity

- 'recombinants' is the dependent variable: # of dectectable recombinants (each month) inferred by RIVET
- 'Diversity' is the independent variable: Standing genetic diversity scores for each month (Phylogenetic entropy)

- Regression Method used is Ordinary Least Squares

In [8]:
est = smf.ols('Recombinants ~ Diversity', df).fit()
est.summary()

0,1,2,3
Dep. Variable:,Recombinants,R-squared:,0.615
Model:,OLS,Adj. R-squared:,0.604
Method:,Least Squares,F-statistic:,55.98
Date:,"Fri, 12 Jan 2024",Prob (F-statistic):,9.23e-09
Time:,11:29:21,Log-Likelihood:,-166.67
No. Observations:,37,AIC:,337.3
Df Residuals:,35,BIC:,340.6
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,-44.5029,14.599,-3.048,0.004,-74.141,-14.865
Diversity,0.9924,0.133,7.482,0.000,0.723,1.262

0,1,2,3
Omnibus:,1.546,Durbin-Watson:,0.642
Prob(Omnibus):,0.462,Jarque-Bera (JB):,1.471
Skew:,-0.398,Prob(JB):,0.479
Kurtosis:,2.433,Cond. No.,434.0


# 1b. Number of detectable recombinants vs Number of Infections

In [9]:
est_infections_only = smf.ols('Recombinants ~ Infections', df).fit()
est_infections_only.summary()

0,1,2,3
Dep. Variable:,Recombinants,R-squared:,0.101
Model:,OLS,Adj. R-squared:,0.075
Method:,Least Squares,F-statistic:,3.932
Date:,"Fri, 12 Jan 2024",Prob (F-statistic):,0.0553
Time:,11:29:21,Log-Likelihood:,-182.38
No. Observations:,37,AIC:,368.8
Df Residuals:,35,BIC:,372.0
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,49.1186,8.298,5.919,0.000,32.272,65.965
Infections,6.644e-07,3.35e-07,1.983,0.055,-1.58e-08,1.34e-06

0,1,2,3
Omnibus:,4.239,Durbin-Watson:,0.261
Prob(Omnibus):,0.12,Jarque-Bera (JB):,2.147
Skew:,0.307,Prob(JB):,0.342
Kurtosis:,1.993,Cond. No.,36300000.0


# 1c. Number of detectable recombinants vs (Diversity and Infections)

In [10]:
est = smf.ols('Recombinants ~ Diversity + Infections', df).fit()
est.summary()

0,1,2,3
Dep. Variable:,Recombinants,R-squared:,0.669
Model:,OLS,Adj. R-squared:,0.649
Method:,Least Squares,F-statistic:,34.3
Date:,"Fri, 12 Jan 2024",Prob (F-statistic):,7.01e-09
Time:,11:29:21,Log-Likelihood:,-163.91
No. Observations:,37,AIC:,333.8
Df Residuals:,34,BIC:,338.7
Df Model:,2,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,-49.7838,13.932,-3.573,0.001,-78.096,-21.471
Diversity,0.9593,0.126,7.631,0.000,0.704,1.215
Infections,4.858e-07,2.08e-07,2.339,0.025,6.37e-08,9.08e-07

0,1,2,3
Omnibus:,0.831,Durbin-Watson:,0.839
Prob(Omnibus):,0.66,Jarque-Bera (JB):,0.744
Skew:,-0.324,Prob(JB):,0.689
Kurtosis:,2.751,Cond. No.,99100000.0


# Testing for Multicollinearity

## Correlation matrix between independent variables

In [11]:
independent_variables_df = df.iloc[:,1:-1]

independent_variables_df.corr()

Unnamed: 0,Diversity,Infections
Diversity,1.0,0.112636
Infections,0.112636,1.0


### Measuring Multicollinearity using Variance Inflation Factor (VIF)

In [12]:
calculate_vif(independent_variables_df)

Unnamed: 0,variables,VIF
0,Diversity,2.123985
1,Infections,2.123985
