In [1]:
from transformers import pipeline
from nltk import sent_tokenize
import nltk
import torch
from glob import glob
import pandas as pd
import numpy as np

In [3]:
nltk.download('punkt')

# Load Model

In [6]:
model_name = "facebook/bart-large-mnli"
device = 0 if torch.cuda.is_available() else 'cpu'

In [7]:
def load_model(device):
    theme_classifier = pipeline(
        "zero-shot-classification",
        model=model_name,
        device=device
    )

    return theme_classifier

In [8]:
theme_classifier = load_model(device)

In [9]:
theme_list = ["friendship","hope", "sacrifice", "battle", "self development", "betrayal","love","dialogue"]

In [None]:
theme_classifier(
    "I gave him a right hook then a left jab",
    theme_list,
    multi_label=True
)

# Load Dataset

In [None]:
import os
if os.path.exists('./data/original/subtitlist'):
    print("OK")
else:
    print("NO")

files = glob('data/original/subtitlist/*.ass')
print(len(files))

In [None]:
files[:5]

In [None]:
with open(files[0],'r') as file:
    lines = file.readlines()
    lines = lines[27:]
    lines =  [ ",".join(line.split(',')[9:])  for line in lines ]

In [None]:
lines[:2]

In [18]:
lines = [ line.replace('\\N',' ') for line in lines]

In [None]:
lines[:2]

In [None]:
" ".join(lines[:10])

In [None]:
int(files[0].split('-')[-1].split('.')[0].strip())

In [29]:
def load_subtitles_dataset(dataset_path):
    subtitles_paths = glob(dataset_path+'/*.ass')

    scripts=[]
    episode_num=[]

    for path in subtitles_paths:

        #Read Lines
        with open(path,'r') as file:
            lines = file.readlines()
            lines = lines[27:]
            lines =  [ ",".join(line.split(',')[9:])  for line in lines ]
        
        lines = [ line.replace('\\N',' ') for line in lines]
        script = " ".join(lines)

        episode = int(path.split('-')[-1].split('.')[0].strip())

        scripts.append(script)
        episode_num.append(episode)

    df = pd.DataFrame.from_dict({"episode":episode_num, "script":scripts })
    return df

In [31]:
dataset_path = "../data/Subtitles"
df = load_subtitles_dataset(dataset_path)

In [None]:
df.head()

Run Model

In [33]:
script = df.iloc[0]['script']

In [None]:
script

In [None]:
script_sentences = sent_tokenize(script)
script_sentences[:3]

In [39]:
# Batch Sentence
sentence_batch_size=20
script_batches = []
for index in range(0,len(script_sentences),sentence_batch_size):
    sent = " ".join(script_sentences[index:index+sentence_batch_size])
    script_batches.append(sent)

In [None]:
script_batches[:2]

In [None]:
theme_output = theme_classifier(
    script_batches[:2],
    theme_list,
    multi_label=True
)

In [None]:
theme_output

In [43]:
# Wrangle Ouput
# battle: [0.51489498, 0.2156498]
themes = {}
for output in theme_output:
    for label,score in zip(output['labels'],output['scores']):
        if label not in themes:
            themes[label] = []
        themes[label].append(score)

In [49]:
themes = {key: np.mean(np.array(value)) for key,value in themes.items()}

In [None]:
themes

In [51]:
def get_themes_inference(script):
    script_sentences = sent_tokenize(script)

    # Batch Sentence
    sentence_batch_size=20
    script_batches = []
    for index in range(0,len(script_sentences),sentence_batch_size):
        sent = " ".join(script_sentences[index:index+sentence_batch_size])
        script_batches.append(sent)
    
    # Run Model
    theme_output = theme_classifier(
        script_batches[:2],
        theme_list,
        multi_label=True
    )

    # Wrangle Output 
    themes = {}
    for output in theme_output:
        for label,score in zip(output['labels'],output['scores']):
            if label not in themes:
                themes[label] = []
            themes[label].append(score)

    themes = {key: np.mean(np.array(value)) for key,value in themes.items()}

    return themes

In [52]:
df = df.head(2)

In [None]:
df

In [None]:
output_themes = df['script'].apply(get_themes_inference)

In [None]:
output_themes

In [58]:
theme_df = pd.DataFrame(output_themes.tolist())

In [None]:
theme_df

In [None]:
df

In [None]:
df[theme_df.columns] = theme_df
df

# Visualize ouput

In [63]:
df = df.drop('dialogue',axis=1)

In [None]:
theme_output = df.drop(['episode','script'],axis=1).sum().reset_index()
theme_output.columns = ['theme','score']
theme_output

In [68]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sns.barplot(data = theme_output ,x="theme",y="score" )
plt.xticks(rotation=45)
plt.show()