# OTDD Dataset Comparison

### Imports

In [1]:
import pandas as pd
import numpy as np
from word2number import w2n

from otdd.pytorch.distance import DatasetDistance
from sentence_transformers import SentenceTransformer

import torch
from torch.utils.data import DataLoader, TensorDataset

import warnings
warnings.filterwarnings('ignore')

  from tqdm.autonotebook import tqdm


                  variable OMP_PATH to the location of the header before importing keopscore or pykeops,
                  e.g. using os.environ: import os; os.environ['OMP_PATH'] = '/path/to/omp/header'


ot.gpu not found - coupling computation will be in cpu


### Preprocessing

In [2]:
data1 = pd.read_csv('./data/cleaned_data.csv')
data2 = pd.read_csv('./data/clean_data2.csv')

In [3]:
# rename Book ID and Chat0CREW1 to book_id and chat_crew
columns_to_rename = {'Dialogic Spell': 'dialogic_spell',
                     'Discussion' : 'discussion_type',
                     'Pivot' : 'pivot',
                     'Question' : 'question',
                     'Uptake' : 'uptake',
                     }
data1.rename(columns=columns_to_rename, inplace=True)

In [4]:
sorted(data1.columns)

['book_id',
 'bookclub',
 'chat_crew',
 'course',
 'dialogic_spell',
 'discussion_type',
 'is_answer',
 'message',
 'page',
 'pivot',
 'pseudonym',
 'question',
 'response_number',
 'time',
 'topic',
 'uptake']

In [5]:
sorted(data2.columns)

['book_id',
 'bookclub',
 'dialogic_spell',
 'discussion_type',
 'is_answer',
 'message',
 'page',
 'pivot',
 'pseudonym',
 'question',
 'time',
 'topic',
 'uptake']

In [6]:
columns_to_remove = list(set(data1.columns) - set(data2.columns))
print('Columns will be removed: {}'.format(columns_to_remove))

Columns will be removed: ['course', 'chat_crew', 'response_number']


In [7]:
data1.drop(columns=columns_to_remove, inplace=True)

In [8]:
# test if the columns are the same
assert list(set(data1.columns) - set(data2.columns)) == []

In [9]:
data1.head()

Unnamed: 0,book_id,topic,bookclub,pseudonym,message,time,is_answer,page,discussion_type,dialogic_spell,uptake,question,pivot
0,260,Part 1: What happens next? What is behind the ...,1,pim-01,Hello.,2020-10-20 17:06:00,False,10.0,Social,1.0,,,
1,260,Part 1: What happens next? What is behind the ...,1,pim-01,My assumption is,2020-10-20 17:06:00,False,10.0,Seminar,1.0,,,from Social/Procedure/UX to Seminar
2,260,Part 1: What happens next? What is behind the ...,1,pim-01,that the emphasis on barbarism implies that sh...,2020-10-20 17:06:00,False,10.0,Seminar,1.0,,,
3,260,Part 1: What happens next? What is behind the ...,1,pim-03,I agree with Cassandra's noticing,2020-10-27 17:58:00,False,10.0,Seminar,1.0,Affirm,,
4,260,Part 1: What happens next? What is behind the ...,1,pim-03,of the author's word choice of barbarism.,2020-10-27 17:58:00,False,10.0,Seminar,1.0,Affirm,,


In [10]:
data2.head()

Unnamed: 0,book_id,topic,bookclub,message,time,is_answer,page,question,pivot,dialogic_spell,discussion_type,uptake,pseudonym
0,306,"Using the chat discussion to the right, discus...",Book Club One,hello,2022-03-01 14:41:05,No,8.0,,,,Social,Filler,430.0 (Ava)
1,306,"Using the chat discussion to the right, discus...",Book Club One,yoooo wasssupppp,2022-03-01 14:41:21,No,8.0,,,,Social,Filler,407.0 (Samiran)
2,306,"Using the chat discussion to the right, discus...",Book Club One,hola,2022-03-01 14:41:42,No,6.0,,,,Social,Filler,416.0 (Nicholas)
3,306,"Using the chat discussion to the right, discus...",Book Club One,yoooo wasssupppp yooo,2022-03-01 14:42:04,No,8.0,,,,Social,Filler,407.0 (Samiran)
4,306,"Using the chat discussion to the right, discus...",Book Club One,so lets start out,2022-03-01 14:42:54,No,8.0,,,,Deliberation,Prompt,430.0 (Ava)


In [11]:
data1.book_id.value_counts()

