In [None]:
import pandas as pd
import numpy as np
import pyreadr


def create_y_dataframe(df, index_col, column_col, value_col):
    """
    Create Y dataframe (entity x time matrix) from long-form data.
    
    Args:
        df: Long-form dataframe
        index_col: Column name for entities/rows
        column_col: Column name for time periods/columns
        value_col: Column name for values
    
    Returns:
        Y_df: Wide-form dataframe with entities as index and time as columns
    """
    df[column_col] = df[column_col].astype(int)

    Y_df = df.pivot(index=index_col, columns=column_col, values=value_col)
    return Y_df


def create_z_dataframe(Y_df, treated_entity, treatment_start_year):
    """
    Create Z dataframe (treatment indicator matrix).
    
    Args:
        Y_df: Entity x time matrix
        treated_entity: Name of the treated entity
        treatment_start_year: Year when treatment starts
    
    Returns:
        Z_df: Treatment indicator matrix (1 for treated entity after treatment, 0 otherwise)
    """
    Z_df = pd.DataFrame(0, index=Y_df.index, columns=Y_df.columns)
    treated_mask = Z_df.index == treated_entity
    post_treatment_mask = Z_df.columns >= treatment_start_year
    Z_df.loc[treated_mask, post_treatment_mask] = 1
    return Z_df


def create_x_dataframe(df, Y_df, index_col, time_col, covariate_cols, avg_start_year, avg_end_year, 
                       additional_cols=None):
    """
    Create X dataframe (covariates matrix) by averaging over pre-treatment period.
    
    Args:
        df: Long-form dataframe
        Y_df: Entity x time matrix
        index_col: Column name for entities
        time_col: Column name for time periods
        covariate_cols: List of covariate columns to average
        avg_start_year: Start year for averaging period
        avg_end_year: End year for averaging period
        additional_cols: Dict mapping new column names to years to add from Y_df
                        e.g., {'Smoking 1988': 1988, 'Smoking 1980': 1980}
    
    Returns:
        X_df: Covariates matrix with entities as rows
    """
    # Filter to averaging period
    mask = (df[time_col] >= avg_start_year) & (df[time_col] <= avg_end_year)
    
    # Average covariates over period
    X_df = (
        df.loc[mask, [index_col] + covariate_cols]
          .groupby(index_col, as_index=False)
          .mean()
    )
    
    # Add additional columns from Y_df if specified
    if additional_cols is not None:
        for col_name, year in additional_cols.items():
            if year in Y_df.columns:
                X_df[col_name] = Y_df[year].values
    
    return X_df

### Smoking Dataset

In [None]:
# Load the smoking.rda data (if not already loaded)
result = pyreadr.read_r('../src/causaltensor/datasets/raw/smoking.rda')
df = list(result.values())[0]

# State mapping from documentation
state_mapping = {
    1: 'Alabama', 2: 'Arkansas', 3: 'California', 4: 'Colorado',
    5: 'Connecticut', 6: 'Delaware', 7: 'Georgia', 8: 'Idaho',
    9: 'Illinois', 10: 'Indiana', 11: 'Iowa', 12: 'Kansas',
    13: 'Kentucky', 14: 'Louisiana', 15: 'Maine', 16: 'Minnesota',
    17: 'Mississippi', 18: 'Missouri', 19: 'Montana', 20: 'Nebraska',
    21: 'Nevada', 22: 'New Hampshire', 23: 'New Mexico', 24: 'North Carolina',
    25: 'North Dakota', 26: 'Ohio', 27: 'Oklahoma', 28: 'Pennsylvania',
    29: 'Rhode Island', 30: 'South Carolina', 31: 'South Dakota', 32: 'Tennessee',
    33: 'Texas', 34: 'Utah', 35: 'Vermont', 36: 'Virginia',
    37: 'West Virginia', 38: 'Wisconsin', 39: 'Wyoming'
}
df["state"] = df["state"].map(state_mapping)


# Create Y_df, Z_df, and X_df using helper functions
Y_df = create_y_dataframe(df, index_col="state", column_col="year", value_col="cigsale")
Y = Y_df.to_numpy()

# Create Z matrix: 1 for California after 1988, 0 otherwise
Z_df = create_z_dataframe(Y_df, treated_entity='California', treatment_start_year=1988)

