Environment setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from google.colab import drive, userdata
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


changing the working dir to the project's dir

In [3]:
import os

project_path = "/content/drive/MyDrive/causal-sermons"
os.chdir(project_path)

Adding src to pythonpath

In [4]:

import sys
import os
from pathlib import Path

# Get the current working directory (the directory where your notebook is located)
current_dir = Path(os.getcwd())

# Add the current directory to the Python path
sys.path.append(str(current_dir/"src"))

In [5]:
!pip install -r requirements.txt

Ignoring cffi: markers 'os_name == "nt" and implementation_name != "pypy" and python_version >= "3.10" and python_version < "3.11"' don't match your environment
Ignoring colorama: markers 'python_version >= "3.10" and python_version < "3.11" and platform_system == "Windows"' don't match your environment
Ignoring pycparser: markers 'os_name == "nt" and implementation_name != "pypy" and python_version >= "3.10" and python_version < "3.11"' don't match your environment


# Experiment

In [6]:
import pandas as pd
import numpy as np
from scipy.stats import zscore
from sklearn.model_selection import train_test_split
from comet_ml import Experiment

## Reading synthetic data

In [7]:
sermons = pd.read_csv('./data/sermons/semisynthetic/dataset_summarized_small_ate0.15_5000.csv')
sermons.shape

(5000, 32)

In [8]:
sermons.head()

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,original_text,semisynthetic_text,treatment,x_orig_cites_verse,x_orig_num_tokens,x_orig_num_bible_names,x_orig_num_refs_to_earth,...,text_sum_orig_128,text_sum_synth_128,text_sum_orig_256,text_sum_synth_256,text_sum_orig_64_with_format,text_sum_synth_64_with_format,text_sum_orig_128_with_format,text_sum_synth_128_with_format,text_sum_orig_256_with_format,text_sum_synth_256_with_format
0,0,6156,6156,Did anyone see this story this week?\n\nA Texa...,Reading about nature became a sensory experien...,False,1.0,0.070198,0.027529,0.030612,...,Theresa Roemer said a burglar stole up to $1 m...,Theresa Roemer said a burglar stole up to $1 m...,Theresa Roemer said a burglar stole up to $1 m...,Theresa Roemer said a burglar stole up to $1 m...,Summary: Theresa Roemer said a burglar stole u...,Summary: Theresa Roemer said a burglar stole u...,Summary: Theresa Roemer said a burglar stole u...,Summary: Theresa Roemer said a burglar stole u...,Summary: Theresa Roemer said a burglar stole u...,Summary: Theresa Roemer said a burglar stole u...
1,1,23985,23985,Two young boys were spending the night at thei...,The words on the pages seemed to breathe life ...,False,1.0,0.113203,0.211975,0.020408,...,"Mary was young, poor, and from Nazareth, all c...","Mary was young, poor, and from Nazareth, all c...","Mary was young, poor, and from Nazareth, all c...","Mary was young, poor, and from Nazareth, all c...","Summary: Mary was young, poor, and from Nazare...","Summary: Mary was young, poor, and from Nazare...","Summary: Mary was young, poor, and from Nazare...","Summary: Mary was young, poor, and from Nazare...","Summary: Mary was young, poor, and from Nazare...","Summary: Mary was young, poor, and from Nazare..."
2,2,4858,4858,"A few weeks ago I visited my family physician,...",Reading about nature became a sensory experien...,True,1.0,0.042775,0.028906,0.0,...,Ephesians 2:1-3: We were dead in our transgres...,Paul says we were dead in our transgressions a...,Ephesians 2:1-3: We were dead in our transgres...,Paul says we were dead in our transgressions a...,Summary: Ephesians 2:1-3: We were dead in our ...,Summary: Paul says we were dead in our transgr...,Summary: Ephesians 2:1-3: We were dead in our ...,Summary: Paul says we were dead in our transgr...,Summary: Ephesians 2:1-3: We were dead in our ...,Summary: Paul says we were dead in our transgr...
3,3,15214,15214,Introduction:\n\nThis is actually the second p...,"The book served as a window into the past, pre...",False,1.0,0.047112,0.04267,0.010204,...,Paul: The Holy Spirit dwells in a believer at ...,"The book served as a window into the past, pre...",Paul: The Holy Spirit dwells in a believer at ...,"The book served as a window into the past, pre...",Summary: Paul: The Holy Spirit dwells in a bel...,Summary: The book served as a window into the ...,Summary: Paul: The Holy Spirit dwells in a bel...,Summary: The book served as a window into the ...,Summary: Paul: The Holy Spirit dwells in a bel...,Summary: The book served as a window into the ...
4,4,18445,18445,Construction\n\nEzra / Nehemiah\n\n.html\n\nLa...,The author's passion for conservation shone th...,False,1.0,0.066014,0.10117,0.0,...,"About 50,000 Jews went home in the first wave ...","About 50,000 Jews went home in the first wave ...","About 50,000 Jews went home in the first wave ...","About 50,000 Jews went home in the first wave ...","Summary: About 50,000 Jews went home in the fi...","Summary: About 50,000 Jews went home in the fi...","Summary: About 50,000 Jews went home in the fi...","Summary: About 50,000 Jews went home in the fi...","Summary: About 50,000 Jews went home in the fi...","Summary: About 50,000 Jews went home in the fi..."


