# Information Retrieval for Documents Using PLDA
The main process is taken from [here](https://github.com/RaviSoji/plda/blob/master/mnist_demo/mnist_demo.ipynb).

**Note: This methods was not good at all so we decided to use others which you can find out them in `DR.TEIT.ipynb` from [here](https://github.com/sharif-multidoc2dial/Docalog-2022).**

In [None]:
!pip install https://github.com/sadrasabouri/plda/tarball/master

Collecting https://github.com/sadrasabouri/plda/tarball/master
  Downloading https://github.com/sadrasabouri/plda/tarball/master
[K     \ 845 kB 1.7 MB/s
[?25hBuilding wheels for collected packages: plda
  Building wheel for plda (setup.py) ... [?25l[?25hdone
  Created wheel for plda: filename=plda-0.1.0-py3-none-any.whl size=13655 sha256=300c7c59b9f37d195280649631839567e2de9b92e9d4e127f74abcc832537d22
  Stored in directory: /tmp/pip-ephem-wheel-cache-91u08quk/wheels/17/57/2b/069666589a33ecf03d21ecebc97313b9fa09b7913577b61dd4
Successfully built plda
Installing collected packages: plda
Successfully installed plda-0.1.0


## Dataset
### Dataset Description

- **mutldoc2dial_doc.json** contains the documents that are indexed by key `domain` and `doc_id` . Each document instance includes the following,

  - `doc_id`: the ID of a document;
  - `title`: the title of the document;
  - `domain`: the domain of the document;
  - `doc_text`: the text content of the document (without HTML markups);
  - `doc_html_ts`: the document content with HTML markups and the annotated spans that are indicated by `text_id` attribute, which corresponds to `id_sp`.
  - `doc_html_raw`: the document content with HTML markups and without span annotations.
  - `spans`: key-value pairs of all spans in the document, with `id_sp` as key. Each span includes the following,
    - `id_sp`: the id of a  span as noted by `text_id` in  `doc_html_ts`;
    - `start_sp`/  `end_sp`: the start/end position of the text span in `doc_text`;
    - `text_sp`: the text content of the span.
    - `id_sec`: the id of the (sub)section (e.g. `<p>`) or title (`<h2>`) that contains the span.
    - `start_sec` / `end_sec`: the start/end position of the (sub)section in `doc_text`.
    - `text_sec`: the text of the (sub)section.
    - `title`: the title of the (sub)section.
    - `parent_titles`: the parent titles of the `title`.

- **multidoc2dial_dial_train.json** and **multidoc2dial_dial_validation.json**  contain the training and dev split of dialogue data that are indexed by key `domain` . Please note: **For test split, we only include a dummy file in this version.**

  Each dialogue instance includes the following,

  - `dial_id`: the ID of a dialogue;
  - `turns`: a list of dialogue turns. Each turn includes,
    - `turn_id`: the time order of the turn;
    - `role`: either "agent" or "user";READ
    - `da`: dialogue act;
    - `references`: a list of spans with `id_sp` ,  `label` and `doc_id`. `references` is empty if a turn is for indicating previous user query not answerable or irrelevant to the document. **Note** that labels "*precondition*"/"*solution*" are fuzzy annotations that indicate whether a span is for describing a conditional context or a solution.
    - `utterance`: the human-generated utterance based on the dialogue scene.
Downloading the training dataset:

In [None]:
!gdown --id 1Ln4pU93_ofAkbrz1uibsNABB0QsEaOXw

Downloading...
From: https://drive.google.com/uc?id=1Ln4pU93_ofAkbrz1uibsNABB0QsEaOXw
To: /content/multidoc2dial.zip
100% 6.45M/6.45M [00:00<00:00, 42.3MB/s]


Unziping the dataset:

In [None]:
!unzip multidoc2dial.zip

Archive:  multidoc2dial.zip
   creating: multidoc2dial/
  inflating: multidoc2dial/multidoc2dial_dial_validation.json  
  inflating: multidoc2dial/multidoc2dial_dial_train.json  
  inflating: multidoc2dial/multidoc2dial_dial_test.json  
  inflating: multidoc2dial/multidoc2dial_doc.json  
  inflating: multidoc2dial/README.md  


## Preprocess the data
In this section we'll form the training sampels for the documnet classifier based on PLDA as fallows:

$$
(X_{ij}, y_i)
$$
where $X_{ij}$ is the embedding of $j$th span from $i$th document and $y_i$ is the label of $i$th document.

In [None]:
def clean_text(text):
    """
    Clean the given text.

    :param text: input text
    :type text: str
    :return: cleaned string
    """
    return text.strip()

In [None]:
import json
with open('multidoc2dial/multidoc2dial_doc.json', 'r') as f:
    multidoc2dial_doc = json.load(f)

In [None]:
multidoc2dial_doc['doc_data']['ssa']['Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0']['spans']['1']

{'end_sec': 61,
 'end_sp': 61,
 'id_sec': 't_0',
 'id_sp': '1',
 'parent_titles': [],
 'start_sec': 0,
 'start_sp': 0,
 'tag': 'h2',
 'text_sec': '\n\nBenefits Planner: Survivors | Planning For Your Survivors \n',
 'text_sp': '\n\nBenefits Planner: Survivors | Planning For Your Survivors \n',
 'title': 'Benefits Planner: Survivors | Planning For Your Survivors'}

In [None]:
doc_sentence_train = []
doc_label_train = []
for doc_idx1 in multidoc2dial_doc['doc_data']:
    for doc_idx2 in multidoc2dial_doc['doc_data'][doc_idx1]:
        for doc_idx3 in multidoc2dial_doc['doc_data'][doc_idx1]\
                                          [doc_idx2]['spans']:
            doc_sentence_train.append(clean_text(multidoc2dial_doc['doc_data']\
                                                 [doc_idx1][doc_idx2]['spans']\
                                                 [doc_idx3]['text_sp']))
            doc_label_train.append(doc_idx2)

In [None]:
len(doc_label_train)  # Number of total samples

35659

In [None]:
len(set(doc_label_train))   # Number of total docs

488

In [None]:
for i in [1, 100, 1000, 2000, 5000]:
    print(doc_sentence_train[i])
    print(doc_label_train[i])
    print('--' * 20)

As you plan for the future ,
Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0
----------------------------------------
you'll want to think about what your family would need if you should die now.
Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#2_0
----------------------------------------
Religious record made before the age of 5 showing your date of birth ;
Learn what documents you will need to get a Social Security Card | Social Security Administration#10_0
----------------------------------------
What happens after I apply?
Disability Benefits | Social Security Administration#1_0
----------------------------------------
For more information about our disability claims process ,
Benefits Planner: Disability | How You Qualify | Social Security Administration#2_0
----------------------------------------


### Encoding the sentences
We use the LaBSE which is a Language-agnostic BERT Sentence Encoder (LaBSE) is a BERT-based model trained for sentence embedding for 109 languages. The pre-training process combines masked language modeling with translation language modeling. The model is useful for getting multilingual sentence embeddings and for bi-text retrieval.

In [None]:
!pip install transformers



In [None]:
from transformers import AutoTokenizer, AutoModel, AutoConfig
import numpy as np
import torch
from torch.nn.functional import normalize

In [None]:
tokenizer_labse = AutoTokenizer.from_pretrained("setu4993/LaBSE")
model_labse = AutoModel.from_pretrained("setu4993/LaBSE")

#### `get_embeddings`
In this method we extract the **pooler output** (Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining).

In [None]:
def get_embeddings(sentece):
    """
    Return embeddings based on encoder model

    :param sentence: input sentence(s)
    :type sentence: str or list of strs
    :return: embeddings
    """
    tokenized = tokenizer_labse(sentece,
                                return_tensors="pt",
                                padding=True)
    with torch.no_grad():
        embeddings = model_labse(**tokenized)
    
    return np.squeeze(np.array(embeddings.pooler_output))

In [None]:
TRAIN_SIZE = len(doc_label_train)  # for final training

In [None]:
X = []
labels = list(set(doc_label_train))
y = []
progress = 0
for sentence, label in zip(doc_sentence_train[:TRAIN_SIZE],
                           doc_label_train[:TRAIN_SIZE]):
    X.append(get_embeddings(sentence))
    y.append(labels.index(label))
    progress += 1
    if progress % 50 == 0:
        print('Progress Percent = {}%'.format(100 * progress / TRAIN_SIZE))

Progress Percent = 0.14021705600269216%
Progress Percent = 0.2804341120053843%
Progress Percent = 0.4206511680080765%
Progress Percent = 0.5608682240107686%
Progress Percent = 0.7010852800134608%
Progress Percent = 0.841302336016153%
Progress Percent = 0.9815193920188452%
Progress Percent = 1.1217364480215373%
Progress Percent = 1.2619535040242296%
Progress Percent = 1.4021705600269216%
Progress Percent = 1.542387616029614%
Progress Percent = 1.682604672032306%
Progress Percent = 1.822821728034998%
Progress Percent = 1.9630387840376904%
Progress Percent = 2.1032558400403825%
Progress Percent = 2.2434728960430745%
Progress Percent = 2.383689952045767%
Progress Percent = 2.523907008048459%
Progress Percent = 2.664124064051151%
Progress Percent = 2.8043411200538433%
Progress Percent = 2.9445581760565354%
Progress Percent = 3.084775232059228%
Progress Percent = 3.22499228806192%
Progress Percent = 3.365209344064612%
Progress Percent = 3.505426400067304%
Progress Percent = 3.645643456069996

### Saving the datas

In [None]:
X = np.array(X)
y = np.array(y)
print(X.shape, y.shape)

(35659, 768) (35659,)


In [None]:
with open('doc_spans_LaBSE_Embedding.npy', 'wb') as f:
    np.save(f, X)
with open('doc_labels.npy', 'wb') as f:
    np.save(f, y)

In [None]:
for i in [1, 100, 1000, 2000, 5000]:
    print(doc_sentence_train[i])
    print(X[i])
    print(doc_label_train[i])
    print(y[i])
    print('--' * 20)

As you plan for the future ,
[-1.81475043e-01 -5.57618916e-01 -1.24171665e-02 -3.00217479e-01
 -5.32819152e-01 -1.57167733e-01 -2.91564651e-02  2.33065173e-01
 -3.90631616e-01  4.06808764e-01 -5.93124554e-02  6.02699742e-02
  3.24147969e-01 -7.00683370e-02  5.46456166e-02 -1.70996666e-01
  5.40538458e-03  3.65681857e-01 -4.62393641e-01  3.33274901e-01
 -2.09975570e-01  1.91097274e-01 -4.94019747e-01 -3.16805303e-01
 -3.73387456e-01  3.16426247e-01 -1.21236019e-01 -4.35519248e-01
 -6.76226020e-01 -3.96640718e-01 -3.65455709e-02 -6.61699653e-01
 -1.46537498e-01 -3.36690724e-01  4.28373516e-01 -5.37108481e-01
 -2.90278226e-01  4.67822343e-01 -1.73657313e-01  3.03228386e-02
 -2.60386884e-01 -5.60500085e-01  1.58685669e-01 -6.48382902e-01
 -1.71234593e-01 -9.78930667e-02 -6.56116128e-01  4.51449782e-01
  1.58998057e-01 -2.37927601e-01  3.71033043e-01  9.97790322e-02
  1.21435821e-01  5.63008189e-01 -5.87908983e-01  5.54587431e-02
  5.60545504e-01 -3.33180249e-01 -1.52503267e-01  1.38059795e

## PLDA Training

In [None]:
import plda

### Fit model for overfit classifier
This method overfits the PLDA instead of this method we could also use bellow method:

```
better_classifier = plda.Classifier()
better_classifier.fit_model(X, y, n_principal_components=5)
```

In [None]:
PLDA_classifier = plda.Classifier()
PLDA_classifier.fit_model(np.array(X),
                          np.array(y))

In [None]:
import pickle
with open('plda_clf.pkl', 'wb') as f:
    pickle.dump(PLDA_classifier, f)

NameError: ignored

## PLDA Testing
In this section we wanted to test the trained PLDA model.

In [None]:
def predict_doc(query):
    """
    Predict which document is matched to the given query.

    :param query: input query
    :type query: str (or list of strs)
    :return: return the document name
    """
    query_embedding = get_embeddings(query)
    predictions, log_p_predictions = PLDA_classifier.predict(query_embedding)
    return labels[predictions]

In [None]:
test_queries = ["I'm looking for information regarding benefits planning, can you help me?",
                "I want to know about the benefits plan for survivors, can you give me more information about this?",
                "What are Social Security credits?",
                "Do you have any knowledge of Adult Disability Report? What if my spouse and I are no longer together?"]
test_labels = ["Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0",
               "Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0",
               "Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0",
               "Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0"]

In [None]:
for query in test_queries:
    print(predict_doc(query))

How To Apply For The GI Bill | Veterans Affairs#1_0
Benefits Planner: Disability | Social Security Administration#1_0
Benefit Verification Letter  | Social Security Administration#1_0
Benefits Planner: Disability | How You Qualify | Social Security Administration#1_0


## Final Test

In [None]:
!gdown --id 1ALs7qEOVzY8B-JF1BKuOJzGcS_1DcGbz

Downloading...
From: https://drive.google.com/uc?id=1ALs7qEOVzY8B-JF1BKuOJzGcS_1DcGbz
To: /content/plda_clf.pkl
  0% 0.00/16.6M [00:00<?, ?B/s]100% 16.6M/16.6M [00:00<00:00, 230MB/s]


In [None]:
doc_sentence_train = []
doc_label_train = []
for doc_idx1 in multidoc2dial_doc['doc_data']:
    for doc_idx2 in multidoc2dial_doc['doc_data'][doc_idx1]:
        for doc_idx3 in multidoc2dial_doc['doc_data'][doc_idx1]\
                                          [doc_idx2]['spans']:
            doc_sentence_train.append(clean_text(multidoc2dial_doc['doc_data']\
                                                 [doc_idx1][doc_idx2]['spans']\
                                                 [doc_idx3]['text_sp']))
            doc_label_train.append(doc_idx2)
labels = list(set(doc_label_train))

In [None]:
import pickle
with open('plda_clf.pkl', 'rb') as f:
    PLDA_classifier = pickle.load(f)

In [None]:
def if_predicted(query, predicted):
    if isinstance(predicted, (int, np.uint8)):
        return query == predicted
    return True if query in predicted else False

In [None]:
def predict_doc_at(query, k=1):
    """
    Predict which document is matched to the given query.

    :param query: input query
    :type query: str (or list of strs)
    :param k: number of returning docs
    :type k: int 
    :return: return the document name
    """
    query_embedding = get_embeddings(query)
    predictions, log_p_predictions = PLDA_classifier.predict(query_embedding,
                                                             n_best=3)
    predictions = predictions[:k]
    sum_log = np.sum(np.exp(-log_p_predictions))
    accuracy = list(map(lambda x: np.exp(-x) / sum_log,
                        log_p_predictions[predictions]))
    predictions = list(map(lambda x: labels[x], predictions))
    return accuracy, predictions

In [None]:
for query in test_queries:
    accs, preds = predict_doc_at(query, k=5)
    print(accs)
    print(preds)
    print('-' * 20)

TypeError: ignored

In [None]:
import json
with open('multidoc2dial/multidoc2dial_dial_train.json', 'r') as f:
    multidoc2dial_dial_train = json.load(f)

In [None]:
multidoc2dial_dial_train['dial_data']['dmv'][0]['turns'][0]['utterance']

In [None]:
multidoc2dial_dial_train['dial_data']['dmv'][0]['turns'][0]['references'][0]['doc_id']

In [None]:
doc_sentence_test = []
doc_label_test = []
for doc_idx1 in multidoc2dial_dial_train['dial_data']:
    for dial in multidoc2dial_dial_train['dial_data'][doc_idx1]:
        for turns in dial['turns']:
            if turns['role'] == "user":
                doc_sentence_test.append(turns['utterance'])
                doc_label_test.append(turns['references'][0]['doc_id'])

In [None]:
for i in [1, 100, 1000, 2000, 5000]:
    print(doc_sentence_test[i])
    print(doc_label_test[i])
    print('--' * 20)

In [None]:
TEST_SIZE = len(doc_sentence_test)
TEST_SIZE

In [None]:
prec_at_500 = 0
prec_at_100 = 0
prec_at_50 = 0
prec_at_10 = 0
prec_at_5 = 0
prec_at_1 = 0
sample_till_now = 0
ranks = []
for query, act_doc in zip(doc_sentence_test[:TEST_SIZE],
                          doc_label_test[:TEST_SIZE]):
    accs, preds = predict_doc_at(query, k=500)
    ranks.append(1/ (preds.index(act_doc) + 1))
    # print(accs)
    # print(preds)
    # print(act_doc)
    # print('-' * 20)
    if act_doc == preds[0]:
        prec_at_1 += 1
    if act_doc in preds[:5]:
        prec_at_5 += 1
    if act_doc in preds[:10]:
        prec_at_10 += 1
    if act_doc in preds[:50]:
        prec_at_50 += 1
    if act_doc in preds[:100]:
        prec_at_100 += 1
    if act_doc in preds[:500]:
        prec_at_500 += 1
    sample_till_now += 1
    if sample_till_now % 10 == 0:
        print("MRR: mean={}, var={}".format(np.array(ranks).mean(), np.array(ranks).var()))
        print("Prec@(1) = {} | Prec@(5) = {} | Prec@(10) = {} | Prec@(50) = {} | Prec@(100) = {} | Prec@(500) = {} | NUMBER_OF_SAMPLES = {}".\
              format(prec_at_1 / sample_till_now, prec_at_5 / sample_till_now,
                     prec_at_10 / sample_till_now, prec_at_50 / sample_till_now,
                     prec_at_100 / sample_till_now, prec_at_500 / sample_till_now,
                     sample_till_now))

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


Prec@(1) = 0.0 | Prec@(5) = 0.0 | Prec@(10) = 0.0 | Prec@(50) = 0.2 | Prec@(100) = 0.3 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 10
Prec@(1) = 0.0 | Prec@(5) = 0.0 | Prec@(10) = 0.0 | Prec@(50) = 0.3 | Prec@(100) = 0.4 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 20
Prec@(1) = 0.0 | Prec@(5) = 0.0 | Prec@(10) = 0.0 | Prec@(50) = 0.3 | Prec@(100) = 0.36666666666666664 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 30
Prec@(1) = 0.0 | Prec@(5) = 0.0 | Prec@(10) = 0.0 | Prec@(50) = 0.275 | Prec@(100) = 0.4 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 40
Prec@(1) = 0.0 | Prec@(5) = 0.0 | Prec@(10) = 0.0 | Prec@(50) = 0.26 | Prec@(100) = 0.4 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 50
Prec@(1) = 0.0 | Prec@(5) = 0.0 | Prec@(10) = 0.0 | Prec@(50) = 0.21666666666666667 | Prec@(100) = 0.3333333333333333 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 60


KeyboardInterrupt: ignored

# Label-wise Test (Much Faster)

In [None]:
label_embeddings = []
progress = 0
TRAIN_SIZE = len(labels)
for label in labels:
    label_embeddings.append(get_embeddings(label))
    progress += 1
    if progress % 50 == 0:
        print('Progress Percent = {}%'.format(100 * progress / TRAIN_SIZE))

Progress Percent = 10.245901639344263%
Progress Percent = 20.491803278688526%
Progress Percent = 30.737704918032787%
Progress Percent = 40.98360655737705%
Progress Percent = 51.22950819672131%
Progress Percent = 61.47540983606557%
Progress Percent = 71.72131147540983%
Progress Percent = 81.9672131147541%
Progress Percent = 92.21311475409836%


In [None]:
def predict_labelwise_doc_at(query, k=1):
    """
    Predict which document is matched to the given query.

    :param query: input query
    :type query: str (or list of strs)
    :param k: number of returning docs
    :type k: int 
    :return: return the document name
    """
    query_embedding = get_embeddings(query)
    similarities = list(map(lambda x: np.dot(x, query_embedding) /
                            (np.linalg.norm(query_embedding) * np.linalg.norm(x)),
                            label_embeddings))
    similarities = np.array(similarities)
    best_k_idx = similarities.argsort()[::-1][:k]
    predictions = list(map(lambda x: labels[x], best_k_idx))
    accuracy = similarities[best_k_idx]
    return accuracy, predictions

In [None]:
for query in test_queries:
    accs, preds = predict_labelwise_doc_at(query, k=5)
    print(accs)
    print(preds)
    print('-' * 20)

[0.4170898  0.4135436  0.41314107 0.41062284 0.40333804]
['Benefits Planner: Disability | How You Apply | Social Security Administration#2_0', 'Benefits Planner: Disability | Are You Working | Social Security Administration#2_0', 'Benefits Planner: Disability | How You Apply | Social Security Administration#1_0', 'Benefits Planner: Disability | Are You Working | Social Security Administration#1_0', 'Contracting Information | Federal Student Aid#1_0']
--------------------
[0.38753018 0.3701287  0.36652112 0.35735554 0.3544231 ]
['VA Education Benefits For Survivors And Dependents | Veterans Affairs#1_0', 'Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#2_0', 'Benefits Planner: Survivors | Planning For Your Survivors | Social Security Administration#1_0', 'Benefits Planner: Survivors | If You Are The Survivor | Social Security Administration#2_0', 'Benefits Planner: Survivors | If You Are The Survivor | Social Security Administration#1_0']
-----

In [None]:
prec_at_500 = 0
prec_at_100 = 0
prec_at_50 = 0
prec_at_10 = 0
prec_at_5 = 0
prec_at_1 = 0
sample_till_now = 0
for query, act_doc in zip(doc_sentence_test[:TEST_SIZE],
                          doc_label_test[:TEST_SIZE]):
    accs, preds = predict_labelwise_doc_at(query, k=500)
    # print(accs)
    # print(preds)
    # print(act_doc)
    # print('-' * 20)
    if act_doc == preds[0]:
        prec_at_1 += 1
    if act_doc in preds[:5]:
        prec_at_5 += 1
    if act_doc in preds[:10]:
        prec_at_10 += 1
    if act_doc in preds[:50]:
        prec_at_50 += 1
    if act_doc in preds[:100]:
        prec_at_100 += 1
    if act_doc in preds[:500]:
        prec_at_500 += 1
    sample_till_now += 1
    if sample_till_now % 100 == 0:
        print("Prec@(1) = {} | Prec@(5) = {} | Prec@(10) = {} | Prec@(50) = {} | Prec@(100) = {} | Prec@(500) = {} | NUMBER_OF_SAMPLES = {}".\
              format(prec_at_1 / sample_till_now, prec_at_5 / sample_till_now,
                     prec_at_10 / sample_till_now, prec_at_50 / sample_till_now,
                     prec_at_100 / sample_till_now, prec_at_500 / sample_till_now,
                     sample_till_now))

Prec@(1) = 0.05 | Prec@(5) = 0.17 | Prec@(10) = 0.31 | Prec@(50) = 0.47 | Prec@(100) = 0.58 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 100
Prec@(1) = 0.09 | Prec@(5) = 0.19 | Prec@(10) = 0.295 | Prec@(50) = 0.49 | Prec@(100) = 0.64 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 200
Prec@(1) = 0.10666666666666667 | Prec@(5) = 0.22333333333333333 | Prec@(10) = 0.32666666666666666 | Prec@(50) = 0.49333333333333335 | Prec@(100) = 0.66 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 300
Prec@(1) = 0.1075 | Prec@(5) = 0.225 | Prec@(10) = 0.3375 | Prec@(50) = 0.53 | Prec@(100) = 0.6875 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 400
Prec@(1) = 0.124 | Prec@(5) = 0.244 | Prec@(10) = 0.35 | Prec@(50) = 0.556 | Prec@(100) = 0.706 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 500
Prec@(1) = 0.12333333333333334 | Prec@(5) = 0.24 | Prec@(10) = 0.33666666666666667 | Prec@(50) = 0.5416666666666666 | Prec@(100) = 0.6916666666666667 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 600
Prec@(1) = 0.12571428571428572 | Prec@(5) = 0.23

KeyboardInterrupt: ignored

### History (last three sentences)

In [None]:
prec_at_500 = 0
prec_at_100 = 0
prec_at_50 = 0
prec_at_10 = 0
prec_at_5 = 0
prec_at_1 = 0
sample_till_now = 0
for i in range(2, TEST_SIZE):
    query = '.'.join(doc_sentence_test[i-2:i+1])
    act_doc = doc_label_test[i]
    accs, preds = predict_labelwise_doc_at(query, k=500)
    # print(accs)
    # print(preds)
    # print(act_doc)
    # print('-' * 20)
    if act_doc == preds[0]:
        prec_at_1 += 1
    if act_doc in preds[:5]:
        prec_at_5 += 1
    if act_doc in preds[:10]:
        prec_at_10 += 1
    if act_doc in preds[:50]:
        prec_at_50 += 1
    if act_doc in preds[:100]:
        prec_at_100 += 1
    if act_doc in preds[:500]:
        prec_at_500 += 1
    sample_till_now += 1
    if sample_till_now % 100 == 0:
        print("Prec@(1) = {} | Prec@(5) = {} | Prec@(10) = {} | Prec@(50) = {} | Prec@(100) = {} | Prec@(500) = {} | NUMBER_OF_SAMPLES = {}".\
              format(prec_at_1 / sample_till_now, prec_at_5 / sample_till_now,
                     prec_at_10 / sample_till_now, prec_at_50 / sample_till_now,
                     prec_at_100 / sample_till_now, prec_at_500 / sample_till_now,
                     sample_till_now))

Prec@(1) = 0.05 | Prec@(5) = 0.13 | Prec@(10) = 0.22 | Prec@(50) = 0.46 | Prec@(100) = 0.7 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 100
Prec@(1) = 0.085 | Prec@(5) = 0.15 | Prec@(10) = 0.215 | Prec@(50) = 0.49 | Prec@(100) = 0.66 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 200
Prec@(1) = 0.09666666666666666 | Prec@(5) = 0.17 | Prec@(10) = 0.25333333333333335 | Prec@(50) = 0.5033333333333333 | Prec@(100) = 0.6833333333333333 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 300
Prec@(1) = 0.1075 | Prec@(5) = 0.22 | Prec@(10) = 0.3 | Prec@(50) = 0.55 | Prec@(100) = 0.7325 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 400
Prec@(1) = 0.094 | Prec@(5) = 0.242 | Prec@(10) = 0.316 | Prec@(50) = 0.56 | Prec@(100) = 0.742 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 500
Prec@(1) = 0.08666666666666667 | Prec@(5) = 0.23333333333333334 | Prec@(10) = 0.3 | Prec@(50) = 0.5366666666666666 | Prec@(100) = 0.73 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 600
Prec@(1) = 0.09571428571428571 | Prec@(5) = 0.2357142857142857 | Prec

KeyboardInterrupt: ignored

# Label-wise + TF-IDF

In [None]:
doc_texts_train = []
for doc_idx1 in multidoc2dial_doc['doc_data']:
    for doc_idx2 in multidoc2dial_doc['doc_data'][doc_idx1]:
        doc_texts_train.append(multidoc2dial_doc['doc_data'][doc_idx1]\
                                          [doc_idx2]['doc_text'].strip())

In [None]:
doc_texts_train[0]

"Benefits Planner: Survivors | Planning For Your Survivors \nAs you plan for the future , you'll want to think about what your family would need if you should die now. Social Security can help your family if you have earned enough Social Security credits through your work. You can earn up to four credits each year. In 2019 , for example , you earn one credit for each $1,360 of wages or self - employment income. When you have earned $5,440 , you have earned your four credits for the year. The number of credits needed to provide benefits for your survivors depends on your age when you die. No one needs more than 40 credits 10 years of work to be eligible for any Social Security benefit. But , the younger a person is , the fewer credits they must have for family members to receive survivors benefits. Benefits can be paid to your children and your spouse who is caring for the children even if you don't have the required number of credits. They can get benefits if you have credit for one an

In [None]:
words = set()
doc_texts_train_tokenized = []
for doc in doc_texts_train:
    tokenized_doc = [s.lower() for s in tokenizer_labse.tokenize(doc)]
    doc_texts_train_tokenized.append(tokenized_doc) 
    words = set(tokenized_doc).union(words)
len(words)

8446

In [None]:
words2IDF = {}
N_doc = len(doc_texts_train)
for i, word in enumerate(words):
    n_word = 0
    for doc in doc_texts_train_tokenized:
        if word in doc:
            n_word += 1
    words2IDF[word] = np.log(N_doc / (n_word + 1))
    if i % 1000 == 0:
        print(word, words2IDF[word])

nature 3.792420133054777
nobody 5.0917031171850375
potentially 4.804021044733257
examination 3.8877303128591016
##wal 3.5512580762378887
appear 2.5793974932089228
participate 3.245876426686707
##mig 3.6253660483916104
go 1.1730355690382228


In [None]:
def calc_idf_score(sentence):
    tokenzied_sentence = [s.lower() for s in tokenizer_labse.tokenize(sentence)]
    score = 0
    for token in tokenzied_sentence:
        if token in words2IDF:
            score += words2IDF[token]
        else:
            score += np.log(N_doc)
    return score / len(tokenzied_sentence)

In [None]:
def predict_labelwise_doc_at_history(query_h2, query_h1, query_h0, k=1):
    """
    Predict which document is matched to the given query.

    :param query: input query
    :type query: str (or list of strs)
    :param k: number of returning docs
    :type k: int 
    :return: return the document name
    """
    query_h2_embedding = get_embeddings(query_h2)
    similarities2 = list(map(lambda x: np.dot(x, query_h2_embedding) /
                            (np.linalg.norm(query_h2_embedding) * np.linalg.norm(x)),
                            label_embeddings))
    similarities2 = np.array(similarities2)
    idf_score2 = calc_idf_score(query_h2)

    query_h1_embedding = get_embeddings(query_h1)
    similarities1 = list(map(lambda x: np.dot(x, query_h1_embedding) /
                            (np.linalg.norm(query_h1_embedding) * np.linalg.norm(x)),
                            label_embeddings))
    similarities1 = np.array(similarities1)
    idf_score1 = calc_idf_score(query_h1)

    query_h0_embedding = get_embeddings(query_h0)
    similarities0 = list(map(lambda x: np.dot(x, query_h0_embedding) /
                            (np.linalg.norm(query_h0_embedding) * np.linalg.norm(x)),
                            label_embeddings))
    similarities0 = np.array(similarities0)
    idf_score0 = calc_idf_score(query_h0)

    similarities = (idf_score0 * similarities0 + \
                    idf_score1 * similarities1 + \
                    idf_score2 * similarities2) / \
                    (idf_score0 + idf_score1 + idf_score2)
    best_k_idx = similarities.argsort()[::-1][:k]
    predictions = list(map(lambda x: labels[x], best_k_idx))
    accuracy = similarities[best_k_idx]
    return accuracy, predictions

In [None]:
prec_at_500 = 0
prec_at_100 = 0
prec_at_50 = 0
prec_at_10 = 0
prec_at_5 = 0
prec_at_1 = 0
sample_till_now = 0
for i in range(2, TEST_SIZE):
    act_doc = doc_label_test[i]
    query_h2 = doc_sentence_test[i-2]
    query_h1 = doc_sentence_test[i-1]
    query_h0 = doc_sentence_test[i]
    accs, preds = predict_labelwise_doc_at_history(query_h2,
                                                   query_h1,
                                                   query_h0,
                                                   k=500)
    # print(accs)
    # print(preds)
    # print(act_doc)
    # print('-' * 20)
    if act_doc == preds[0]:
        prec_at_1 += 1
    if act_doc in preds[:5]:
        prec_at_5 += 1
    if act_doc in preds[:10]:
        prec_at_10 += 1
    if act_doc in preds[:50]:
        prec_at_50 += 1
    if act_doc in preds[:100]:
        prec_at_100 += 1
    if act_doc in preds[:500]:
        prec_at_500 += 1
    sample_till_now += 1
    if sample_till_now % 100 == 0:
        print("Prec@(1) = {} | Prec@(5) = {} | Prec@(10) = {} | Prec@(50) = {} | Prec@(100) = {} | Prec@(500) = {} | NUMBER_OF_SAMPLES = {}".\
              format(prec_at_1 / sample_till_now, prec_at_5 / sample_till_now,
                     prec_at_10 / sample_till_now, prec_at_50 / sample_till_now,
                     prec_at_100 / sample_till_now, prec_at_500 / sample_till_now,
                     sample_till_now))

Prec@(1) = 0.06 | Prec@(5) = 0.15 | Prec@(10) = 0.23 | Prec@(50) = 0.47 | Prec@(100) = 0.69 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 100
Prec@(1) = 0.045 | Prec@(5) = 0.155 | Prec@(10) = 0.25 | Prec@(50) = 0.5 | Prec@(100) = 0.715 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 200
Prec@(1) = 0.06 | Prec@(5) = 0.19 | Prec@(10) = 0.29 | Prec@(50) = 0.5366666666666666 | Prec@(100) = 0.76 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 300
Prec@(1) = 0.085 | Prec@(5) = 0.2325 | Prec@(10) = 0.335 | Prec@(50) = 0.59 | Prec@(100) = 0.795 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 400
Prec@(1) = 0.108 | Prec@(5) = 0.266 | Prec@(10) = 0.364 | Prec@(50) = 0.622 | Prec@(100) = 0.816 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 500
Prec@(1) = 0.105 | Prec@(5) = 0.25333333333333335 | Prec@(10) = 0.345 | Prec@(50) = 0.595 | Prec@(100) = 0.8083333333333333 | Prec@(500) = 1.0 | NUMBER_OF_SAMPLES = 600
Prec@(1) = 0.11714285714285715 | Prec@(5) = 0.2542857142857143 | Prec@(10) = 0.3457142857142857 | Prec@(50) = 0.6 | Prec

KeyboardInterrupt: ignored