In [1]:
import os
import pandas as pd
import numpy as np
import sys
import re
import logging
from Modules.Loader_wrangler import *
import random
import torch
import torch.nn as nn
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler

In [2]:
# Configure basic logging
logging.basicConfig(level=logging.INFO, force=True, format='%(levelname)s: %(message)s')

In [None]:
play = loader(output_file_name="merged_df2017.pkl", chunksize=100000, sample_size=100000, survey_year=2017)

In [4]:
play = pd.read_pickle("/home/trapfishscott/Cambridge24.25/D200_ML_econ/ProblemSets/Project/data/merged_df2017.pkl")

### Obtaining only relevant variables and making into a time series

In [26]:
#temporal_vars = ["TWSMonth", "TravelYear", "TravelWeekDay_B01ID"]
#individual_vars =["PSUGOR_B02ID", "IndIncome2002_B02ID", "HHoldNumChildren", "DVLALengthBand_B01ID"]

numerical_outcome_vars = ["TripStart", "TripEnd", "TripDisExSW"]
categorical_outcome_vars = ["TripPurpose_B01ID"]


extra_vars = ["IndividualID_x", "JourSeq"]


features_one_hot = ["PSUGOR_B02ID"]
features_numerical = ["TravelYear", "HHoldNumChildren", "IndIncome2002_B02ID", "DVLALengthBand_B01ID"]
features_cyclical = ["TWSMonth", "TravelWeekDay_B01ID"]

features = features_one_hot + features_numerical + features_cyclical
outcomes = numerical_outcome_vars + categorical_outcome_vars

In [6]:
ts_df = play[extra_vars +  features + outcomes]

In [7]:
ts_df = ts_df.sort_values(["IndividualID_x", "TravelWeekDay_B01ID", "JourSeq"])

## Data Manipulation pipeline

1. One-hot encode categorical features + any small cleaning steps
2. Add days of the week with no car travel
3. Make data frame into wide format
4. Convert to tensor

* Includes JourSeq gaps if trips were made by non-car inbetween

In [None]:
### small cleaning steps and one hot encoding

### Cyclical encoder

In [8]:
def apply_cyclical_encoding(column, type_, max_val):

    if type_ == "cos":
        return np.cos(2 * np.pi * column/ max_val)
    else:
        return np.sin(2 * np.pi * column/ max_val)



### Imputing missing travel days

In [63]:
def impute_missing_travel_week_for_i(i_df, i_id, full_week_encoding, features=features, outcomes=outcomes):
        
    break_flag = False

    # Travel days with travel 
    included_travel_day = i_df["TravelWeekDay_B01ID"].to_list()

    # Travel days with no travel
    travel_day_no_drive = list(set(full_week_encoding) - set(included_travel_day))

    # These values will repeat for empty-travel travel days
    imputed_travel_df = pd.DataFrame({
        "TravelWeekDay_B01ID": travel_day_no_drive,
        "IndividualID_x": [i_id]*len(travel_day_no_drive),
        "JourSeq": [1]*len(travel_day_no_drive)
    })

    # Looping through all the columns in the original df
    for col in i_df.columns:

        # For days with no travel all outcomes vars will take 0
        if col in outcomes:
            imputed_travel_df[col] = [0]*len(travel_day_no_drive)

        else:
        
            if col not in extra_vars + ["TravelWeekDay_B01ID"]:
                if len(i_df[col].unique()) != 1:
                    print(f"{col} is erroneous for {i_id}")
                    print(f"Unique vals: {i_df[col].unique()}")
                    break_flag = True
                    break
                else:
                    imputed_travel_df[col] = i_df[col].unique()[0]

    if break_flag:
        print("Continuing to next individual")
        return
    

    # display(imputed_travel_df)

    # Concatenating df to include empty travel days
    full_df = pd.concat([i_df, imputed_travel_df])

    # Apply cyclical encoding to cyclical column

    full_df["TWSMonth_cos"] = apply_cyclical_encoding(column=full_df["TWSMonth"], type_="cos", max_val=12)
    full_df["TWSMonth_sin"] =apply_cyclical_encoding(column=full_df["TWSMonth"], type_="sin", max_val=12)

    full_df["TravelWeekDay_B01ID_cos"] = apply_cyclical_encoding(column=full_df["TravelWeekDay_B01ID"], type_="cos", max_val=7)
    full_df["TravelWeekDay_B01ID_sin"] =apply_cyclical_encoding(column=full_df["TravelWeekDay_B01ID"], type_="sin", max_val=7)

    full_df = full_df.sort_values(["TravelYear", "TWSMonth", "TravelWeekDay_B01ID", "JourSeq", "TripStart", "TripEnd"])

    #display(full_df)


    return full_df

