In [35]:
import pandas as pd
import numpy as np
import os
import datetime
from HMM import unsupervised_HMM
from HMM import supervised_HMM
from HMM_helper import sample_sentence
import simplejson
from hmmlearn import hmm
from matplotlib import pyplot as plt
import warnings
warnings.filterwarnings("error")


In [36]:
import git
import sys
repo = git.Repo("./", search_parent_directories=True)
homedir = repo.working_dir

In [37]:
def makeHMMUnSupData(Input, colname, fipsname):
    #Takes input dataframe, and gives out HMM format of data, a list of lists 
    #of the colname value, each list in the set represents one fips code.
    Output = []
    for fips in Input[fipsname].unique():
        temp = list(Input[Input[fipsname] == fips][colname])
        Output.append(temp)
    return Output

In [38]:
def makeHMMmap(Output):
    #Takes in output of makeHMMUnSupData and transforms data into list from 0 to D-1, where D is the number of unique
    #values of the output
    #Unqiue values in the input
    UniqueVals = np.array(list(set(x for l in Output for x in l)))
    UniqueVals = np.sort(UniqueVals)
    HMMOutput = []
    templs = []
    Map = {}
    RMap = {}
    for x in range(len(UniqueVals)):
        Map[int(UniqueVals[x])] = x
        RMap[x] = int(UniqueVals[x])
    for ls in Output:
        for val in ls:
            templs.append(Map[val])
        HMMOutput.append(templs)
        templs = []
    return [Map,RMap,HMMOutput]

In [39]:
def makeHMMSupData(UnSupData):
    #Takes list of lists of time series data from makeHMMUnSupData and makes it into data with X and Y
    X = []
    Y = []
    tempX = []
    tempY = []
    for ls in UnSupData:
        lenls = len(ls)
        for n in range(lenls):            
            if n == 0:
                tempX.append(ls[n])
            elif n == lenls - 1:
                tempY.append(ls[n])
            else:
                tempX.append(ls[n])
                tempY.append(ls[n])
        if len(tempX) != 0 and len(tempY) != 0:
            X.append(tempX)
            Y.append(tempY)
        tempX = []
        tempY = []   
    return [X,Y]

In [40]:
def makeX(Data, DTW, cluster_col, cluster_num, fipsname='FIPS', deathsname='Deaths'):
    #Takes in the dataset, cluster column and number, and gives out the deaths info in this cluster
    #In the form able to be processed by hmmlearn's HMM modules    
    fips = list(DTW[DTW[cluster_col] == cluster_num]['FIPS'])
    Rows = Data[Data[fipsname].isin(fips)]
    RawData = makeHMMUnSupData(Rows, deathsname, fipsname)
    #RawData = [a[0] for a in RawData]
    temp = []
    lengths = []
    for i in RawData:
        temp.extend(i)
        lengths.append(len(i))
    temp = np.array(temp).reshape(-1,1)
    return [temp, lengths]

In [41]:
def makeHMM(X):
    #Takes in data from makeX, and uses the Elbow method to determine the optimal number of 
    #states needed in the HMM, and returns the HMM with that optimal number of states
    scores = []
    Flag = True
    val = 999
    for i in range(1,31):
        tempmodel = hmm.GaussianHMM(n_components=i, covariance_type="full")
        #Tries to make the model fit, can fail if data not diverse enough
        try:
            if Flag:
                tempmodel.fit(X[0],X[1])
                scores.append(tempmodel.score(X[0],X[1]))
                if i > 10:
                    if scores[-1] < scores[-2]:
                        Flag = False
                print(i)
        except:
            val = i - 1
            Flag = False
    #If the data only accepts less than 4 states to work, we chose the max number of states to describe it
    if val < 5:
        return hmm.GaussianHMM(n_components = val, covariance_type="full").fit(X[0],X[1])
    else:
    #We do an elbow method otherwise
        n = 0
        #finding number of negative entries
        for j in scores:
            if j < 0:
                n += 1
        #gettin index of best point by elbow method (using first derivative)
        print(scores)
        ind = np.argmax(np.diff(scores)[(n + 1):]/scores[(n + 2):])
        return hmm.GaussianHMM(n_components = ind + n + 3, covariance_type="full").fit(X[0],X[1])