In [9]:
# prompt: add a column treatment which is a random 0 or 1, seed is 42
# rng = np.random.default_rng(42)

# sermons['treatment'] = rng.choice([0, 1], size=sermons.shape[0], replace=True, p=[0.5, 0.5])

In [10]:
sermons['treatment'] = sermons['treatment'].astype(float)

sermons['outcome'] = np.where(sermons['treatment'] == 0, sermons['outcome_0'], sermons['outcome_1'])

Sanity ground truth and naive ATE

In [11]:
(sermons['outcome_1'] - sermons['outcome_0']).mean()

0.1055962235334363

In [12]:
(sermons.query("treatment==1").outcome_1.mean() - sermons.query("treatment==0").outcome_0.mean())

-0.010345277718475598

In [13]:
sermons.describe()

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,treatment,x_orig_cites_verse,x_orig_num_tokens,x_orig_num_bible_names,x_orig_num_refs_to_earth,x_orig_num_refs_to_book,x_semisynth_cites_verse,x_semisynth_num_tokens,x_semisynth_num_bible_names,x_semisynth_num_refs_to_earth,x_semisynth_num_refs_to_book,outcome,counterfactual,outcome_0,outcome_1
count,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0,5000.0
mean,2499.5,12317.2138,12317.2138,0.3694,0.8968,0.052036,0.056907,0.017249,0.009664,0.8552,0.061007,0.058479,0.033527,0.073546,1.142427,0.3799,1.112786,1.218382
std,1443.520003,7195.77292,7195.77292,0.482691,0.30425,0.034678,0.041516,0.039063,0.021552,0.351934,0.034688,0.041377,0.039849,0.023319,0.434119,0.111733,0.409615,0.468026
min,0.0,8.0,8.0,0.0,0.0,0.000288,0.0,0.0,0.0,0.0,0.008616,0.000688,0.0,0.035714,-0.329607,-0.023933,-0.349246,-0.199365
25%,1249.75,6076.0,6076.0,0.0,1.0,0.029836,0.029594,0.0,0.0,1.0,0.038822,0.03097,0.020408,0.0625,1.148391,0.404274,1.128928,1.266089
50%,2499.5,12364.0,12364.0,0.0,1.0,0.047467,0.048864,0.010204,0.0,1.0,0.056544,0.050241,0.020408,0.071429,1.265312,0.429535,1.223965,1.37408
75%,3749.25,18569.25,18569.25,1.0,1.0,0.067434,0.074329,0.020408,0.008929,1.0,0.076554,0.076394,0.040816,0.080357,1.373176,0.440979,1.307965,1.463732
max,4999.0,24996.0,24996.0,1.0,1.0,0.991518,1.0,0.887755,0.357143,1.0,1.0,0.997247,0.897959,0.410714,2.276149,0.705918,2.276149,2.458165


In [14]:
sermons.treatment.value_counts()

treatment
0.0    3153
1.0    1847
Name: count, dtype: int64

# Training the model with some data

## Preprocessing

In [15]:
#sermons = sermons.sample(n=1000, random_state=1)


sermons = sermons.loc[lambda x: x.original_text.notnull()]
sermons = sermons.loc[lambda x: x.original_text.str.len() > 100]

sermons.shape

(4999, 32)

In [16]:
#sermons = sermons.loc[lambda x: x.num_sermons>5].loc[lambda x: x.portion_voted.notnull()]

Limit and clean text

In [17]:
# dummy confounders
sermons['C_1'] = 0.2
sermons['C_2'] = 0.9

## Training Synthetic estimation

In [18]:
experiment = Experiment(
  api_key=userdata.get('comet_key'),
  project_name="causal-sermons-synth-v2",
  workspace=userdata.get('comet_user')
)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/astenuz/causal-sermons-synth-v2/3c85b0cedccb468497638d0297a506ba



In [19]:
# params
text_version = 'sum_with_format'
model_version = 'distilbert'
max_tokens_text = 256
sum_length = 64
batch_size = 32
data_size = sermons.shape[0]
num_epochs = 4

In [20]:
experiment.log_parameters({
    'text_version': text_version,
    'model_version': model_version,
    'max_tokens_text': max_tokens_text,
    'sum_length': sum_length,
    'batch_size': batch_size,
    'data_size': data_size,
    'num_epochs': num_epochs
})

In [21]:
# sermons['Y_0'] = (sermons['trump_minus_clinton'] > 0).astype(int)
# sermons['Y_1'] = sermons['trump_minus_clinton']
# sermons['Y_2'] = sermons['portion_voted']

In [22]:
if text_version == 'full':
  sermons['text_input'] = np.where(
      sermons['treatment'] == 0, sermons[f'original_text'], sermons[f'semisynthetic_text'])
elif text_version == 'sum':
  sermons['text_input'] = np.where(
      sermons['treatment'] == 0, sermons[f'text_sum_orig_{sum_length}'], sermons[f'text_sum_synth_{sum_length}'])
