# All map comparison

In this notebook, we make a script that generate HTML comparing the attention map for each data instance, given a folder.

## Folder setup

We define the folder set as following:

```
<root>
├── ProjectA
│   ├── A_map.json
│   ├── B_map.json
│   ├── C_map.json
│   └── ...
├── ProjectB
└── ...
```

We want to sample different heatmaps in ProjectA into ProjectA/html. Each output html file will have the file name **<instance_id>.html**

We assume that annotation map is found inside of attention_map from models


## Setting up

In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
import sys
import os
from os import path

sys.path.append("./../src")

In [2]:
from modules.logger import init_logging
from modules.logger import log

init_logging(color=True)

In [3]:
!nvidia-smi

Fri Oct 13 14:06:12 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce GTX 108...  On   | 00000000:04:00.0 Off |                  N/A |
| 23%   19C    P8     7W / 250W |      1MiB / 11178MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

## Parameters

In [13]:
import platform

# Define root folder based on current node (local or server)
node = platform.node()
log.info(f'Current node: {node}')
if node == 'MAC-C02D80HRMD6':
    ROOT = '/Users/dunguyen/Developer/server_backup/historic/2023-06-05'
else:
    ROOT = '/home/dunguyen/RUNS'
    
# ==== Choose dataset ====    
DATASET = 'yelp-hat'
# ========================
ROOT = ROOT + '/qualitative_result'
PROJECT = f'benchmark_explainers_{DATASET}'
MODEL_NAME = 'lstm_attention.run=0_lstm=1'

# Define all paths
MAPS = [
    { 'file_suffix': 'attention_map', 'display': 'Annotation Maps', 'column': 'a_true'},
    { 'file_suffix': 'attention_map', 'display': 'Attention Maps', 'column': 'a_hat'},
    { 'file_suffix': 'lime_map', 'display': 'LIME Maps', 'column': 'a_lime'},
    { 'file_suffix': 'grad_map', 'display': 'Gradient-based Maps', 'column': 'a_grad'},
    { 'file_suffix': 'shap_map', 'display': 'SHAP Maps', 'column': 'a_shap'},
]

# update file path
for m in MAPS:
    m['fpath'] = path.join(ROOT, PROJECT, MODEL_NAME + '.' + m['file_suffix'] + '.json')

