In [10]:
smoking_words = [
     "smoke",
     "cigar",
     "fire",
     "puff",
     "tobacco",
     "cigarette",
     "fume",
     "smog",
     "pot",
     "vapor",
     "vape",
     "cannabis",
     "weed",
     "inhale",
     "whiff",
     "draft",
     "butt",
     "fag"
  ]

In [1]:
from typing import Dict, List
import json

import requests


URL = "https://model-apis.semanticscholar.org/specter/v1/invoke"
MAX_BATCH_SIZE = 16


def chunks(lst, chunk_size=MAX_BATCH_SIZE):
    """Splits a longer list to respect batch size"""
    for i in range(0, len(lst), chunk_size):
        yield lst[i : i + chunk_size]


SAMPLE_PAPERS = [
    {
        "paper_id": "A",
        "title": "Angiotensin-converting enzyme 2 is a functional receptor for the SARS coronavirus",
        "abstract": "Spike (S) proteins of coronaviruses ...",
    },
    {
        "paper_id": "B",
        "title": "Hospital outbreak of Middle East respiratory syndrome coronavirus",
        "abstract": "Between April 1 and May 23, 2013, a total of 23 cases of MERS-CoV ...",
    },
]


def embed(papers):
    embeddings_by_paper_id: Dict[str, List[float]] = {}

    for chunk in chunks(papers):
        # Allow Python requests to convert the data above to JSON
        response = requests.post(URL, json=chunk)

        if response.status_code != 200:
            raise RuntimeError("Sorry, something went wrong, please try later!")

        for paper in response.json()["preds"]:
            embeddings_by_paper_id[paper["paper_id"]] = paper["embedding"]

    return embeddings_by_paper_id


if __name__ == "__main__":
    all_embeddings = embed(SAMPLE_PAPERS)

    # Prints { 'A': [4.089589595794678, ...], 'B': [-0.15814849734306335, ...] }
    print(all_embeddings)

{'A': [4.089589595794678, -6.363525390625, 1.0412890911102295, 3.612004280090332, 2.1087265014648438, -2.6228532791137695, -1.7152503728866577, -0.00752413272857666, 4.46497917175293, -4.820253372192383, 4.729174613952637, -5.887621879577637, 0.9146002531051636, 1.5539332628250122, 1.1271227598190308, -3.1765334606170654, -6.33390474319458, -0.6067850589752197, -5.112493515014648, 0.20254170894622803, 0.7709828615188599, 2.2189793586730957, -3.718961715698242, 2.1455492973327637, 5.062032699584961, 2.823756694793701, 1.4857819080352783, 3.3496930599212646, 3.181285858154297, -3.4674243927001953, -3.667738437652588, 2.369105339050293, -3.153042793273926, -2.2808587551116943, 6.329648971557617, 0.778890073299408, -2.9996585845947266, 4.419855117797852, -1.4674656391143799, -2.5231809616088867, -2.8538272380828857, -5.1504292488098145, -0.4575399160385132, 3.058708667755127, 4.537525177001953, -2.9432787895202637, 4.250761032104492, -2.799833297729492, -3.8617684841156006, 2.3819820880889