In [42]:
def makeHMMlist(Data, DTW, cluster_col):
    labels = np.sort(DTW[cluster_col].dropna().unique())
    HMM_list = [0] * len(labels)
    n = 0
    for i in labels:
        print(i)
        X = makeX(Data, DTW, cluster_col, i)
        HMM_list[n] = makeHMM(X)
        n += 1
    return [HMM_list, labels]

In [43]:
#Dataframes of deaths
NYT_F = pd.read_csv(f"{homedir}/models/HMM_Work/NYT_daily_Filled.csv", index_col=0)
NYT_W = pd.read_csv(f"{homedir}/models/HMM_Work/NYT_daily_Warp.csv", index_col=0)
NYT_F = NYT_F.rename(columns={'fips':'FIPS','deaths':'Deaths'})
NYT_W = NYT_W.rename(columns={'fips':'FIPS','deaths':'Deaths'})
JHU = pd.read_csv(f"{homedir}/models/HMM_Work/JHU_daily.csv", index_col=0)
#list of lists of deaths data
with open('NYT_daily_Warp_Death.txt') as f:
    NYT_daily_Warp_Death = simplejson.load(f)
with open('NYT_daily_Death_Filled.txt') as g:
    NYT_daily_Death_Filled = simplejson.load(g)
with open('JHU_daily_death.txt') as h:
    JHU_daily_death = simplejson.load(h)
#DTW Based Clusters
DTW_Clusters = pd.read_csv(f"{homedir}/models/HMM_Work/DTW_Clustering.csv", index_col=0)

In [44]:
DTW_Clusters.columns

Index(['FIPS', 'JHU_Orig', 'JHU_Z_T', 'JHU_Z_L', 'JHU_N_T', 'JHU_N_L',
       'NYT_F_Orig', 'NYT_F_Z_T', 'NYT_F_Z_L', 'NYT_F_N_T', 'NYT_F_N_L',
       'NYT_F_N_L_L', 'NYT_W_Orig', 'NYT_W_Z_T', 'NYT_W_Z_L', 'NYT_W_N_T',
       'NYT_W_N_L'],
      dtype='object')

In [45]:
JHU_Z_T_HMMs = makeHMMlist(JHU, DTW_Clusters, 'JHU_Z_T')
JHU_Z_L_HMMs = makeHMMlist(JHU, DTW_Clusters, 'JHU_Z_L')
JHU_N_T_HMMs = makeHMMlist(JHU, DTW_Clusters, 'JHU_N_T')
JHU_N_L_HMMs = makeHMMlist(JHU, DTW_Clusters, 'JHU_N_L')


