In [None]:
import os
import json
from json import JSONDecodeError
import numpy as np
from tqdm import tqdm
import torch

from collections import defaultdict

from util.constants import Topic

In [None]:
# Models
from sentence_transformers import SentenceTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sbert_model = SentenceTransformer('bert-base-nli-mean-tokens').to(device)

def model(titles):
    return sbert_model.encode(titles)

Get latents for all videos

In [None]:
# Read data
videos = []
for cat in Topic._member_names_:
    with open(os.path.join("..", "data", "info_videos", F"videos-info_{cat}.json"), "r") as f:
        videos_info = json.load(f)
        videos.extend([vid for channel_vids in videos_info.values() for vid in channel_vids])

In [None]:
RESULTS_DIR = os.path.join("..", "data", "title-latents")

video_results_dir = os.path.join("..","..","DATA","title-latents","videos")
channel_results_dir = os.path.join(RESULTS_DIR, "channels")

if not os.path.exists(video_results_dir):
    os.makedirs(video_results_dir)

def get_done_list(dir):
    return [nm.replace(".json",'').replace('.pt','') for nm in os.listdir(dir)]

In [None]:
# Run the code in batches
done_list = get_done_list(video_results_dir)
vid2title = {v["id"]:v["title"] for v in videos}
all_ids = [vid["id"] for vid in tqdm(videos)]
todo_ids = list(set(all_ids).difference(done_list))

batch_size = 512
batch_num = len(todo_ids)//batch_size
if batch_num != int(len(todo_ids)/batch_size):
    batch_num += 1

for batch in tqdm(range(batch_num)):
    ids = todo_ids[batch*batch_size:(batch+1)*batch_size]
    titles = [vid2title[id] for id in ids]

    batch_latents = model(titles)

    for vid_id, result in zip(ids, batch_latents):
        path = os.path.join(video_results_dir, f"{vid_id}.pt")
        torch.save(result, path)

    torch.cuda.empty_cache()

Channel stats

In [None]:
# Read data
channel_videos_dict = {}
for cat in Topic._member_names_:
    with open(os.path.join("..", "data", "info_videos", F"videos-info_{cat}.json"), "r") as f:
        channel_videos_dict.update(json.load(f))

In [None]:
# Calculate channel results
done_list = get_done_list(video_results_dir)
for channel,videos in tqdm(channel_videos_dict.items()):
    result_list = []
    for vid in videos:
        vid_id = vid["id"]
        filepath = os.path.join(video_results_dir, vid_id+".pt")
        try:
            result_list.append(torch.load(filepath))
        except JSONDecodeError:
            print(f"couldn't open {vid_id}; deleting file")
            os.remove(filepath)
        except FileNotFoundError:
            print(f"couldn't find {vid_id}")

    result_list = np.array(result_list)
    channel_mean = result_list.mean(axis=0)
    channel_result = {
        "std": float(result_list.std()),
        "len": len(result_list),
    }

    filepath = os.path.join(channel_results_dir, f"{channel}.json")
    with open(filepath, "w") as f:
        json.dump(channel_result, f)
    filepath = os.path.join(RESULTS_DIR, "channels_mean", f"{channel}.pt")
    torch.save(channel_mean, filepath)

Category stats

In [None]:
# Read data
with open(os.path.join("..", "data", "channel2category.json"), "r") as f:
    channel2cat = json.load(f)

In [None]:
# Make list of results per channel for each category
category_results_list = defaultdict(list)
for channel in tqdm(get_done_list(channel_results_dir)):
    cat = channel2cat[channel]
    results = {}
    for folder,ext in [("channels",".json"), ("channels_mean",".pt")]:
        filepath = os.path.join(RESULTS_DIR, folder, channel+ext)
        try:
            if ext == ".json":
                with open(filepath, "r") as f:
                    results.update(json.load(f))
            elif ext == ".pt":
                results.update({"mean": torch.load(filepath)})
        except JSONDecodeError:
            print(f"couldn't open {channel}; deleting file")
            os.remove(filepath)
    category_results_list[cat].append(results)

In [None]:
# Calculate category results
for cat,stats_list in category_results_list.items():
    mean_list = np.array([channel_stats["mean"] for channel_stats in stats_list])
    std_list = np.array([channel_stats["std"] for channel_stats in stats_list])
    category_mean = mean_list.mean(axis=0)
    category_result = {
        "std": std_list.mean(),
        "len": len(mean_list),
    }

    filepath = os.path.join(RESULTS_DIR, "categories", f"{cat}.json")
    with open(filepath, "w") as f:
        json.dump(category_result, f)
    filepath = os.path.join(RESULTS_DIR, "categories_mean", f"{cat}.pt")
    torch.save(category_mean, filepath)