In [None]:
import os
import json
from json import JSONDecodeError
import numpy as np
from tqdm import tqdm

from collections import defaultdict

from util.constants import Topic

In [None]:
from sentence_transformers import SentenceTransformer
sbert_model = SentenceTransformer('bert-base-nli-mean-tokens')

Get latents for all videos

In [None]:
# Read data
videos = []
for cat in Topic._member_names_:
    with open(os.path.join("..", "data", "info_videos", F"videos-info_{cat}.json"), "r") as f:
        videos_info = json.load(f)
        videos.extend([vid for channel_vids in videos_info.values() for vid in channel_vids])
videos

In [None]:
title_latent_dir = os.path.join("..", "data", "title-latents")
title_enc_dir = os.path.join(title_latent_dir, "videos")
channel_stats_dir = os.path.join(title_latent_dir, "channels")
def get_done_list(dir):
    return [nm.replace("_title-enc",'').replace(".json",'') for nm in os.listdir(dir)]

In [None]:
# Calculate latents
vid2title_enc = {}
done_list = get_done_list(title_latent_dir)
for vid in tqdm(videos):
    id = vid["id"]
    if id in done_list:
        continue
    vid2title_enc[id] = sbert_model.encode(vid["title"])

In [None]:
# Save latents
for vid_id,enc in vid2title_enc.items():
    path = os.path.join("..", "data", "title-latents", "videos", f"{vid_id}_title-enc.json")
    with open(path, "w") as f:
        json.dump(enc.tolist(), f)

Channel stats

In [None]:
# Read data
with open(os.path.join("..", "data", "vid2channel.json"), "r") as f:
    vid2channel = json.load(f)

In [None]:
# Make list of encodings for each channel
channel_enc_list = defaultdict(list)
for vid_id in tqdm(get_done_list(title_latent_dir)):
    channel = vid2channel[vid_id]
    filepath = os.path.join(title_enc_dir, f"{vid_id}_title-enc.json")
    try:
        with open(filepath, "r") as f:
            enc = json.load(f)
    except JSONDecodeError:
        print(f"couldn't open {vid_id}; deleting file")
        os.remove(filepath)
    channel_enc_list[channel].append(enc)

In [None]:
# Calculate channel stats
channel_stats = {}
for channel,enc_list in channel_enc_list.items():
    enc_list = np.array(enc_list)
    channel_stats[channel] = {
        "mean": enc_list.mean(axis=0).tolist(),
        "std": enc_list.std(),
        "len": len(enc_list),
    }

In [None]:
# Save channel stats
for channel,stats in channel_stats.items():
    filepath = os.path.join("..", "data", "title-latents", "channels", f"{channel}.json")
    with open(filepath, "w") as f:
        json.dump(stats, f)

Category stats

In [None]:
# Read data
with open(os.path.join("..", "data", "channel2category.json"), "r") as f:
    channel2cat = json.load(f)

In [None]:
# Make list of encodings for each category
category_stats_list = defaultdict(list)
for channel in tqdm(get_done_list(channel_stats_dir)):
    cat = channel2cat[channel]
    filepath = os.path.join(channel_stats_dir, f"{channel}.json")
    try:
        with open(filepath, "r") as f:
            stats = json.load(f)
    except JSONDecodeError:
        print(f"couldn't open {channel}; deleting file")
        os.remove(filepath)
    category_stats_list[cat].append(stats)

In [None]:
# Calculate category stats
category_stats = {}
for cat,stats_list in category_stats_list.items():
    mean_list = np.array([channel_stats["mean"] for channel_stats in stats_list])
    std_list = np.array([channel_stats["std"] for channel_stats in stats_list])
    category_stats[cat] = {
        "mean": mean_list.mean(axis=0).tolist(),
        "std": std_list.mean(),
        "len": len(mean_list),
    }

In [None]:
# Save category stats
for cat,stats in category_stats.items():
    filepath = os.path.join("..", "data", "title-latents", "categories", f"{cat}.json")
    with open(filepath, "w") as f:
        json.dump(stats, f)