### Transforming to wide

In [130]:
def transform_to_wide_for_i(i_df, max_journey_seq, seq_length = 7, outcomes=outcomes, features=features, extra_vars=extra_vars):
    df = i_df.copy()

    expected_all = [f"{col}_{i}" for col in outcomes for i in range(1, max_journey_seq+1)]
    expected_categorical = [f"{col}_{i}" for col in categorical_outcome_vars for i in range(1, max_journey_seq+1)]

    df = df[df["JourSeq"]<=max_journey_seq]

    #

    df_wide = df.pivot(index="TravelWeekDay_B01ID",
                  columns = "JourSeq",
                  values = outcomes)
    
    df_wide.columns = [f"{col[0]}_{int(col[1])}" for col in df_wide.columns]

    for col in expected_all:
        if col not in df_wide.columns:
            df_wide[col] = 0
    
    # Ensure column order is consistent
    df_wide = df_wide[expected_all]
    
    df_wide = df_wide.fillna(0)

    df_wide.reset_index(inplace=True)

    # Dropping outcome columns
    df.drop(columns=outcomes + extra_vars, axis=1, inplace = True)
    df.drop_duplicates(subset=["TravelWeekDay_B01ID"], inplace=True)

    df_wide = df_wide.merge(df, on="TravelWeekDay_B01ID", how="left")

    top_row = df_wide.head(1).copy()

    for col in expected_all:
        top_row[col] = 0
        top_row["TravelWeekDay_B01ID"] = 0

    repeated_rows = pd.concat([top_row] * seq_length, ignore_index=True)

    df_wide = pd.concat([repeated_rows, df_wide], ignore_index=True)

    df_wide.drop(columns=features_one_hot + features_cyclical, inplace=True, axis=1)

    #df_wide.drop(columns=features_cyclical + features_one_hot, axis=1, inplace=True)

    targets_only = df_wide.drop(columns=features + extra_vars, axis=1, errors="ignore")

    targets_only = targets_only.iloc[seq_length:,:]

    targets_cont = targets_only[expected_all]
    targets_cont = targets_cont.copy()
    targets_cont.drop(columns=expected_categorical, axis=1, inplace=True)


    targets_cat = targets_only[expected_categorical]

    return df_wide, targets_cont, targets_cat

### Putting altogether for LSTM

