In [None]:
import os
import json
import torch

from json import JSONDecodeError
import numpy as np
from tqdm import tqdm
from PIL import Image
# from thumbnail_repr_stats import ImageLatentRepresentationModel, load_model, get_latent_vectors
from transformers import ViTForImageClassification, ViTFeatureExtractor
from collections import defaultdict
from util.constants import Topic

In [None]:
# Models

class ImageLatentRepresentationModel(ViTForImageClassification):
    """
    Hook into the ViTForImageClassification Class in order to get the latent
    representations of an image, not just the classification output. Source code
    taken from HuggingFace open-source GitHub:
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
    """

    def __init__(self, config):
        super().__init__(config)

    def forward(self, pixel_values):
        """
        Overwritten forward method to only get latent representation of the image,
        without image classification.

        args:
            - pixel_values: input image, as PyTorch Tensor of shape [1,3,224,224] 
        
        returns:
            - latent_vec: latent vector representing the image
        """
        vit_output = self.vit(pixel_values)
        latent_vec = vit_output[0][:,0,:]

        return latent_vec

Get latents for all videos

In [None]:
# Read data
videos = []
for cat in Topic._member_names_:
    with open(os.path.join("..", "data", "info_videos", F"videos-info_{cat}.json"), "r") as f:
        videos_info = json.load(f)
        videos.extend([vid for channel_vids in videos_info.values() for vid in channel_vids])

In [None]:
RESULTS_DIR = os.path.join("..", "data", "thumbnail-latents")

video_results_dir = os.path.join(RESULTS_DIR, "videos")
channel_results_dir = os.path.join(RESULTS_DIR, "channels")
categories_results_dir = os.path.join(RESULTS_DIR, "categories")

if not os.path.exists(video_results_dir):
    os.makedirs(video_results_dir)

if not os.path.exists(categories_results_dir):
    os.makedirs(categories_results_dir)

def get_done_list(dir):
    return [nm.replace(".json",'') for nm in os.listdir(dir)]

In [None]:
def load_model(device, install=True):
    """
    Loads in a pre-trained ViT model for latent image representation

    args:
        - device: what device the model should be on
    returns:
        - model: pre-trained ViT model
        - feature_extractor: pre-trained feature extractor for processing images
    """
    if install:
        print("Installing ViT architecture... This may take a couple of minutes.")
        os.system("pip install -q git+https://github.com/huggingface/transformers.git")
        print("Finished installing.")

    print("Loading pretrained ViT model...")
    model = ImageLatentRepresentationModel.from_pretrained('google/vit-base-patch16-224')
    model.eval()
    model.to(device)
    print("Loaded model")

    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

    return model, feature_extractor

In [None]:
def get_latent_vectors(model, ft_extr, raw_imgs, device):
    """
    Function to get the latent representation of an image.

    args:
        - model: The pretrained ViT model
        - ft_extr: The feature extractor, used to preprocess the images
        - raw_imgs: The raw thumbnail images
    """
    # Encode the images using the feature extractor
    encodings = ft_extr(images=raw_imgs, return_tensors="pt")
    pixel_values = encodings['pixel_values'].to(device)
#     print(f"Size of images in memory: {pixel_values.element_size()*pixel_values.nelement()} bytes")

    # Get the latent representation by passing it through the network
    latent_vecs = model(pixel_values)
    del pixel_values
    return latent_vecs

In [None]:
def generate_repr_stats(model, feature_extractor, videos_path):
    """
    Function to generate all thumbnail latent representation statistics.

    args:
        - model: pre-trained ViT model
        - feature_extractor: pre-trained feature extractor for processing images
        - videos_path: path to the thumbnail that we want to load
    """

    video_thumbnails = [Image.open(vid_path) for vid_path in videos_path]
    latents = get_latent_vectors(model, feature_extractor, video_thumbnails, device)
    latents_save = latents.detach().cpu().numpy().tolist()
    del latents
    return latents_save


In [None]:
# Run the code in batches
batch_size = 32
batch_num = len(videos)//batch_size
if batch_num != int(len(videos)/batch_size):
    batch_num += 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

model, feature_extractor = load_model(device, install=False)

# done_list = get_done_list(video_results_dir)
for batch in tqdm(range(batch_num)):

    vid_batch = videos[batch*batch_size:(batch+1)*batch_size]
    ids = [vid["id"] for vid in vid_batch if os.path.isfile("../data/thumbnails/"+vid["id"]+"_high.jpg")]
    # if id in done_list:
    #     continue
    imgs_paths = ["../data/thumbnails/"+vid_id+"_high.jpg" for vid_id in ids
                  if os.path.isfile("../data/thumbnails/"+vid_id+"_high.jpg")]

    # Prevent code from crashing when not having any thumbnails available
    if imgs_paths:
        batch_latents = generate_repr_stats(model, feature_extractor, imgs_paths)

        for vid_id, result in zip(ids, batch_latents):
            path = os.path.join(video_results_dir, f"{vid_id}.json")
            with open(path, "w") as f:
                json.dump(result, f)

    torch.cuda.empty_cache()

Channel stats

In [None]:
# Read data
with open(os.path.join("..", "data", "vid2channel.json"), "r") as f:
    vid2channel = json.load(f)

In [None]:
# Make list of results per video for each channel
channel_result_list = defaultdict(list)
for vid_id in tqdm(get_done_list(video_results_dir)):
    channel = vid2channel[vid_id]
    filepath = os.path.join(video_results_dir, f"{vid_id}.json")
    try:
        with open(filepath, "r") as f:
            result = json.load(f)
    except JSONDecodeError:
        print(f"couldn't open {vid_id}; deleting file")
        os.remove(filepath)
    channel_result_list[channel].append(result)

In [None]:
# Calculate channel results
channel_results = {}
for channel,result_list in channel_result_list.items():
    result_list = np.array(result_list)
    channel_results[channel] = {
        "mean": result_list.mean(axis=0).tolist(),
        "std": result_list.std(),
        "len": len(result_list),
    }

In [None]:
# Save channel stats
for channel,results in channel_results.items():
    filepath = os.path.join(channel_results_dir, f"{channel}.json")
    with open(filepath, "w") as f:
        json.dump(results, f)

Category stats

In [None]:
# Read data
with open(os.path.join("..", "data", "channel2category.json"), "r") as f:
    channel2cat = json.load(f)

In [None]:
# Make list of results per channel for each category
category_results_list = defaultdict(list)
for channel in tqdm(get_done_list(channel_results_dir)):
    cat = channel2cat[channel]
    filepath = os.path.join(channel_results_dir, f"{channel}.json")
    try:
        with open(filepath, "r") as f:
            results = json.load(f)
    except JSONDecodeError:
        print(f"couldn't open {channel}; deleting file")
        os.remove(filepath)
    category_results_list[cat].append(results)

In [None]:
# Calculate category results
category_results = {}
for cat,stats_list in category_results_list.items():
    mean_list = np.array([channel_stats["mean"] for channel_stats in stats_list])
    std_list = np.array([channel_stats["std"] for channel_stats in stats_list])
    category_results[cat] = {
        "mean": mean_list.mean(axis=0).tolist(),
        "std": std_list.mean(),
        "len": len(mean_list),
    }

In [None]:
# Save category results
for cat,stats in category_results.items():
    filepath = os.path.join(RESULTS_DIR, "categories", f"{cat}.json")
    with open(filepath, "w") as f:
        json.dump(stats, f)