# Rocchio feedback

In [None]:
%pip install ipytest

In [None]:
import ipytest
import pytest
from typing import List

ipytest.autoconfig()

Vocabulary

In [None]:
VOCAB = ['news', 'about', 'presidental', 'campaign', 'food', 'text']

Query vector

In [None]:
Q = [1, 1, 1, 1, 0, 0]

Document-term matrix (each row corresponds to a document vector)

In [None]:
DT_MATRIX = [
    [1.5, 0.1, 0, 0, 0, 0],
    [1.5, 0.1, 0, 2, 2, 0],
    [1.5, 0, 3, 2, 0, 0],
    [1.5, 0, 4, 2, 0, 0],
    [1.5, 0, 0, 6, 2, 0]
]

Feedback: IDs (indices) of positive and negative documents

In [None]:
D_POS = [2, 3]
D_NEG = [0, 1, 4]

## Rocchio feedback

Compute the updated query according to:
$$\vec{q}_m = \alpha \vec{q} + \frac{\beta}{|D^+|}\sum_{d \in D^+}\vec{d} - \frac{\gamma}{|D^-|}\sum_{d \in D^-}\vec{d}$$

where
  - $\vec{q}$ is the original query vector
  - $D^+, D^-$ are set of relevant and non-relevant feedback documents
  - $\alpha, \beta, \gamma$ are parameters that control the movement of the original vector

**TODO** Complete the method below. (You may use the global variables `VOCAB` and `DT_MATRIX`.)

In [None]:
def get_updated_query(
    q: List[int], d_pos: List[int], d_neg: List[int],
    alpha: float, beta: float, gamma: float
) -> List[int]:
    """Computes an updated query model using Rocchio feedback.

    Args:
        q: Query vector.
        d_pos: List of positive feedback document IDs.
        d_neg: List of positive feedback document IDs.
        alpha: Feedback parameter alpha.
        beta: Feedback parameter beta.
        gamma: Feedback parameter gamma.

    Returns:
        Updated query vector.
    """
    q_m = [alpha * q_t for q_t in q]

    for t in range(len(q)):
        q_m[t] += (beta / len(d_pos)) * sum(DT_MATRIX[d][t] for d in d_pos)
        q_m[t] -= (gamma / len(d_neg)) * sum(DT_MATRIX[d][t] for d in d_neg)

    return q_m

Tests.

In [None]:
%%run_pytest[clean]

def test_no_expansion():
    q_m = get_updated_query(Q, D_POS, D_NEG, 1, 0, 0)
    assert q_m == Q

def test_expansion():
    q_m = get_updated_query(Q, D_POS, D_NEG, 0.6, 0.2, 0.2)
    assert q_m == pytest.approx([0.600, 0.587, 1.300, 0.467, -0.267, 0], rel=1e-2)