In [131]:
def prepare_data_for_LSTM(long_df, impute_missing_travel_weeks=True, transform_to_wide=False, transform_to_tensor=False, debug=False):

    df = long_df.copy()
           

    #df = df[~df["DVLALengthBand_B01ID"].isin([-8, -10])]

    # All unique individual id's to loop over
    individual_ids = df["IndividualID_x"].unique()

    # Apply numerical encoding to numerical column
    num_scaler = MinMaxScaler()

    df[numerical_outcome_vars] = num_scaler.fit_transform(df[numerical_outcome_vars])
    df[features_numerical] = num_scaler.fit_transform(df[features_numerical])

    for col in features_one_hot:
        df[col] = df[col].astype(int)
    
    # Apply one-hot to categorical
    ohe = OneHotEncoder(sparse_output=False)

    ohe_array = ohe.fit_transform(df[features_one_hot])
    ohe_df = pd.DataFrame(ohe_array, columns=ohe.get_feature_names_out(features_one_hot))

    # Reset index to avoid misalignment
    df.reset_index(drop=True, inplace=True)
    ohe_df.reset_index(drop=True, inplace=True)

    df = pd.concat([df, ohe_df], axis=1)

    df_chunks = []

    full_week_encoding = list(range(1,8))

    if debug:
        random_index = random.randint(0, len(individual_ids))

        debug_df = df[df["IndividualID_x"] == individual_ids[random_index]]

        display(debug_df)

        debug_df = impute_missing_travel_week_for_i(debug_df, i_id=individual_ids[random_index], full_week_encoding=full_week_encoding)

        display(debug_df)

        debug_df, debug_targets_cont, debug_targets_cat = transform_to_wide_for_i(debug_df, max_journey_seq=10)

        print(debug_df.columns)

        display(debug_df)

        display(debug_targets_cont)

        display(debug_targets_cat)

        return
    
    if transform_to_tensor:
        individual_tensors = []
        target_cont_tensors = []
        target_cat_tensors = []
    
    if impute_missing_travel_weeks:

        for i, individual_id in enumerate(individual_ids[:]):

            i_df = df[df["IndividualID_x"] == individual_id]

            full_df = impute_missing_travel_week_for_i(i_df, i_id=individual_id, full_week_encoding=full_week_encoding)

            #display(full_df)

            if full_df is not None:
                if not transform_to_wide:
                    df_chunks.append(full_df)

                else:

                    full_df, targets_cont, targets_cat = transform_to_wide_for_i(full_df, max_journey_seq=10)
                    
                    if transform_to_tensor:

                    
                        full_arr = full_df.to_numpy()
                        full_arr = np.expand_dims(full_arr, axis=1)

                        targets_cont_arr = targets_cont.to_numpy()
                        targets_cat_arr = targets_cat.to_numpy()

                        full_i_tensor = torch.tensor(full_arr)
                        target_cont_i_tensor = torch.tensor(targets_cont_arr)
                        target_cat_i_tensor = torch.tensor(targets_cat_arr)

                        individual_tensors.append(full_i_tensor)
                        target_cont_tensors.append(target_cont_i_tensor)
                        target_cat_tensors.append(target_cat_i_tensor)


                    else:

                        #display(full_df)
                        print("")
                        #display(targets)
                        df_chunks.append(full_df)

            sys.stdout.write(f"\rIndividual {i+1} out of {len(individual_ids)} Complete!    ")
            sys.stdout.flush()

        if transform_to_tensor:
            individual_tensors = torch.stack(individual_tensors, dim=0)
            target_cont_tensors = torch.stack(target_cont_tensors, dim=0)
            target_cat_tensors = torch.stack(target_cat_tensors, dim=0)
            return individual_tensors, target_cont_tensors, target_cat_tensors
        
        else:

            df_to_return = pd.concat(df_chunks)

            return df_to_return



    else:
        return df


In [133]:
df = prepare_data_for_LSTM(long_df=ts_df, debug=True)


Unnamed: 0,IndividualID_x,JourSeq,PSUGOR_B02ID,IndIncome2002_B02ID,DVLALengthBand_B01ID,TravelYear,HHoldNumChildren,TWSMonth,TravelWeekDay_B01ID,TripStart,...,TripPurpose_B01ID,PSUGOR_B02ID_1,PSUGOR_B02ID_2,PSUGOR_B02ID_3,PSUGOR_B02ID_4,PSUGOR_B02ID_5,PSUGOR_B02ID_6,PSUGOR_B02ID_7,PSUGOR_B02ID_8,PSUGOR_B02ID_9
92864,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,1.0,0.250174,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92865,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,1.0,0.66713,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92866,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,2.0,0.250174,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92867,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,2.0,0.66713,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92868,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,3.0,0.250174,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92869,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,3.0,0.66713,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92870,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,4.0,0.250174,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92871,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,4.0,0.66713,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92872,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,5.0,0.250174,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
92873,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,5.0,0.66713,...,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


Unnamed: 0,IndividualID_x,JourSeq,PSUGOR_B02ID,IndIncome2002_B02ID,DVLALengthBand_B01ID,TravelYear,HHoldNumChildren,TWSMonth,TravelWeekDay_B01ID,TripStart,...,PSUGOR_B02ID_4,PSUGOR_B02ID_5,PSUGOR_B02ID_6,PSUGOR_B02ID_7,PSUGOR_B02ID_8,PSUGOR_B02ID_9,TWSMonth_cos,TWSMonth_sin,TravelWeekDay_B01ID_cos,TravelWeekDay_B01ID_sin
92864,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,1.0,0.250174,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
92865,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,1.0,0.66713,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
92866,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,2.0,0.250174,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.222521,0.9749279
92867,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,2.0,0.66713,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.222521,0.9749279
92868,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,3.0,0.250174,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.900969,0.4338837
92869,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,3.0,0.66713,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.900969,0.4338837
92870,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,4.0,0.250174,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.900969,-0.4338837
92871,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,4.0,0.66713,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.900969,-0.4338837
92872,2017016000.0,1.0,7,0.0,0.0,1.0,0.0,1.0,5.0,0.250174,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.222521,-0.9749279
92873,2017016000.0,2.0,7,0.0,0.0,1.0,0.0,1.0,5.0,0.66713,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.222521,-0.9749279