1.0
1
2
3
4
5
6
7
8
9
10
11
[-32840.59978712644, 430872.00732721755, 431206.30130103225, 431219.9189569103, 431289.27876231726, 431894.05150728155, 431901.48184390436, 431907.47015674016, 431440.1141718713, 431445.0839904647, 431379.52538275137]
2.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[-336193.884443771, -63857.9782391102, 390315.8945749949, 186772.01756561763, 393545.01466129714, 397117.01334775676, 399758.5052326718, 399970.4140727667, 400080.17767243873, 402245.13859351084, 439868.5869830984, 439993.95515313814, 403006.4647398737]
1.0
1
2
3
4
5
6
7
8
9
10
11
[-32840.59978712644, 430872.00732721755, 431206.432672503, 431219.75608098914, 431289.27876231726, 431894.05150728155, 431901.3346758999, 431914.1899491296, 431450.6564240503, 431440.74059651734, 431379.414054372]
2.0
1
2
3
4
5
6
7
8
9
10
11
[-14110.351051856438, -6971.322464968913, -5516.150646647813, 11906.738003332182, 12141.252966671627, 12323.675924151672, 12367.004548907205, 15953.941322145825, 15879.296526895781, 17289.16439466

In [46]:
NYT_F_Z_T_HMMs = makeHMMlist(NYT_F, DTW_Clusters, 'NYT_F_Z_T')
NYT_F_Z_L_HMMs = makeHMMlist(NYT_F, DTW_Clusters, 'NYT_F_Z_L')
NYT_F_N_T_HMMs = makeHMMlist(NYT_F, DTW_Clusters, 'NYT_F_N_T')
NYT_F_N_L_HMMs = makeHMMlist(NYT_F, DTW_Clusters, 'NYT_F_N_L')
NYT_F_N_L_L_HMMs = makeHMMlist(NYT_F, DTW_Clusters, 'NYT_F_N_L_L')


1.0
1
2
3
4


Some rows of transmat_ have zero sum because no transition from the state was ever observed.


2.0
1
2
3
4
5
6
7
8
9
10
11
12
[-643249.5653963943, 1116545.2532690472, 1151192.3140791268, 1157569.2302384214, 1157413.3184431556, 1158895.081870845, 1160610.7841786323, 1162229.2738577793, 1205575.5547017786, 1205971.925401311, 1206165.4971232847, 1163768.670971517]
1.0
1
2
3
4


Some rows of transmat_ have zero sum because no transition from the state was ever observed.


2.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[1129.2454079700872, 129131.4137647617, 550848.6329431692, 555544.9363835234, 555549.5552453535, 555607.2735772423, 555835.6837004757, 555850.4977984767, 555858.2737736449, 555887.5629066653, 555932.9098583324, 555953.3009488548, 555959.282853974]
3.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[-112442.8025336316, -39317.08819174419, -7017.794951463718, 83103.26044189204, 82986.51795292772, 83921.85146408435, 84788.40833065947, 85440.41338678393, 86143.59187908193, 86155.42524077944, 90127.6798701251, 97029.77182967591, 97026.24804834112]
4.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[-7208.75949191738, 26915.153143835905, 29057.3389611933, 29161.31260477689, 29208.040142358346, 29211.63610801607, 29625.907831656277, 29824.690720110153, 29829.490395945195, 29837.00607417128, 29844.256129626825, 29855.997410353655, 29866.952043718626, 29872.54856618442, 29876.909168650527]
5.0
1
2
3
4
5
6
7
8
9
10
11
[-138263.19851439225, 401284.34693589684, 406290.05737117864, 412659.034

Some rows of transmat_ have zero sum because no transition from the state was ever observed.


2.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[-643249.5653963943, 1116545.2532690472, 1151192.3140791268, 1157569.2302384214, 1157418.9727026934, 1158787.7046545707, 1160856.3547408103, 1161601.9994728279, 1205553.4604700871, 1205534.9777207545, 1206248.3961914375, 1206854.1269277611, 1206525.3253440487]
0.0
1
1.0
1
2
3
4


Some rows of transmat_ have zero sum because no transition from the state was ever observed.


2.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[1129.2454079700872, 129131.4137647617, 550848.6329431692, 555544.9363835234, 555549.5552453535, 555607.2735772423, 555835.6837004757, 555850.4977984767, 555858.3321834889, 555887.5629066653, 555932.9753211556, 555953.3009488548, 555959.282853974]
3.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
[-112442.8025336316, -39317.08819174419, -7017.794951463795, 83114.62278765706, 82999.97897994451, 83921.8514640848, 84935.98541884897, 85393.797698383, 86158.85928089106, 86156.36748912158, 96751.59454486903, 97003.39979276761, 97055.2725433984, 97054.80682326427]
4.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[-7208.75949191738, 26915.153143835905, 27995.689292442617, 29160.554748855033, 29208.040142358346, 29211.63610801607, 29625.907831656343, 29824.69072010966, 29829.490395945497, 29837.00607417159, 29844.256129627574, 29855.997410354383, 29866.95204371824, 29872.548566184643, 29876.90916864964]
5.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
[-138263.19851439225, 401284.34693589684, 4

Some rows of transmat_ have zero sum because no transition from the state was ever observed.


3.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[1129.2454079700872, 129131.4137647617, 550848.6329431692, 555544.9363835234, 555549.5552453535, 555607.2735772423, 555835.6837004757, 555850.4977984767, 555858.2737736449, 555887.5629066653, 555932.9098583324, 555953.3009488548, 555959.282853974]
4.0
1
2
3
4
5
6
7
8
9
10
11
12
[-13532.017988658421, 17848.797703134296, 18969.791357488622, 19059.407141796084, 19245.681196202786, 19331.607132595702, 21250.06003772672, 21252.771043651115, 21238.656747982255, 21242.65707656148, 21288.170448939694, 21271.737344674017]
5.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
[-90868.25672956358, -34556.33632127404, -14834.195339887598, 59351.25248620618, 59252.595178523414, 60058.05362512163, 60890.292794698085, 61424.59106236262, 61994.80474768313, 62005.59574446957, 70158.43581086013, 70443.30435111103, 70444.99153239492, 70368.51321853384]
6.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[-7208.75949191738, 26915.153143835905, 27995.689292442617, 29161.31260477689, 29208.040142358353

In [47]:
NYT_W_Z_T_HMMs = makeHMMlist(NYT_W, DTW_Clusters, 'NYT_W_Z_T')
NYT_W_Z_L_HMMs = makeHMMlist(NYT_W, DTW_Clusters, 'NYT_W_Z_L')
NYT_W_N_T_HMMs = makeHMMlist(NYT_W, DTW_Clusters, 'NYT_W_N_T')
NYT_W_N_L_HMMs = makeHMMlist(NYT_W, DTW_Clusters, 'NYT_W_N_L')


1.0
1
2
3
4
5
6
7
8
9
10
11
12
[-2452.428033385679, 399177.60553987784, 400973.58823451237, 402038.1858779473, 402096.96496141487, 402112.12582619156, 402544.4337715863, 402713.7366351885, 402790.90447621583, 402845.3126318618, 402883.963925679, 402888.182813921]
2.0
1
2
3
4
5
6
7
8
9
10
11
12
[-339388.54571535834, -63515.11260018592, 432590.51119133533, 438929.56157301343, 438646.5468898114, 440961.4436858139, 442780.4704137443, 443794.3433291237, 486367.0113377221, 486527.08287282323, 487144.4657981419, 486056.80324858986]
1.0
1
2
3
4
2.0
1
2
3
4
5
6
7
8
9
10
11
12
[-5606.694113993169, 16523.255573163544, 17099.42067486449, 19214.94558083832, 19260.094516698577, 19693.254877380947, 19702.591790746897, 19865.77945568011, 19865.214428954136, 19869.05750870027, 19908.419823151187, 19912.63570523207]
3.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
[-7076.939525646441, -3830.756272075452, -2920.2497666996164, -2367.3159856399425, -247.2958661271614, 2204.894192656356, 2239.4833033075734, 2308.841196

Some rows of transmat_ have zero sum because no transition from the state was ever observed.


1.0
1
2
3
4
5
6
7
8
9
10
11
12
[-6108.009744635576, 19860.858327682596, 20491.545149080244, 22652.437402620555, 22687.957594756135, 22709.48740164542, 23134.144512346455, 23301.835833419194, 23301.48051905241, 23305.542765363763, 23344.12205638138, 23348.33964693049]
2.0
1
2
3
4
5
6
7
8
9
10
11
12
[-339388.54571535834, -63515.11260018592, 432590.51119133533, 438929.5615730135, 438599.6225289795, 440599.1839519624, 442792.1988646276, 443844.4007610078, 463851.1214043949, 486440.54412338464, 486915.2635040722, 445766.735191901]
0.0
1


Some rows of transmat_ have zero sum because no transition from the state was ever observed.


1.0
1
2
3
4
2.0
1
2
3
4
5
6
7
8
9
10
11
12
[-5606.694113993169, 16523.255573163544, 17145.885404898512, 19234.294508060208, 19260.094516698566, 19276.5718757175, 19702.59179074692, 19865.779455679978, 19865.21442895436, 19869.057508700655, 19908.419823150874, 19912.635705232344]
3.0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
[-7076.939525646441, -3830.756272075452, -2924.2766144159855, -2451.477452451045, -1400.7973341382349, 2218.161354990779, 2240.6124765861023, 2320.8422538358795, 2320.637996575055, 3355.8780920974523, 3388.7780553162056, 3398.4540613522945, 3406.6586970897856, 3410.078768112662, 3352.5795726403335]
4.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[-61495.88231557796, -31492.103518578566, -26202.81463481085, -23787.676305214023, -24033.142241769714, -22191.144736246333, -20962.01226612796, -20007.377163292294, 132.87081520581535, 7633.6305331832955, 14810.491179510062, 15162.428854596224, 14946.474496772755]
5.0
1
2
3
4
5
6
7
8
9
10
11
12
13
[-28286.88119526608, 75272.41387951768, 78847.

In [53]:
fips = list(DTW_Clusters[DTW_Clusters['NYT_F_N_L_L'] == 9]['FIPS'])
Rows = NYT_F[NYT_F['FIPS'].isin(fips)]
Rows

Unnamed: 0,FIPS,date,id,cases,Deaths
0,1001,2020-01-21,"1001, 2020-01-21",0.0,0.0
1,1001,2020-01-22,"1001, 2020-01-22",0.0,0.0
2,1001,2020-01-23,"1001, 2020-01-23",0.0,0.0
3,1001,2020-01-24,"1001, 2020-01-24",0.0,0.0
4,1001,2020-01-25,"1001, 2020-01-25",0.0,0.0
5,1001,2020-01-26,"1001, 2020-01-26",0.0,0.0
6,1001,2020-01-27,"1001, 2020-01-27",0.0,0.0
7,1001,2020-01-28,"1001, 2020-01-28",0.0,0.0
8,1001,2020-01-29,"1001, 2020-01-29",0.0,0.0
9,1001,2020-01-30,"1001, 2020-01-30",0.0,0.0


In [54]:
X1 = [0.5, 1.0, -1.0, 0.42, 0.24]
X2 = [2.4, 4.2, 0.5, -0.24]
X = np.concatenate([X1, X2]).reshape(-1,1)
lengths = [len(X1), len(X2)]
hmm.GMMHMM(n_components=2,n_mix=2).fit(X, lengths)

Fitting a model with 12 free scalar parameters with only 9 data points will result in a degenerate solution.


GMMHMM(algorithm='viterbi', covariance_type='diag',
       covars_prior=array([[[-1.5],
        [-1.5]],

       [[-1.5],
        [-1.5]]]),
       covars_weight=array([[[0.],
        [0.]],

       [[0.],
        [0.]]]),
       init_params='stmcw',
       means_prior=array([[[0.],
        [0.]],

       [[0.],
        [0.]]]),
       means_weight=array([[0., 0.],
       [0., 0.]]), min_covar=0.001,
       n_components=2, n_iter=10, n_mix=2, params='stmcw', random_state=None,
       startprob_prior=1.0, tol=0.01, transmat_prior=1.0, verbose=False,
       weights_prior=array([[1., 1.],
       [1., 1.]]))

In [55]:
X

array([[ 0.5 ],
       [ 1.  ],
       [-1.  ],
       [ 0.42],
       [ 0.24],
       [ 2.4 ],
       [ 4.2 ],
       [ 0.5 ],
       [-0.24]])