In [None]:
%cd ../..
%load_ext autoreload

%autoreload 2

In [None]:
from emv.db.dao import DataAccessObject
from emv.db.queries import get_features_by_type_paginated, count_features_by_type
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ast import literal_eval
import numpy as np
from tqdm import tqdm
from emv.api.models import Feature
from emv.api.models import Projection, MapProjectionFeatureCreate
from emv.db.queries import create_projection, create_map_projection_feature, create_feature
from emv.io.media import create_square_atlases
from umap import UMAP
import numba
import cv2
from PIL import Image
from sqlalchemy.sql import text
from datetime import datetime
import textwrap as tw

from emv.db.queries import get_all_media_by_library_id, get_library_id_from_name, get_library_from_name, check_media_exists, get_media_by_id, delete_feature_by_type
from emv.storage.storage import get_storage_client
from emv.features.image import embed_images

from transformers import pipeline

# Load data

In [None]:
total_features = count_features_by_type("transcript+ner", short_clips_only=True)
print(f"Total features: {total_features}")

In [None]:
MAX_FEATURES = total_features + 1
data = get_features_by_type_paginated("transcript+ner", page_size=10000, short_clips_only=True)

for _ in tqdm(range(MAX_FEATURES // 10000)):
    last_seen_id = data[-1].get("feature_id", None)
    if last_seen_id is None:
        break
    data.extend(get_features_by_type_paginated("transcript+ner", page_size=10000, last_seen_feature_id=last_seen_id, short_clips_only=True))

In [None]:
# Drop fields not needed
df = []
for d in tqdm(data):
    df.append(
        {
            "feature_id": d["feature_id"],
            "media_id": d["media_id"],
            "data": d["data"]
        }
    )
    
df = pd.DataFrame(df)
df = df.dropna()
df = df.reset_index(drop=True)
print(f"Retrieved {len(df)} instances")

In [None]:
MAX_FEATURES = 100000
PAGE_SIZE = 10000
features = get_features_by_type_paginated("locations", page_size=PAGE_SIZE)

for _ in tqdm(range(MAX_FEATURES // PAGE_SIZE)):
    last_seen_id = features[-1].get("feature_id", None)
    if last_seen_id is None:
        break
    features.extend(get_features_by_type_paginated("locations", page_size=PAGE_SIZE, last_seen_feature_id=last_seen_id))
    
features = pd.DataFrame(features)
print(f"Retrieved {len(features)} instances")

In [None]:
df.rename(columns={"data": "transcript_data"}, inplace=True)
features = features.merge(df, on="media_id", how="left")

In [None]:
features["transcript"] = features["transcript_data"].map(lambda x: x.get("transcript", None))
features["entities"] = features["transcript_data"].map(lambda x: x.get("entities", None))

In [None]:
models = [
    "tabularisai/multilingual-sentiment-analysis",
    "SamLowe/roberta-base-go_emotions",
    "nlptown/bert-base-multilingual-uncased-sentiment",
    "lxyuan/distilbert-base-multilingual-cased-sentiments-student"
]
pipe = pipeline("text-classification", model=models[3], return_all_scores=True)

In [None]:
for t in features.transcript.values[:10]:
    if t is None:
        continue
    try:
        result = pipe(t)
        print(result)
        print(tw.fill(t, 100))
        print()
    except Exception as e:
        print(f"Error processing transcript: {e}")

## Theme classification

In [None]:
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")

In [None]:
themes_possibles = [
    "Information & Actualité",
    "Débats & Talk Shows",
    "Société & Monde",
    "Culture & Connaissance",
    "Arts & Spectacles",
    "Musique",
    "Sport",
    "Fiction & Divertissement",
    "Jeunesse",
    "Religion & Spiritualité"
]

In [None]:
sample = features.sample(10000)

In [None]:
sample["theme_llm"] = sample.transcript.map(lambda x: classifier(x, themes_possibles))

In [None]:
sample["theme"] = sample["theme_llm"].map(lambda x: x["labels"][0])
sample["theme_score"] = sample["theme_llm"].map(lambda x: x["scores"][0])

In [None]:
sample.theme.value_counts()

In [None]:
sample.theme_score.hist(bins=20)

In [None]:
sample[sample.theme_score > 0.8].theme.value_counts()

In [None]:
sample.head()