# Recommender

In [None]:
import json
from typing import List

import numpy as np
import requests
import torch
from transformers import DistilBertModel, DistilBertTokenizer

In [None]:
test_user_id = 1

In [None]:
model_class, tokenizer_class, pretrained_weights = DistilBertModel, DistilBertTokenizer, 'distilbert-base-cased'

# Load pretrained model/tokenizer
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights)

In [None]:
def embed(sentences: List[str]) -> np.ndarray:
    # Not nice but works:
    # encode twice to figure out max. encoded length of the input sentences
    encoded = [tokenizer.encode(s, add_special_tokens=True) for s in sentences]
    max_encoded_len = max((len(e) for e in encoded))
    encoded = [tokenizer.encode(s, add_special_tokens=True, max_length=max_encoded_len, pad_to_max_length=True) for s in sentences]

    input_ids = torch.tensor(encoded)  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
    with torch.no_grad():
        last_hidden_states = model(input_ids)[0]  # Models outputs are now tuples
    n = torch.mean(last_hidden_states, 1).numpy()
    return n

In [None]:
with requests.get(f"http://sql_app:80/users/{test_user_id}") as r:
    j = json.loads(r.content)

In [None]:
articles_rated_pos = []
for rating in j["ratings"]:
    value = rating["value"]
    article_id = rating["article_id"]
    if value > 0:
        articles_rated_pos.append(article_id)

In [None]:
def get_article_texts(article_ids: List[str]) -> List[str]:
    texts = []
    for a in article_ids:
        with requests.get(f"http://sql_app:80/articles/{a}") as r:
            j = json.loads(r.content)
        title = j["title"].strip()
        if title[-1] not in "!?.":
            title += "."
        summary = j["summary"].strip()
        text = title + " " + summary
        texts.append(text)
    return texts

In [None]:
pos_texts = get_article_texts(articles_rated_pos)
pos_texts_embedded = embed(pos_texts)

In [None]:
pos_texts_embedded.shape

In [None]:
articles_not_rated = []
with requests.get("http://sql_app:80/articles/?skip=600&limit=100") as r:
    j = json.loads(r.content)

for article in j:
    rated_by_user = False
    for rating in article["ratings"]:
        user_id = rating.get("user_id", "")
        if user_id == test_user_id:
            rated_by_user = True
            break
    if not rated_by_user:
        articles_not_rated.append(article["id"])

In [None]:
unrated_texts = get_article_texts(articles_not_rated)
unrated_texts_embedded = embed(unrated_texts)

In [None]:
unrated_texts_embedded.shape