# Create X dataframe
cols_to_avg = [col for col in df.columns if col not in ['state', 'year', 'cigsale']]
X_df = create_x_dataframe(
    df, Y_df, 
    index_col='state', 
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1980, 
    avg_end_year=1988,
    additional_cols={'Smoking 1988': 1988, 'Smoking 1980': 1980, 'Smoking 1975': 1975}
)




### Basque Dataset

In [None]:
# Load the basque.rda data 
result = pyreadr.read_r('../src/causaltensor/datasets/raw/basque.rda')
df = list(result.values())[0]


# Create Y_df, Z_df, and X_df using helper functions
Y_df = create_y_dataframe(df, index_col="regionname", column_col="year", value_col="gdpcap")

Z_df = create_z_dataframe(Y_df, treated_entity='Basque Country (Pais Vasco)', treatment_start_year=1975)

cols_to_avg = ['invest', 'secagriculture',
       'secenergy', 'secindustry', 'secconstruction', 'secservicesventa',
       'secservicesnonventa', 'schoolillit', 'schoolprim', 'schoolmed',
       'schoolhigh', 'schoolposthigh', 'popdens']

X_df = create_x_dataframe(
    df, Y_df,
    index_col='regionname',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1964,
    avg_end_year=1969,
    additional_cols={'gdpcap1960': 1960, 'gdpcap1965': 1965, 'gdpcap1970': 1970}
)



### German Reunification Dataset

In [None]:
df = pd.read_csv('../src/causaltensor/datasets/raw/german_reunification.csv')


# Create Y_df, Z_df, and X_df using helper functions
Y_df = create_y_dataframe(df, index_col="country", column_col="year", value_col="gdp")
Y_df

Z_df = create_z_dataframe(Y_df, treated_entity='West Germany', treatment_start_year=1990)

cols_to_avg = ['infrate', 'trade', 'schooling', 'industry']

X_df = create_x_dataframe(
    df, Y_df,
    index_col='country',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1980,
    avg_end_year=1985,
    additional_cols={'gdp1960': 1960, 'gdp1970': 1970, 'gdp1980': 1980, 'gdp1985': 1985}
)

# Add investment columns (these are pre-existing in the original data)
X_df['invest60'] = df[~df['invest60'].isna()]['invest60'].values
X_df['invest70'] = df[~df['invest70'].isna()]['invest70'].values
X_df['invest80'] = df[~df['invest80'].isna()]['invest80'].values


### Texas Prison

In [None]:
result = pyreadr.read_r('../src/causaltensor/datasets/raw/texas.rda')
df = list(result.values())[0]


fips_map = {
    1: "Alabama", 2: "Alaska", 4: "Arizona", 5: "Arkansas", 6: "California",
    8: "Colorado", 9: "Connecticut", 10: "Delaware", 11: "District of Columbia",
    12: "Florida", 13: "Georgia", 15: "Hawaii", 16: "Idaho", 17: "Illinois",
    18: "Indiana", 19: "Iowa", 20: "Kansas", 21: "Kentucky", 22: "Louisiana",
    23: "Maine", 24: "Maryland", 25: "Massachusetts", 26: "Michigan",
    27: "Minnesota", 28: "Mississippi", 29: "Missouri", 30: "Montana",
    31: "Nebraska", 32: "Nevada", 33: "New Hampshire", 34: "New Jersey",
    35: "New Mexico", 36: "New York", 37: "North Carolina", 38: "North Dakota",
    39: "Ohio", 40: "Oklahoma", 41: "Oregon", 42: "Pennsylvania", 44: "Rhode Island",
    45: "South Carolina", 46: "South Dakota", 47: "Tennessee", 48: "Texas",
    49: "Utah", 50: "Vermont", 51: "Virginia", 53: "Washington", 54: "West Virginia",
    55: "Wisconsin", 56: "Wyoming"
}

df["state"] = df["statefip"].map(fips_map)

# Create Y_df, Z_df, and X_df using helper functions
Y_df = create_y_dataframe(df, index_col="state", column_col="year", value_col="bmprate")

Z_df = create_z_dataframe(Y_df, treated_entity='Texas', treatment_start_year=1993)

