In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import json
import matplotlib.pyplot as plt
import spacy
from tqdm import tqdm
spacy.prefer_gpu()

root_dir = Path('/data/healthy-ml/gobi1/data/MM-IMDB/mmimdb')

nlp = spacy.load("en_core_web_trf")
import en_core_web_trf
nlp = en_core_web_trf.load()

In [None]:
def extract_nouns(x):
    return [i.text.lower().strip() for i in x if i.pos_ == 'NOUN']

In [None]:
raw = []
keep_keys = ['genres', 'title', 'plot', 'plot outline']
for i in root_dir.glob('**/*.json'):
    if i.name == 'split.json':
        continue
    data = json.load(i.open('r'))
    raw.append({
        k: data[k]
        for k in keep_keys if k in data
    })
    raw[-1]['path'] = str(i.with_suffix('.jpeg'))
    raw[-1]['id'] = i.name[:-5]

In [None]:
df = pd.DataFrame(raw)

In [None]:
len(df)

In [None]:
split = json.load((root_dir/'split.json').open('r'))

In [None]:
df['split'] = None
df.loc[df.id.isin(split['train']), 'split'] = 'train'
df.loc[df.id.isin(split['dev']), 'split'] = 'val'
df.loc[df.id.isin(split['test']), 'split'] = 'test'

In [None]:
pd.isnull(df).sum(axis = 0)

In [None]:
df

In [None]:
df['plot outline'] = df['plot outline'].fillna('')
df['plot'] = df['plot'].apply(lambda x: '\n'.join(x))
df['all_text'] = df.apply(lambda x: x['plot outline'] + '\n' + x['plot'], axis = 1)

In [None]:
labels = np.unique([j for i in df['genres'].values for j in i])
label_mapping = {i: c for c, i in enumerate(labels)}

In [None]:
df['cat_labels'] = df['genres'].apply(lambda x: [label_mapping[i] for i in x])

In [None]:
df['spacy_doc'] = [d for d in tqdm(nlp.pipe(df['plot outline'], n_process = 1), total = len(df))]
df['nouns'] = df['spacy_doc'].apply(extract_nouns)
noun_vocab = tuple(set([j for i in df['nouns'] for j in i]))
noun_vocab_mapping = {i: c for c, i in enumerate(noun_vocab)}

In [None]:
df['nouns_int'] = df['nouns'].apply(lambda x: [noun_vocab_mapping[i] for i in x])

In [None]:
df['sentence'] = df['plot outline']

In [None]:
df.drop(columns = ['spacy_doc']).to_pickle(root_dir/'multimodal_mislabel_split.pkl')