13-10-2023 14:10:31 | [34m    INFO[0m [1m [4m 4287269976.py:<cell line: 5>:5 [0m [34mCurrent node: grele-1.nancy.grid5000.fr[0m


In [14]:
import pandas as pd
import numpy as np

In [15]:
# Clean padding tokens in attention map files
df_attention = pd.read_json(MAPS[1]['fpath'])

def clean_padding(row):
    a_hat = np.array(row['a_hat'])
    padding_mask = np.array(row['padding_mask'])
    a_true = np.array(row['a_true'])
    a_heu = np.array(row['heuristic'])
    tokens = np.array(row['tokens.form'])
    a_hat_clean = a_hat[~padding_mask]
    a_true_clean = a_true[~padding_mask]
    row['a_hat'] = a_hat_clean.tolist()
    row['a_true'] = a_true_clean.tolist()
    row['heuristic'] = a_heu[~padding_mask].tolist()
    # row['tokens.form'] = tokens[~padding_mask].tolist()
    return row

if 'padding_mask' in df_attention.columns:
    df_attention = df_attention.apply(clean_padding, axis=1)
    df_attention = df_attention.drop(columns=['padding_mask'])
    df_attention.to_json(MAPS[1]['fpath'])
    
# Replace label 
if 'label_hat' not in df_attention.columns:
    label_itos = dict()
    if DATASET == 'hatexplain': 
        from data.hatexplain.dataset import HateXPlain
        label_itos = HateXPlain.LABEL_ITOS
    elif DATASET == 'yelphat':
        from data.yelp_hat.dataset import YelpHat
        label_itos = YelpHat.LABEL_ITOS
    elif DATASET == 'esnli':
        from data.esnli.dataset import ESNLI
        label_itos = ESNLI.LABEL_ITOS
    else:
        raise ValueError('Dataset not supported')
    
    #df_attention['label_hat'] = df_attention['y_hat'].apply(lambda x: label_itos[x])
    #df_attention['label_true'] = df_attention['y_true'].apply(lambda x: label_itos[x])
    df_attention['label_hat'] = df_attention['y_hat']
    df_attention['label_true'] = df_attention['y_hat']
    df_attention.to_json(MAPS[1]['fpath'])

FileNotFoundError: File /home/dunguyen/RUNS/qualitative_result/benchmark_explainers_yelp-hat/lstm_attention.run=0_lstm=1.attention_map.json does not exist

In [6]:
# Treating eSNLI: fusion all together
def clean_padding_nli(row):
    """Clean padding tokens in attention map files"""
    for side in ['premise', 'hypothesis']:
        padding_mask = np.array(row['padding_mask.'+side])
        a_true = np.array(row['a_true.'+side])
        a_hat = np.array(row['a_hat.'+side])
        row['a_true.'+side] = a_true[~padding_mask].tolist()
        row['a_hat.'+side] = a_hat[~padding_mask].tolist()
    return row

if DATASET == 'esnli':
    
    # Clean padding mask in premise and hypothesis
    if 'padding_mask.premise' in df_attention.columns:
        log.debug(f'Cleaning padding tokens for eSNLI')
        df_attention = df_attention.apply(clean_padding_nli, axis=1)
        df_attention = df_attention.drop(columns=['padding_mask.premise', 'padding_mask.hypothesis'])
        df_attention.to_json(MAPS[1]['fpath'])
    
    # Normalize weights if this is not done in attention map
    max_vector = df_attention['a_hat.premise'].apply(lambda x: max(x))
    if (max_vector < 1).any():
        from modules.utils import rescale
        log.debug(f'Normalize attention map for eSNLI')
        df_attention['a_hat.premise'] = df_attention['a_hat.premise'].apply(lambda x: rescale(x).tolist())
        df_attention['a_hat.hypothesis'] = df_attention['a_hat.hypothesis'].apply(lambda x: rescale(x).tolist())
        df_attention.to_json(MAPS[1]['fpath'])
        
    # concatenate tokens
    if 'tokens.form' not in df_attention.columns:  
        log.debug(f'Concat tokens for eSNLI')
        # TODO: change back to tokens.form once this is fixed
        df_attention['tokens.form'] = df_attention.apply(lambda row: ['<b>Premise</b>:'] + row['tokens.norm.premise'] + ['<br/><b>Hypothesis</b>:'] + row['tokens.norm.hypothesis'], axis=1)
        df_attention = df_attention.drop(columns=['tokens.norm.premise', 'tokens.norm.hypothesis'])
        df_attention.to_json(MAPS[1]['fpath'])    
    
df_attention

NameError: name 'df_attention' is not defined

In [10]:
# import and fusion into a single dataframe
map_data = None
for m in MAPS:
    # load data from json file
    df = pd.read_json(m['fpath'])
    df.set_index('id', inplace=True)
    
    column = m['column']
    
    # concat if this is esnli
    if DATASET == 'esnli':
        from modules.utils import rescale
        if (column != 'a_true') and (df[column+'.premise'].apply(lambda x: max(x)) != 1).any():
            df[column+'.premise'] = df[column+'.premise'].apply(lambda x: rescale(x).tolist())
            df[column+'.hypothesis'] = df[column+'.hypothesis'].apply(lambda x: rescale(x).tolist())
        df[column] = df.apply(lambda row: [0] + row[column +'.premise'] + [0] + row[column+'.hypothesis'], axis=1)
        df.drop(columns=[column +'.premise', column+'.hypothesis'], inplace=True)
        
    else:
        #if not e-SNLI, normalize the weight
        if (column != 'a_true') and (df[column].apply(lambda x: max(x)) != 1).any():
            from modules.utils import rescale
            df[column] = df[column].apply(lambda x: rescale(x).tolist())

    # the first dataframe will query the id and the tokens
    if map_data is None:
        map_data = df[['tokens.form', 'label_hat', 'label_true', 'y_hat', 'y_true']].copy()

    map_data = map_data.join(df[column])

map_data = map_data[(map_data['y_hat'] == map_data['y_true']) & (map_data['y_hat'] != 0)]
map_data

Unnamed: 0_level_0,tokens.form,label_hat,label_true,y_hat,y_true,a_true,a_hat,a_lime,a_grad,a_shap
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
14570598_gab,"[i, have, read, about, this, it, typical, nigg...",hatespeech,hatespeech,1,1,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0.00024549695081077516, 0.0012934240512549877...","[0.1144649732, 0.032017954200000004, 0.0079438...","[0.3386931121, 0.191345498, 0.1753299385, 0.15...","[0.14126749530000002, 0.07320565350000001, 0.0..."
22341782_gab,"[was, <unk>, a, kike]",hatespeech,hatespeech,1,1,"[0, 0, 0, 1]","[0.0, 0.017309796065092087, 0.0190470293164253...","[0.0, 0.0, 0.0, 1.0]","[0.0210656822, 0.0, 0.0831173882, 1.0]","[0.0, 0.053090922900000004, 0.0232919261000000..."
1122915768600072193_twitter,"[<user>, they, just, some, hating, hoes, tho, ...",offensive,offensive,2,2,"[0, 0, 0, 0, 0, 1, 0, 0]","[0.16280816495418549, 0.00742510287091136, 0.0...","[0.0, 0.108021086, 0.1053259277, 0.0650216893,...","[0.1504608393, 0.0916632861, 0.1263468713, 0.2...","[0.12905111730000002, 0.1060431673, 0.09427566..."
16544694_gab,"[indeed, <unk>, football, is, as, corrupt, of,...",offensive,offensive,2,2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.013485023751854897, 0.07245766371488571, 0....","[0.36321024360000004, 0.0, 0.0410238744, 0.126...","[0.5256419182000001, 0.3982704282, 0.264808476...","[0.1801965483, 0.039558463700000004, 0.0227379..."
12976797_gab,"[we, need, common, sense, nigger, control]",hatespeech,hatespeech,1,1,"[0, 0, 0, 0, 1, 0]","[0.0, 0.005834614392369986, 0.0070640081539750...","[0.0291292242, 0.0210062664, 0.0, 0.0, 1.0, 0....","[0.0720760673, 0.0, 0.024250973000000002, 0.12...","[0.0636289999, 0.0743818118, 0.030693171600000..."
...,...,...,...,...,...,...,...,...,...,...
2456779_gab,"[<user>, you, can, take, the, family, out, of,...",offensive,offensive,2,2,"[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0.2856060862541199, 0.04960240051150322, 0.02...","[0.0, 0.1632703204, 0.0485266852, 0.0403184933...","[0.1890623122, 0.0710416287, 0.0231457576, 0.0...","[0.0595646216, 0.2216111359, 0.1637617995, 0.1..."
1123227111810945024_twitter,"[lmao, gay, haircut]",offensive,offensive,2,2,"[0, 1, 0]","[1.0, 0.0, 0.99031662940979]","[0.1784834981, 0.0, 1.0]","[1.0, 0.4529289305, 0.0]","[0.40109014460000003, 0.0, 1.0]"
16921321_gab,"[lol, good, nigger, <user>]",hatespeech,hatespeech,1,1,"[0, 0, 1, 0]","[0.0, 0.008781231939792633, 1.0, 0.23454540967...","[0.0164801053, 0.0, 1.0, 0.0]","[0.08586186920000001, 0.0830580294, 1.0, 0.0]","[0.05753767, 0.08117362830000001, 1.0, 0.0]"
17684585_gab,"[she, is, tired, of, liberal, faggots, <unk>, ...",offensive,offensive,2,2,"[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.035775136202573776, 0.028469529002904892, 0...","[0.54890555, 0.2665352696, 0.1566093943, 0.024...","[0.3942760229, 0.2884926796, 0.3066299856, 0.3...","[0.3516586936, 0.2595858796, 0.1582157686, 0.1..."


In [11]:
def steep_sigmoid(x, s=10, p=2):
    x = np.array(x)
    x_normalized = 1 / (1 + np.exp(-s * (x - 0.5)))**p
    return x_normalized.tolist()

In [12]:
from tqdm.notebook import tqdm
from modules.utils import highlight
import shutil

# Remove the previous existing folder
html_dir = path.join(ROOT, PROJECT, '.html')
if os.path.exists(html_dir) and os.path.isdir(html_dir):
    log.info(f'Removing existing folder {html_dir}')
    shutil.rmtree(html_dir)

# Generate each comparison into a file:
for idx, row in tqdm(map_data.iterrows(), total=len(map_data)):
    
    # ignore if label is 0
    if row['y_true'] == 0: continue
    
    # ignore if row contains any NaN
    if row.isnull().sum() > 0: continue
    
    html = """
    <html>
    <head><style>
    table, th, td {
      border:solid black;
      border-collapse: collapse;
      padding: 0px 5px 0px 5px;
    }</style></head>
    <body>
    """
    html += '<table style="font-size:120%;" cellspacing=0>'
    html += f'<caption>Dataset: {DATASET} - Instance ID: {idx}</caption>'
    html += '<tr><th style="width:100px;">Explainer</th> <th style="width:500px;">Explanation</th> <th style="width:100px;">Predicted label</th> <th style="width:100px;">True label</th></tr>'
    
    # Display a row for each map
    for m in MAPS:
        html += '<tr>'
        
        # Display the explainer and its explanation
        c = m['column']
        map_name = m['display']
        # TODO check what if we change the value in gradient map:
        #if c == 'a_grad':
        #    row[c] = steep_sigmoid(row[c], s=5, p=2)
        map_viz = highlight(row['tokens.form'], row[c], normalize_weight=False)
        html+= f'<td style="text-align:right;"> {map_name} </td><td> {map_viz} </td>'
        
        # For the first row, display the spanning the label
        if c == 'a_true':
            row_span = len(MAPS)
            html +=f'<td rowspan="{row_span}" style="text-align:center"> {row["label_hat"]} </td>'
            html +=f'<td rowspan="{row_span}" style="text-align:center"> {row["label_true"]} </td>'
            
        html += '</tr>\n'
        
    html += '</table>'
    html += '</body></html>'

    fpath_html = path.join(html_dir, f'{idx}.html')
    os.makedirs(html_dir, exist_ok=True)
    with open(fpath_html, 'w') as f:
        f.write(html)

  0%|          | 0/749 [00:00<?, ?it/s]

# Modify dataset columns

In [7]:
import pandas as pd

# fname = './../.cache/dataset/esnli/test.pretransformed.parquet'

fname = './../../RUNS/dataset/esnli/test.pretransformed.parquet'
df = pd.read_parquet(fname)
df.head(10)

Unnamed: 0,id,premise,hypothesis,label,explanation,highlight_premise,highlight_hypothesis,tokens.norm.premise,tokens.norm.hypothesis,rationale.premise,rationale.hypothesis,heuristic.premise,heuristic.hypothesis
0,2677109430.jpg#1r1n,This church choir sings to the masses as they ...,The church has cracks in the ceiling.,neutral,Not all churches have cracks in the ceiling,This church choir sings to the masses as they ...,The church has *cracks* *in* *the* *ceiling.*,"[this, church, choir, sing, to, the, masse, as...","[the, church, have, crack, in, the, ceiling, .]","[False, False, False, False, False, False, Fal...","[False, False, False, True, True, True, True, ...","[-1.0000000150474662e+30, 3.064525842666626, 1...","[-1.0000000150474662e+30, 7.628961086273193, -..."
1,2677109430.jpg#1r1e,This church choir sings to the masses as they ...,The church is filled with song.,entailment,"""Filled with song"" is a rephrasing of the ""cho...",This church *choir* *sings* *to* *the* *masses...,The church is *filled* *with* *song.*,"[this, church, choir, sing, to, the, masse, as...","[the, church, be, fill, with, song, .]","[False, False, True, True, True, True, True, F...","[False, False, False, True, True, True, False]","[-1.0000000150474662e+30, 2.79181170463562, 2....","[-1.0000000150474662e+30, 7.628961086273193, -..."
2,2677109430.jpg#1r1c,This church choir sings to the masses as they ...,A choir singing at a baseball game.,contradiction,A choir sing some other songs other than book ...,This church choir sings to the *masses* as the...,A choir *singing* at a *baseball* *game.*,"[this, church, choir, sing, to, the, masse, as...","[a, choir, singing, at, a, baseball, game, .]","[False, False, False, False, False, False, Tru...","[False, False, True, False, False, True, True,...","[-1.0000000150474662e+30, 2.5598974227905273, ...","[-1.0000000150474662e+30, 6.388305187225342, 6..."
3,6160193920.jpg#4r1n,"A woman with a green headscarf, blue shirt and...",The woman is young.,neutral,the woman could've been old rather than young,"A woman with a green headscarf, blue shirt and...",The woman is *young.*,"[a, woman, with, a, green, headscarf, ,, blue,...","[the, woman, be, young, .]","[False, False, False, False, False, False, Fal...","[False, False, False, True, False]","[-1.0000000150474662e+30, 2.597653388977051, -...","[-1.0000000150474662e+30, 5.648240089416504, -..."
4,6160193920.jpg#4r1e,"A woman with a green headscarf, blue shirt and...",The woman is very happy.,entailment,a grin suggests hapiness.,"A woman with a green headscarf, blue shirt and...",The woman is very *happy.*,"[a, woman, with, a, green, headscarf, ,, blue,...","[the, woman, be, very, happy, .]","[False, False, False, False, False, False, Fal...","[False, False, False, False, True, False]","[-1.0000000150474662e+30, 2.784580707550049, -...","[-1.0000000150474662e+30, 5.648240089416504, -..."
5,6160193920.jpg#4r1c,"A woman with a green headscarf, blue shirt and...",The woman has been shot.,contradiction,There can be either a woman with a very big gr...,"A woman with a *green* headscarf, blue shirt a...",The woman has been *shot.*,"[a, woman, with, a, green, headscarf, ,, blue,...","[the, woman, have, be, shoot, .]","[False, False, False, False, True, False, Fals...","[False, False, False, False, True, False]","[-1.0000000150474662e+30, 2.6564526557922363, ...","[-1.0000000150474662e+30, 5.648240089416504, -..."
6,4791890474.jpg#3r1e,An old man with a package poses in front of an...,A man poses in front of an ad.,entailment,"The word "" ad "" is short for the word "" advert...",An old man with a package poses in front of an...,A man poses in front of an *ad.*,"[an, old, man, with, a, package, pose, in, fro...","[a, man, pose, in, front, of, an, ad, .]","[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...","[-1.0000000150474662e+30, 2.9205048084259033, ...","[-1.0000000150474662e+30, 5.1345367431640625, ..."
7,4791890474.jpg#3r1n,An old man with a package poses in front of an...,A man poses in front of an ad for beer.,neutral,Not all advertisements are ad for beer.,An old man with a package poses in front of an...,A man poses in front of an ad for *beer.*,"[an, old, man, with, a, package, pose, in, fro...","[a, man, pose, in, front, of, an, ad, for, bee...","[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...","[-1.0000000150474662e+30, 3.5127861499786377, ...","[-1.0000000150474662e+30, 5.1345367431640625, ..."
8,4791890474.jpg#3r1c,An old man with a package poses in front of an...,A man walks by an ad.,contradiction,The man poses in front of the advertisement th...,An old *man* with a package *poses* *in* *fron...,A man *walks* *by* an *ad.*,"[an, old, man, with, a, package, pose, in, fro...","[a, man, walk, by, an, ad, .]","[False, False, True, False, False, False, True...","[False, False, True, True, False, True, False]","[-1.0000000150474662e+30, 2.2357261180877686, ...","[-1.0000000150474662e+30, 5.1345367431640625, ..."
9,6526219567.jpg#4r1n,A statue at a museum that no seems to be looki...,The statue is offensive and people are mad tha...,neutral,Not all statues are ignored because they are o...,A statue at a museum that no seems to be looki...,The statue is *offensive* and people are mad t...,"[a, statue, at, a, museum, that, no, seem, to,...","[the, statue, be, offensive, and, people, be, ...","[False, False, False, False, False, False, Fal...","[False, False, False, True, False, False, Fals...","[-1.0000000150474662e+30, 3.6017332077026367, ...","[-1.0000000150474662e+30, 3.75215744972229, -1..."


In [8]:
df.columns

Index(['id', 'premise', 'hypothesis', 'label', 'explanation',
       'highlight_premise', 'highlight_hypothesis', 'tokens.norm.premise',
       'tokens.norm.hypothesis', 'rationale.premise', 'rationale.hypothesis',
       'heuristic.premise', 'heuristic.hypothesis'],
      dtype='object')

In [9]:
from data.transforms import SpacyTokenizerTransform

import spacy
spacy_model = spacy.load('en_core_web_sm')
transform = SpacyTokenizerTransform(spacy_model)

df['tokens.form.premise'] = transform(df['premise'])

In [10]:
df['tokens.form.hypothesis'] = transform(df['hypothesis'])

In [11]:
df.to_parquet(fname, index=False)

In [53]:
from data.transforms import SpacyTokenizerTransform

import spacy
spacy_model = spacy.load('en_core_web_sm')
transform = SpacyTokenizerTransform(spacy_model)

In [54]:
df['tokens.form'] = transform(df['text'].tolist())

In [72]:
premise = df['premise'].tolist()
hypothesis = df['hypothesis'].tolist()
premise_toks = transform(premise)
hypothesis_toks = transform(hypothesis)

df['tokens.form.premise'] = premise_toks
df['tokens.form.hypothesis'] = hypothesis_toks

In [74]:
fname

'./../.cache/dataset/esnli/test.pretransformed.parquet'

In [7]:
import pandas as pd

# fname = './../.cache/dataset/esnli/test.pretransformed.parquet'

fname = './../../RUNS/dataset_/yelp-hat/yelp50.pretokenized_lower_lemma.parquet'
df = pd.read_parquet(fname)
df

Unnamed: 0,text,label,ham_html_0,human_label_0,ham_html_1,human_label_1,ham_html_2,human_label_2,id,ham_0,ham_1,ham_2,tokens.norm,tokens.form,ham,cam,sam,heuristic
0,Out in Twinsburg for work and wasn't expecting...,1,<span>Out</span> <span>in</span> <span>Twinsbu...,yes,<span>Out</span> <span>in</span> <span>Twinsbu...,yes,<span>Out</span> <span>in</span> <span>Twinsbu...,yes,ham_part1(50words)_1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[out, in, twinsburg, for, work, and, be, not, ...","[Out, in, Twinsburg, for, work, and, was, n't,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.00042654533821490776, 0..."
1,Very slow. Never been in the drive at any othe...,0,"<span class=""active"">Very</span> <span class=""...",no,"<span>Very</span> <span class=""active"">slow.</...",no,"<span>Very</span> <span class=""active"">slow.</...",no,ham_part1(50words)_2,"[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[very, slow, ., never, be, in, the, drive, at,...","[Very, slow, ., Never, been, in, the, drive, a...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.0, 0.0023104539153307505, 0.0, 0.0, 0.0, 0...."
2,"Food is good, but service terrible. They have ...",0,<span>Food</span> <span>is</span> <span class=...,idk,"<span>Food</span> <span>is</span> <span>good,<...",no,"<span>Food</span> <span>is</span> <span>good,<...",no,ham_part1(50words)_3,"[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[food, be, good, ,, but, service, terrible, .,...","[Food, is, good, ,, but, service, terrible, .,...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, ...","[0.023922084384886078, 0.0, 0.0225002665908363..."
3,Stopped by on a Sunday around 11am after a tri...,1,<span>Stopped</span> <span>by</span> <span>on<...,yes,<span>Stopped</span> <span>by</span> <span>on<...,yes,<span>Stopped</span> <span>by</span> <span>on<...,yes,ham_part1(50words)_4,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[stop, by, on, a, sunday, around, 11, am, afte...","[Stopped, by, on, a, Sunday, around, 11, am, a...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.00039099989336366543, 0.0, 0.0, 0.0, 0.0, 0..."
4,This place is horrible. They are very stingy w...,0,<span>This</span> <span>place</span> <span>is<...,no,<span>This</span> <span>place</span> <span>is<...,no,<span>This</span> <span>place</span> <span>is<...,no,ham_part1(50words)_5,"[0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, ...","[this, place, be, horrible, ., they, be, very,...","[This, place, is, horrible, ., They, are, very...","[0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, ...","[0.0, 0.012405360253083567, 0.0, 0.00312799914..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,Service and staff were very good. Topping sele...,1,"<span class=""active"">Service</span> <span>and<...",yes,<span>Service</span> <span>and</span> <span>st...,yes,<span>Service</span> <span>and</span> <span>st...,yes,ham_part1(50words)_296,"[1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, ...","[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...","[service, and, staff, be, very, good, ., toppi...","[Service, and, staff, were, very, good, ., Top...","[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...","[1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, ...","[0.019869903671844453, 0.0, 0.0042654533821490..."
296,Love it! Eaten here over 300 times in the last...,1,"<span class=""active"">Love</span> <span>it!</sp...",yes,"<span class=""active"">Love</span> <span class=""...",yes,"<span class=""active"">Love</span> <span class=""...",yes,ham_part1(50words)_297,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, ...","[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, ...","[love, it, !, eat, here, over, 300, time, in, ...","[Love, it, !, Eaten, here, over, 300, times, i...","[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, ...","[0.008495361319446913, 0.0, 0.0, 0.00252372658..."
297,"According to my friend, this local bar type pl...",1,<span>According</span> <span>to</span> <span>m...,yes,<span>According</span> <span>to</span> <span>m...,yes,<span>According</span> <span>to</span> <span>m...,yes,ham_part1(50words)_298,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, ...","[accord, to, my, friend, ,, this, local, bar, ...","[According, to, my, friend, ,, this, local, ba...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, ...","[0.0, 0.0, 0.0, 0.0003554544485124231, 0.0, 0...."
298,I went here to get a snack before I went on th...,0,<span>I</span> <span>went</span> <span>here</s...,no,<span>I</span> <span>went</span> <span>here</s...,no,<span>I</span> <span>went</span> <span>here</s...,no,ham_part1(50words)_299,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[i, go, here, to, get, a, snack, before, i, go...","[I, went, here, to, get, a, snack, before, I, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.5545444851242..."


In [None]:
DATA_CACHE = '/home/dunguyen/RUNS/dataset_'

In [28]:
from data_module.yelp_hat_module import YelpHat50DM
yelphat_dm = YelpHat50DM(cache_path=DATA_CACHE, batch_size=16)
yelphat_dm.prepare_data()
yelphat_dm.setup()

04-10-2023 15:52:07 | [32;1m   DEBUG[0m [1m [4m dataset.py:download_format_dataset:82 [0m [32;1mCorrectly handle part7.csv[0m
04-10-2023 15:52:08 | [34m    INFO[0m [1m [4m dataset.py:download_format_dataset:110 [0m [34mSave yelp subset at: /home/dunguyen/RUNS/dataset_/yelp-hat/yelp200.parquet[0m
04-10-2023 15:52:08 | [34m    INFO[0m [1m [4m dataset.py:download_format_dataset:110 [0m [34mSave yelp subset at: /home/dunguyen/RUNS/dataset_/yelp-hat/yelp50.parquet[0m
04-10-2023 15:52:08 | [34m    INFO[0m [1m [4m dataset.py:download_format_dataset:110 [0m [34mSave yelp subset at: /home/dunguyen/RUNS/dataset_/yelp-hat/yelp100.parquet[0m
04-10-2023 15:52:08 | [34m    INFO[0m [1m [4m dataset.py:download_format_dataset:116 [0m [34mSave clean dataset at /home/dunguyen/RUNS/dataset_/yelp-hat/yelp.parquet[0m
04-10-2023 15:52:08 | [34m    INFO[0m [1m [4m dataset.py:download_format_dataset:123 [0m [34mSave training set at /home/dunguyen/RUNS/dataset_/yelp-hat/