cols_to_avg = ["income", "ur", "poverty", "black", "perc1519",
  "aidscapita", "crack", "alcohol", "parole",
  "probation", "capacity_operational"]

X_df = create_x_dataframe(
    df, Y_df,
    index_col='state',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1985,
    avg_end_year=1993,
    additional_cols={'bmprate1985': 1985, 'bmprate1990': 1990, 'bmprate1993': 1993}
)




### PWT

In [None]:
df = pd.read_csv('../src/causaltensor/datasets/raw/PWT.csv')
df['openness'] = df['csh_x'] + df['csh_m']
df

Unnamed: 0,countrycode,country,currency_unit,year,rgdpe,rgdpo,pop,emp,avh,hc,...,csh_x,csh_m,csh_r,pl_c,pl_i,pl_g,pl_x,pl_m,pl_k,openness
0,ABW,Aruba,Aruban Guilder,1950,,,,,,,...,,,,,,,,,,
1,ABW,Aruba,Aruban Guilder,1951,,,,,,,...,,,,,,,,,,
2,ABW,Aruba,Aruban Guilder,1952,,,,,,,...,,,,,,,,,,
3,ABW,Aruba,Aruban Guilder,1953,,,,,,,...,,,,,,,,,,
4,ABW,Aruba,Aruban Guilder,1954,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11825,ZWE,Zimbabwe,US Dollar,2010,20652.718750,21053.855469,13.973897,6.298438,,2.372605,...,0.214657,-0.454497,0.014462,0.447170,0.543100,0.411316,0.701797,0.606324,1.015145,-0.239839
11826,ZWE,Zimbabwe,US Dollar,2011,20720.435547,21592.298828,14.255592,6.518841,,2.415823,...,0.219809,-0.625170,0.004390,0.531029,0.606065,0.440252,0.739989,0.637035,0.470333,-0.405361
11827,ZWE,Zimbabwe,US Dollar,2012,23708.654297,24360.527344,14.565482,6.248271,,2.459828,...,0.225631,-0.479897,-0.076998,0.474047,1.363167,0.458315,0.712036,0.634858,0.608320,-0.254266
11828,ZWE,Zimbabwe,US Dollar,2013,27011.988281,28157.886719,14.898092,6.287056,,2.504635,...,0.174443,-0.436145,-0.000005,0.498061,0.575870,0.465031,0.717884,0.630712,0.414526,-0.261702


##### Effect of Spain joining EU in 1986

In [199]:
df1 = df[(df['year'] >= 1970) & (df['year'] <= 2000)]
Y_df = create_y_dataframe(df1, index_col="country", column_col="year", value_col="rgdpe")

Z_df = create_z_dataframe(Y_df, treated_entity='Spain', treatment_start_year=1986)

cols_to_avg = ["hc","csh_i","csh_c","csh_g","openness","pl_gdpo", "pop"]