book_id
260    427
306    288
261    190
Name: count, dtype: int64

In [12]:
data2.book_id.value_counts()

book_id
306                                        288
'transport make 10% of emissions' pg 10      1
where was this quote found?                  1
Name: count, dtype: int64

In [13]:
data1 = data1[data1['book_id'] != 306]

In [14]:
data2 = data2[data2['book_id'] == '306']

In [15]:
# bookclub always starts with Book
data2['bookclub'] = data2['bookclub'].fillna('').astype(str)
data2 = data2[data2['bookclub'].str.startswith('Book')]

# map Book Club One to 1, Book Club Two to 2, etc
data2['bookclub'] = data2['bookclub'].apply(lambda x: w2n.word_to_num(x.split(' ')[-1]))
data2['bookclub'].value_counts()

bookclub
1    82
7    50
3    48
2    44
6    38
5    26
Name: count, dtype: int64

In [16]:
# if Is Answer is nan or ' ' then it is True, otherwise False
data2['is_answer'] = data2['is_answer'].isna() | data2['is_answer'].str.isspace()

In [17]:
data2.book_id = data2.book_id.astype(int)

In [18]:
columns = list(data1.columns)
data2 = data2[columns] # re-order columns

In [19]:
display(data1.info())
print('-'*50)
display(data2.info())

<class 'pandas.core.frame.DataFrame'>
Index: 617 entries, 0 to 616
Data columns (total 13 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   book_id          617 non-null    int64  
 1   topic            609 non-null    object 
 2   bookclub         617 non-null    int64  
 3   pseudonym        617 non-null    object 
 4   message          617 non-null    object 
 5   time             617 non-null    object 
 6   is_answer        617 non-null    bool   
 7   page             515 non-null    float64
 8   discussion_type  617 non-null    object 
 9   dialogic_spell   467 non-null    float64
 10  uptake           374 non-null    object 
 11  question         84 non-null     object 
 12  pivot            47 non-null     object 
dtypes: bool(1), float64(2), int64(2), object(8)
memory usage: 63.3+ KB


None

--------------------------------------------------
<class 'pandas.core.frame.DataFrame'>
Index: 288 entries, 0 to 289
Data columns (total 13 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   book_id          288 non-null    int64  
 1   topic            288 non-null    object 
 2   bookclub         288 non-null    int64  
 3   pseudonym        288 non-null    object 
 4   message          288 non-null    object 
 5   time             287 non-null    object 
 6   is_answer        288 non-null    bool   
 7   page             287 non-null    float64
 8   discussion_type  288 non-null    object 
 9   dialogic_spell   176 non-null    float64
 10  uptake           279 non-null    object 
 11  question         0 non-null      float64
 12  pivot            0 non-null      float64
dtypes: bool(1), float64(4), int64(2), object(6)
memory usage: 29.5+ KB


None

In [20]:
target_labels = ['discussion_type', 'uptake']

In [24]:
for label in target_labels:
    data1[label] = data1[label].astype(str)
    data2[label] = data2[label].astype(str)

### OTDD

In [25]:
model_path = 'all-MiniLM-L6-v2'

In [26]:
model = SentenceTransformer(model_path)

In [None]:
# Getting the embeddings
emb1 = model.encode(data1.message.tolist())
emb2 = model.encode(data2.message.tolist())

In [32]:
def create_data_loader(df, emb, target_label, batch_size=32):
    
    # Create label tensors
    ul = df[target_label].unique()
    mapper = dict(zip(ul, range(len(ul))))
    labels = [mapper[c] for c in df[target_label].tolist()]
    label_tensor = torch.tensor(labels)
        
    # Create embeddings tensor
    embeddings = torch.tensor(emb)

    # Create a TensorDataset
    dataset = TensorDataset(embeddings, label_tensor)
    
    # Create a DataLoader
    batch_size = batch_size
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return dataloader

In [None]:
# Create data loaders
for label in target_labels:
    
    d1 = create_data_loader(data1, emb1, label)
    d2 = create_data_loader(data2, emb2, label)

    # Compute the distance
    try:
        dist = DatasetDistance(d1, d2,
                    inner_ot_method = 'means_only',
                    debiased_loss = True,
                    p = 2, entreg = 1e-1,
                    device='cpu')
    except Exception as e:
        print('Error:', e)
        
    print(f'Distance for label {label}: {dist.distance(maxsamples = 10000)}')

- Distance for label discussion_type: 0.5095193386077881
- Distance for label uptake: 0.5094767212867737