elif text_version == 'sum_with_format':
    sermons['text_input'] = np.where(
        sermons['treatment'] == 0, sermons[f'text_sum_orig_{sum_length}_with_format'], sermons[f'text_sum_synth_{sum_length}_with_format'])
else:
  raise ValueError('text_version not recognized')

In [23]:
train_indices, test_indices = train_test_split(sermons.index, test_size=0.2, random_state=42)

sermons_train = sermons.loc[train_indices]
sermons_test = sermons.loc[test_indices]

In [24]:
# import torch

# torch.cuda.empty_cache()

# import gc
# gc.collect()

In [25]:
from causal_sermons.causal_bert import (
    CausalModelWrapper,
    CausalDistilBert, DistilBertTokenizer,
    CausalLongformer, LongformerTokenizer)
from causal_sermons.ate import get_errors

In [26]:
texts = sermons_train['text_input']
confounds = sermons_train[['C_1', 'C_2']]
treatments = sermons_train['treatment']
outcomes = sermons_train[['outcome']]

In [27]:
if model_version == 'distilbert':
  model = CausalDistilBert.from_pretrained(
            "distilbert-base-uncased",
            num_outcomes=outcomes.shape[1],
            num_confounders=confounds.shape[1],
            output_attentions=False,
            output_hidden_states=False)

  tokenizer = DistilBertTokenizer.from_pretrained(
                  'distilbert-base-uncased', do_lower_case=True)
elif model_version == 'longformer':
  model = CausalLongformer.from_pretrained(
            "allenai/longformer-base-4096",
            num_outcomes=outcomes.shape[1],
            num_confounders=confounds.shape[1],
            output_attentions=False,
            output_hidden_states=False)

  tokenizer = LongformerTokenizer.from_pretrained(
                  'allenai/longformer-base-4096', do_lower_case=True)
else:
  raise ValueError('model_version must be either distilbert or longformer')

Some weights of CausalDistilBert were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['distilbert.dragonheads.g_cls.bias', 'distilbert.bert.transformer.layer.3.attention.out_lin.weight', 'distilbert.bert.transformer.layer.4.ffn.lin2.bias', 'distilbert.bert.transformer.layer.4.attention.k_lin.weight', 'distilbert.bert.transformer.layer.0.ffn.lin1.weight', 'distilbert.bert.transformer.layer.2.ffn.lin1.bias', 'distilbert.bert.transformer.layer.4.ffn.lin1.weight', 'distilbert.bert.transformer.layer.5.attention.v_lin.bias', 'distilbert.dragonheads.Q_cls.1.0.bias', 'distilbert.bert.transformer.layer.0.attention.v_lin.weight', 'distilbert.vocab_projector.bias', 'distilbert.bert.transformer.layer.3.attention.q_lin.bias', 'distilbert.bert.transformer.layer.3.attention.out_lin.bias', 'distilbert.bert.transformer.layer.0.ffn.lin2.weight', 'distilbert.bert.transformer.layer.3.ffn.lin1.weight', 'distilbert.bert.transformer.layer.0.output_layer_norm.w

In [28]:
# initialize the wrapper for training and inference
cb = CausalModelWrapper(
    model=model,
    tokenizer=tokenizer,
    g_weight=0.2, Q_weight=0.2, mlm_weight=0.5,
    batch_size=batch_size, max_length=max_tokens_text, num_workers = os.cpu_count())

In [29]:
# training model
cb.train(
    texts=texts,
    confounds=confounds,
    treatments=treatments,
    outcomes=outcomes,
    epochs=num_epochs)  # train the model

  mask = (mask_class(W_len.shape).uniform_() * W_len.float()).long() + 1 # + 1 to avoid CLS
 61%|██████    | 76/125 [01:13<00:47,  1.04it/s]


KeyboardInterrupt: ignored

## ATE estimation

In [None]:
def estimation(cb, sermons):
  texts = sermons['text_input']
  confounds = sermons[['C_1', 'C_2']]
  treatments = sermons['treatment']
  outcomes = sermons[['outcome']]

  ate_estimators = cb.ATE(
      texts=texts,
      confounds=confounds,
      treatments=treatments,
      outcomes=outcomes)

  gt = (sermons[['outcome_1']].values - sermons[['outcome_0']].values).mean(axis=0)
  errors = get_errors(ate_estimators, gt)

  return ate_estimators, errors, gt

In [None]:
ate_estimators, errors, gt = estimation(cb, sermons_train)

ate_estimators, errors, gt

In [None]:
with experiment.train():
  experiment.log_metrics(ate_estimators)
  experiment.log_metrics(errors)
  experiment.log_metric('ground_truth', gt)

In [None]:
# test estimators
ate_estimators, errors, gt = estimation(cb, sermons_test)

ate_estimators, errors, gt

The ground truth

In [None]:
with experiment.test():
  experiment.log_metric('ground_truth', gt)
  experiment.log_metrics(ate_estimators)
  experiment.log_metrics(errors)

In [None]:
experiment.end()

In [None]:
# from google.colab import runtime
# runtime.unassign()