X_df = create_x_dataframe(
    df1, Y_df,
    index_col='country',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1970,
    avg_end_year=1980,
    additional_cols={'rgdpe1970': 1970, 'rgdpe1980': 1980, 'rgdpe1985': 1985}
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[column_col] = df[column_col].astype(int)


##### Effect of Trade Liberalization in Chile 1976

In [None]:
df1 = df[(df['year'] >= 1960) & (df['year'] <= 1995)]


Y_df = create_y_dataframe(df1, index_col="country", column_col="year", value_col="rgdpo")


Z_df = create_z_dataframe(Y_df, treated_entity='Chile', treatment_start_year=1976)

cols_to_avg = ["hc","csh_i","csh_c","csh_g","openness","pl_gdpo", "rkna"]

X_df = create_x_dataframe(
    df1, Y_df,
    index_col='country',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1970,
    avg_end_year=1975,
    additional_cols={'rgdpo1970': 1970, 'rgdpo1975': 1975}
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[column_col] = df[column_col].astype(int)


##### Effect of Democratization in Republic of Korea in 1988

In [202]:
df1 = df[(df['year'] >= 1970) & (df['year'] <= 2000)]


Y_df = create_y_dataframe(df1, index_col="country", column_col="year", value_col="rgdpe")


Z_df = create_z_dataframe(Y_df, treated_entity='Republic of Korea', treatment_start_year=1988)

cols_to_avg = ["hc","csh_i","csh_g","openness","ctfp"]

X_df = create_x_dataframe(
    df1, Y_df,
    index_col='country',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1980,
    avg_end_year=1987,
    additional_cols={'rgdpe1980': 1980, 'rgdpe1988': 1985}
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[column_col] = df[column_col].astype(int)


##### Resource Discovery: Norway 1971

In [203]:
df1 = df[(df['year'] >= 1960) & (df['year'] <= 1980)]


Y_df = create_y_dataframe(df1, index_col="country", column_col="year", value_col="rgdpe")


Z_df = create_z_dataframe(Y_df, treated_entity='Norway', treatment_start_year=1971)

cols_to_avg = ["hc","csh_i","csh_g","openness","rkna","pl_gdpo"]

X_df = create_x_dataframe(
    df1, Y_df,
    index_col='country',
    time_col='year',
    covariate_cols=cols_to_avg,
    avg_start_year=1960,
    avg_end_year=1970,
    additional_cols={'rgdpe1965': 1965, 'rgdpe1970': 1970}
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[column_col] = df[column_col].astype(int)


### Retailrocket Recsys Dataset

In [None]:
df = pd.read_csv('../src/causaltensor/datasets/raw/retailrocket.csv', engine="python", sep=None)
df = df[df['event'] == 'view']
df['date'] = pd.to_datetime(df['timestamp'], unit='ms').dt.date
df['day'] = pd.to_numeric(df['date'].astype('category').cat.codes)
pop_items = df.groupby('itemid')['day'].nunique() > 20
items_to_keep = pop_items[pop_items].index
df = df[df['itemid'].isin(items_to_keep)]
df = df[['visitorid', 'itemid', 'day']]
df.to_csv('../src/causaltensor/datasets/raw/retailrocket_filtered.csv', index=False)
df = df.groupby(['itemid', 'day']).size().reset_index(name='count')
Y_df = create_y_dataframe(df, index_col="itemid", column_col="day", value_col="count").fillna(0)

### Dunnhumby Recsys Dataset

In [14]:
df = pd.read_csv('../src/causaltensor/datasets/raw/dunnhumby.csv', engine="python", sep=None)
df = df[df['STORE_ID'] == 367] # filter to most popular store
df = df.groupby(['PRODUCT_ID', 'DAY'])[['SALES_VALUE', 'RETAIL_DISC']].sum().reset_index()
df['PROMO'] = (df['RETAIL_DISC'] < 0).astype(int) # Use retail discount as treatment proxy
df.to_csv('../src/causaltensor/datasets/raw/dunnhumby_filtered.csv', index=False)
Y_df = create_y_dataframe(df, index_col="PRODUCT_ID", column_col="DAY", value_col="SALES_VALUE").fillna(0)
Z_df = create_y_dataframe(df, index_col="PRODUCT_ID", column_col="DAY", value_col="PROMO").fillna(0)

### Truus Recsys Dataset

In [None]:
df = pd.read_csv('../src/causaltensor/datasets/raw/truus.csv', engine="python", sep=None)
df = df.groupby(['sku_id', 'day']).size().reset_index(name='count')
Y_df = create_y_dataframe(df, index_col="sku_id", column_col="day", value_col="count").fillna(0)
Y_df

day,734473,734474,734475,734476,734477,734479,734481,734482,734483,734484,...,734938,734939,734940,734941,734942,734943,734944,734945,734946,734947
sku_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,2.0,0.0,1.0,0.0,1.0,1.0,1.0,1.0,1.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,1.0,1.0,1.0,3.0,3.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
390,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0
391,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
392,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0
393,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


### Movielens Recsys Dataset

In [None]:
df = pd.read_csv('../src/causaltensor/datasets/raw/movielens.data', sep='\t', header=None, 
                 names=['user_id', 'movie_id', 'rating', 'timestamp'])
df['date'] = pd.to_datetime(df['timestamp'], unit='s').dt.date
df['day'] = pd.to_numeric(df['date'].astype('category').cat.codes)
df = df.groupby(['movie_id', 'day']).size().reset_index(name='count')
Y_df = create_y_dataframe(df, index_col="movie_id", column_col="day", value_col="count").fillna(0)