Index(['TripStart_1', 'TripStart_2', 'TripStart_3', 'TripStart_4',
       'TripStart_5', 'TripStart_6', 'TripStart_7', 'TripStart_8',
       'TripStart_9', 'TripStart_10', 'TripEnd_1', 'TripEnd_2', 'TripEnd_3',
       'TripEnd_4', 'TripEnd_5', 'TripEnd_6', 'TripEnd_7', 'TripEnd_8',
       'TripEnd_9', 'TripEnd_10', 'TripDisExSW_1', 'TripDisExSW_2',
       'TripDisExSW_3', 'TripDisExSW_4', 'TripDisExSW_5', 'TripDisExSW_6',
       'TripDisExSW_7', 'TripDisExSW_8', 'TripDisExSW_9', 'TripDisExSW_10',
       'TripPurpose_B01ID_1', 'TripPurpose_B01ID_2', 'TripPurpose_B01ID_3',
       'TripPurpose_B01ID_4', 'TripPurpose_B01ID_5', 'TripPurpose_B01ID_6',
       'TripPurpose_B01ID_7', 'TripPurpose_B01ID_8', 'TripPurpose_B01ID_9',
       'TripPurpose_B01ID_10', 'IndIncome2002_B02ID', 'DVLALengthBand_B01ID',
       'TravelYear', 'HHoldNumChildren', 'PSUGOR_B02ID_1', 'PSUGOR_B02ID_2',
       'PSUGOR_B02ID_3', 'PSUGOR_B02ID_4', 'PSUGOR_B02ID_5', 'PSUGOR_B02ID_6',
       'PSUGOR_B02ID_7', 'PSUGOR_B02

Unnamed: 0,TripStart_1,TripStart_2,TripStart_3,TripStart_4,TripStart_5,TripStart_6,TripStart_7,TripStart_8,TripStart_9,TripStart_10,...,PSUGOR_B02ID_4,PSUGOR_B02ID_5,PSUGOR_B02ID_6,PSUGOR_B02ID_7,PSUGOR_B02ID_8,PSUGOR_B02ID_9,TWSMonth_cos,TWSMonth_sin,TravelWeekDay_B01ID_cos,TravelWeekDay_B01ID_sin
0,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
1,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
2,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
3,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
4,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
5,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
6,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
7,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,0.62349,0.7818315
8,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.222521,0.9749279
9,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.866025,0.5,-0.900969,0.4338837


Unnamed: 0,TripStart_1,TripStart_2,TripStart_3,TripStart_4,TripStart_5,TripStart_6,TripStart_7,TripStart_8,TripStart_9,TripStart_10,...,TripDisExSW_1,TripDisExSW_2,TripDisExSW_3,TripDisExSW_4,TripDisExSW_5,TripDisExSW_6,TripDisExSW_7,TripDisExSW_8,TripDisExSW_9,TripDisExSW_10
7,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.072296,0.072296,0,0,0,0,0,0,0,0
8,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.072296,0.072296,0,0,0,0,0,0,0,0
9,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.072296,0.072296,0,0,0,0,0,0,0,0
10,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.072296,0.072296,0,0,0,0,0,0,0,0
11,0.250174,0.66713,0,0,0,0,0,0,0,0,...,0.072296,0.072296,0,0,0,0,0,0,0,0
12,0.271022,0.542043,0,0,0,0,0,0,0,0,...,0.054176,0.054176,0,0,0,0,0,0,0,0
13,0.0,0.0,0,0,0,0,0,0,0,0,...,0.0,0.0,0,0,0,0,0,0,0,0


Unnamed: 0,TripPurpose_B01ID_1,TripPurpose_B01ID_2,TripPurpose_B01ID_3,TripPurpose_B01ID_4,TripPurpose_B01ID_5,TripPurpose_B01ID_6,TripPurpose_B01ID_7,TripPurpose_B01ID_8,TripPurpose_B01ID_9,TripPurpose_B01ID_10
7,2.0,2.0,0,0,0,0,0,0,0,0
8,2.0,2.0,0,0,0,0,0,0,0,0
9,2.0,2.0,0,0,0,0,0,0,0,0
10,2.0,2.0,0,0,0,0,0,0,0,0
11,2.0,2.0,0,0,0,0,0,0,0,0
12,2.0,2.0,0,0,0,0,0,0,0,0
13,0.0,0.0,0,0,0,0,0,0,0,0


In [134]:
X, y_cont, y_cat = prepare_data_for_LSTM(long_df=ts_df, transform_to_wide=True, transform_to_tensor=True)

Individual 5949 out of 6838 Complete!    TravelYear is erroneous for 2017014397.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 5950 out of 6838 Complete!    TravelYear is erroneous for 2017014398.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 5998 out of 6838 Complete!    TravelYear is erroneous for 2017014552.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 6057 out of 6838 Complete!    TravelYear is erroneous for 2017014714.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 6058 out of 6838 Complete!    TravelYear is erroneous for 2017014715.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 6086 out of 6838 Complete!    TravelYear is erroneous for 2017014773.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 6164 out of 6838 Complete!    TravelYear is erroneous for 2017014964.0
Unique vals: [1. 0.]
Continuing to next individual
Individual 6165 out of 6838 Complete!    TravelYear is erroneous for 

In [246]:
X = X.to(torch.float32)
y_cont = y_cont.to(torch.float32)
y_cat = y_cat.to(torch.float32)

print(f"Input shape: {X.shape}")
print(f"Cont Output shape: {y_cont.shape}")
print(f"Cat Output shape: {y_cat.shape}")

Input shape: torch.Size([6775, 14, 1, 57])
Cont Output shape: torch.Size([6775, 7, 30])
Cat Output shape: torch.Size([6775, 7, 10])


In [165]:
y_cont[0,0,:]

tensor([0.4170, 0.4795, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.4309, 0.4934, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0071, 0.0071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000], dtype=torch.float64)

In [137]:
# Save tensors
with open("/home/trapfishscott/Cambridge24.25/D200_ML_econ/ProblemSets/Project/tensors/tensors.pkl", "wb") as f:
    pickle.dump((X, y_cont, y_cat), f)

### Creating the RNN

In [138]:
# Defining parameters
INPUT_SIZE = X.shape[3]
HIDDEN_SIZE = 3
NUM_LAYERS = 1
OUTPUT_SIZE_CONT = y_cont.shape[2]
OUTPUT_SIZE_CAT = y_cat.shape[2]

In [244]:
class RNNmodel(nn.Module):
    def __init__(self):
        super().__init__()

        # Define RNN layer

        self.rnn = nn.RNN(INPUT_SIZE, HIDDEN_SIZE)

        # Output layer

        self.output_cont = nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE_CONT)
        self.output_cat = nn.Linear(HIDDEN_SIZE, OUTPUT_SIZE_CAT)


    def forward(self, X):

        out, hh = self.rnn(X)

        #print(f"out shape: {out.shape}")
        #print(f"hh shape: {hh.shape}")

        y_cont_hat_vector = self.output_cont(hh)
        y_cat_hat_vector = self.output_cat(hh)

        y_cat_hat = y_cat_hat_vector[0,0,:]
        y_cont_hat = y_cont_hat_vector[0,0,:]


        #print(y_hat_vector)

        '''
        y_cont_hat = {}

        for index in range(y_cont_hat_vector.shape[2]):
            y_cont_hat[index+1] = y_cont_hat_vector[:,:,index].detach()
        '''


        return y_cont_hat, y_cat_hat


In [141]:
# Taking one test draw

rnn_model = RNNmodel()

X0 = X[0,:,0,:].unsqueeze(1).to(torch.float32)
print(f"X1 shape: {X0.shape}")
print("")

y_cont_hat, y_cat_hat = rnn_model.forward(X0)

print(f"Categorical outputs:  {y_cat_hat}")
print(f"Ground truth categorical: {y_cat[0,0,:]}")
print("")
print(f"Continous outputs:  {y_cont_hat}")
print(f"Ground truth Continous: {y_cont[0,0,:]}")

loss_cat = nn.CrossEntropyLoss()  #(y_hat, y)
loss_cont = nn.MSELoss()

print(f"Categorical loss: {loss_cat(y_cat_hat, y_cat[0,0,:])}")
print(f"Continous loss: {loss_cont(y_cont_hat, y_cont[0,0,:])}")


X1 shape: torch.Size([14, 1, 57])

Categorical outputs:  tensor([ 0.0226,  0.9181, -0.2481,  0.3822,  0.2465, -0.1911,  0.1248, -0.1309,
        -0.6274,  0.2030])
Ground truth categorical: tensor([6., 6., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)

Continous outputs:  tensor([-0.0610,  0.4366, -0.8558,  0.5995,  0.7486, -0.0321, -0.3573,  0.2202,
         0.1094,  0.2829,  0.0116,  0.4470, -0.0714,  0.0055,  0.2632, -0.0935,
        -0.5309, -0.0201,  0.0790,  0.0084,  0.9121, -0.5325, -0.7127,  0.4314,
        -0.0205, -0.4988,  0.9897,  0.2369, -0.0822, -0.3129])
Ground truth Continous: tensor([0.4170, 0.4795, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.4309, 0.4934, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0071, 0.0071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000], dtype=torch.float64)
Categorical loss: 23.81736946105957
Continous loss: 0.19650199432470233


In [259]:
rnn_model = RNNmodel()

loss_cat = nn.CrossEntropyLoss()  #(y_hat, y)
loss_cont = nn.MSELoss()
optimizer = torch.optim.SGD(rnn_model.parameters(), lr=0.001, momentum=0.9)

epochs = 1

seq_length = 7

for epochi in range(epochs):
    #print(epochi)
    for individual_i in range(100):

        travel_diary = X[individual_i, :, 0, :].unsqueeze(1).to(torch.float32)

        for i in range(1, travel_diary.shape[0] - seq_length):
            sliding_input = travel_diary[i:seq_length+i,0,:].unsqueeze(1)

            #print(sliding_input.shape)

            y_cont_hat, y_cat_hat = rnn_model.forward(sliding_input)

            y_cont_i = y_cont[individual_i,i,:]
            y_cat_i = y_cat[individual_i,i,:]

            print(f"Pred_cont: {y_cont_hat}")
            print(f"True cont: {y_cont_i}")

            print(f"Pred_cat: {y_cat_hat}")
            print(f"True cat: {y_cat_i}")

            categorical_loss = loss_cat(y_cat_hat, y_cat_i)


            continuous_loss = loss_cont(y_cont_hat, y_cont_i)


            combined_loss = categorical_loss + continuous_loss
            
            optimizer.zero_grad()
            combined_loss.backward()
            optimizer.step()
            
            print("")
            #print(f"\repoch: {epochi} | individual: {individual_i} | categorical loss: {categorical_loss:.2f} | continuous loss: {continuous_loss:2f}", end="", flush=True)

            



    


Pred_cont: tensor([-0.1867,  0.7533,  0.6751,  0.3286,  0.2847, -0.7440,  1.5087,  0.6880,
         0.5694, -0.0554,  0.5237, -0.2351,  0.1902,  1.0121, -0.5878,  1.5739,
         0.9583, -1.0863,  0.7430,  0.5786,  0.5330, -0.5584,  0.6188,  0.1306,
        -0.9419,  0.2115, -0.2175, -0.3655,  0.2113, -0.3633],
       grad_fn=<SliceBackward0>)
True cont: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.])
Pred_cat: tensor([ 8.7910e-01,  2.7775e-01, -9.3972e-01,  4.2874e-01,  3.5653e-02,
        -3.2972e-01, -4.5367e-01,  1.2401e+00,  3.9592e-04, -3.0683e-01],
       grad_fn=<SliceBackward0>)
True cat: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Pred_cont: tensor([-0.0362,  0.8370,  0.3861,  0.0780,  0.0204, -0.3423,  0.9993,  0.2944,
         0.4949, -0.2250,  0.2888,  0.0052,  0.0029,  0.8711, -0.0765,  0.9717,
         0.6267, -0.9457,  0.4254,  0.2887,  0.4107, -0.1000,  0.3398,  0.2928,
    