In [1]:
! pip install sentence_transformers feedparser

Collecting sentence_transformers
  Using cached sentence_transformers-2.2.2-py3-none-any.whl
Collecting feedparser
  Using cached feedparser-6.0.10-py3-none-any.whl (81 kB)
Collecting transformers<5.0.0,>=4.6.0
  Using cached transformers-4.24.0-py3-none-any.whl (5.5 MB)
Collecting sentencepiece
  Using cached sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
Collecting nltk
  Using cached nltk-3.7-py3-none-any.whl (1.5 MB)
Collecting huggingface-hub>=0.4.0
  Using cached huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
Collecting sgmllib3k
  Using cached sgmllib3k-1.0.0-py3-none-any.whl
Collecting regex!=2019.12.17
  Using cached regex-2022.10.31-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (769 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Using cached tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
Installing collected packages: tokenizers, sgmllib3k, sentencepiece, regex, feedparser, nltk, hug

In [2]:
from pathlib import Path
from tqdm.auto import tqdm
import pickle
import json
from sentence_transformers import SentenceTransformer

In [3]:
def prepare_segments(segments, window=7, stride=1):
    data = []
    for j in range(0, len(segments), stride):
        j_end = min(j+window, len(segments)-1)
        text = ''.join([x["text"] for x in segments[j:j_end]])
        start = segments[j]['start']
        end = segments[j_end]['end']
        row_id = f"{path.stem}-t{segments[j]['start']}"
        meta = {
                    "id": row_id,
                    "text": text.strip(),
                    "start": start,
                    "end": end
        }
        data.append(meta)
    return data


def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]
        
        
from functools import lru_cache

@lru_cache(100)
def load_meta(feedtitle):
    FEED_META = list(Path("feedmeta").glob("*.pickle"))
    with open([x for x in FEED_META if x.stem.startswith(feedtitle)][0], "rb") as f:
        return pickle.load(f)
    
    
def filter_audio_link(link):
    if link["type"].startswith("audio"):
        return True
    
def get_episode_meta(feedtitle, episode_link):
    meta = load_meta(feedtitle)
    episode_meta = [
        x for x in meta["entries"] 
        if list(filter(filter_audio_link, x["links"]))[0]["href"].split("/")[-1].startswith(episode_link)
    ][0]
    episode_meta["feed_meta"] = meta["feed"]
    return episode_meta


In [4]:
model_id = "multi-qa-mpnet-base-dot-v1"
model = SentenceTransformer(model_id)

In [5]:
SEGMENTS_PATH = Path("segments")
SEGMENTS_PATH.mkdir(exist_ok=True)

In [6]:
paths = list(Path("transcriptions").glob("**/*.pickle"))
batch_size = 10

for path in tqdm(paths):
    output_folder = SEGMENTS_PATH / path.parts[1] 
    output_folder.mkdir(exist_ok=True)
    output_path = output_folder / path.name
    if not output_path.exists():
        embeddings = []
        with open(path, "rb") as f:
            segments = pickle.load(f)["segments"]

        meta_data = get_episode_meta(
            feedtitle=path.parts[1], 
            episode_link=path.stem
        )
        prepared_segments = prepare_segments(segments, window=10, stride=3)
        embeddings = []
        for batch_segments in tqdm(batch(prepared_segments, batch_size), total=len(prepared_segments)/batch_size, leave=False):
            embeddings.append(model.encode([x["text"] for x in batch_segments]))
        for idx, embedding in enumerate([x for y in embeddings for x in y]):
            prepared_segments[idx]["embedding"] = embedding
            prepared_segments[idx]["meta"] = meta_data

        with open(output_path, "wb") as f:
            pickle.dump(prepared_segments, f)

  0%|          | 0/1121 [00:00<?, ?it/s]

IndexError: list index out of range