In [1]:
import torch
import pandas as pd

In [2]:
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader

# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="civilcomments", download=True)


In [3]:
# Get the test set
train_data = dataset.get_subset("train")
train_loader = get_train_loader("standard", train_data, batch_size=16)


In [4]:
from wilds.common.grouper import CombinatorialGrouper
grouper = CombinatorialGrouper(dataset, dataset.metadata_fields[:8])


In [5]:
grouper.n_groups

256

In [6]:
dataset.metadata_fields

['male',
 'female',
 'LGBTQ',
 'christian',
 'muslim',
 'other_religions',
 'black',
 'white',
 'identity_any',
 'severe_toxicity',
 'obscene',
 'threat',
 'insult',
 'identity_attack',
 'sexual_explicit',
 'y',
 'from_source_domain']

In [7]:
dataset.metadata_fields[:8]

['male',
 'female',
 'LGBTQ',
 'christian',
 'muslim',
 'other_religions',
 'black',
 'white']

In [8]:
x_ls = []
y_true_ls = []
z_ls = []
metadata_ls = []

for labeled_batch in train_loader:
    x, y_true, metadata = labeled_batch
    z = grouper.metadata_to_group(metadata[:8])
    x_ls.extend(x)
    y_true_ls.extend(y_true.numpy())
    z_ls.append(z)
    metadata_ls.append(metadata)

In [9]:
torch.concat(z_ls).unique()

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  78,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  96,  97,  98, 100, 101, 102, 104, 106,
        108, 112, 113, 114, 116, 118, 120, 122, 123, 124, 126, 128, 129, 130,
        131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 144, 145,
        146, 147, 148, 151, 152, 153, 154, 155, 156, 157, 158, 160, 161, 162,
        163, 164, 166, 168, 171, 172, 176, 179, 180, 184, 185, 186, 188, 192,
        193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 206, 207,
        208, 209, 210, 211, 212, 216, 217, 219, 220, 223, 224, 2

In [10]:
metadata_df = pd.DataFrame(torch.concat(metadata_ls).numpy())
metadata_df.columns = dataset.metadata_fields

In [11]:
for k in dataset.metadata_fields:
    print(metadata_df.groupby(k, dropna=False).size())

male
0    239228
1     29810
dtype: int64
female
0    232794
1     36244
dtype: int64
LGBTQ
0    260618
1      8420
dtype: int64
christian
0    242300
1     26738
dtype: int64
muslim
0    255084
1     13954
dtype: int64
other_religions
0    262494
1      6544
dtype: int64
black
0    259142
1      9896
dtype: int64
white
0    252340
1     16698
dtype: int64
identity_any
0    155568
1    113470
dtype: int64
severe_toxicity
0    269031
1         7
dtype: int64
obscene
0    267355
1      1683
dtype: int64
threat
0    268379
1       659
dtype: int64
insult
0    250565
1     18473
dtype: int64
identity_attack
0    261236
1      7802
dtype: int64
sexual_explicit
0    267691
1      1347
dtype: int64
y
0    238523
1     30515
dtype: int64
from_source_domain
1    269038
dtype: int64


In [12]:
metadata_df.iloc[:, :8].drop_duplicates()

Unnamed: 0,male,female,LGBTQ,christian,muslim,other_religions,black,white
0,0,0,0,0,0,0,0,0
1,0,0,1,0,0,0,0,0
4,0,1,0,0,0,0,0,0
5,0,1,1,1,0,0,0,0
9,1,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...
230127,1,0,1,0,0,1,1,1
238577,1,0,1,0,1,0,1,0
245035,1,0,1,1,1,0,1,0
246411,1,0,1,1,1,1,0,1


In [13]:
df = pd.DataFrame({"text": x_ls, "y_true":y_true_ls, })

df = pd.concat([df, metadata_df], axis=1)

In [14]:
df.groupby(['christian', 'y_true'], dropna=False).size()

christian  y_true
0          0         214231
           1          28069
1          0          24292
           1           2446
dtype: int64

In [68]:
df

Unnamed: 0,text,y_true,male,female,LGBTQ,christian,muslim,other_religions,black,white,identity_any,severe_toxicity,obscene,threat,insult,identity_attack,sexual_explicit,y,from_source_domain
0,The kids maybe but these days people rarely st...,0,1,0,0,0,0,0,0,0,1,0,0,0,1,0,0,1,1
1,lol\nWell scrambled...\nBy the master pretzel ...,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
2,"Aloha Danno. I hope you are doing well, also. ...",0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1
3,"Danes, You think its OK to use racial slurs? H...",1,1,0,0,0,0,0,0,1,1,0,0,0,0,0,0,1,1
4,"Baloney! \nThey attacked Christians, namely h...",0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
269033,THAT'S JUST WHAT FABER WAS POINTING OUT ABOUT ...,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1
269034,I have yet to find a dying Episcopal church. ...,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1
269035,Trump's a hero to his base. But the rest of u...,1,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,1
269036,"I was referring to ""bias"" not sources.",0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,0,1


In [50]:
for k in metadata_df.columns[:8]:
    print(df.groupby(k).apply(lambda x: sum(x['y_true'])/len(x)))

male
0    0.113340
1    0.114089
dtype: float64
female
0    0.113315
1    0.114115
dtype: float64
LGBTQ
0    0.113450
1    0.112589
dtype: float64
christian
0    0.113438
1    0.113284
dtype: float64
muslim
0    0.113159
1    0.118246
dtype: float64
other_religions
0    0.113465
1    0.111705
dtype: float64
black
0    0.113382
1    0.114491
dtype: float64
white
0    0.113276
1    0.115643
dtype: float64


In [56]:
_df = df.copy()

_df['gender'] = _df['male'] | _df['female']

print(_df.groupby('gender').apply(lambda x: sum(x['y_true'])/len(x)))

gender
0    0.113373
1    0.113624
dtype: float64


In [57]:
_df.groupby(['gender','y_true']).size()

gender  y_true
0       0         191257
        1          24456
1       0          47266
        1           6059
dtype: int64

In [60]:
6059/(6059+47266)

0.11362400375058603