In [None]:
from transformers import pipeline
from nltk import sent_tokenize
import nltk
import torch
from glob import glob
import pandas as pd
import numpy as np

In [None]:
nltk.download('punkt')

In [None]:
model_name = "facebook/bart-large-mnli"
device = 0 if torch.cuda.is_available() else 'cpu'

In [None]:
def load_model(device):
    theme_classifier = pipeline(
        "zero-shot-classification",
        model=model_name,
        device=device
    )

    return theme_classifier

In [None]:
theme_classifier = load_model(device)

In [None]:
theme_list = ["friendship","hope","sacrifice","battle","self development","betrayal","love","dialogue"]

In [None]:
theme_classifier(
    "I gave him a right hook then a left jab",
    theme_list,
    multi_label=True
)

In [None]:
files = glob('../data/Subtitles/*.ass')

In [None]:
files[:5]

In [None]:
with open(files[0],'r') as file:
    lines = file.readlines()
    lines = lines[27:]
    lines =  [ ",".join(line.split(',')[9:])  for line in lines ]

In [None]:
lines[:2]

In [None]:
lines = [ line.replace('\\N',' ') for line in lines]

In [None]:
lines[:2]

In [None]:
" ".join(lines[:10])

In [None]:
int(files[0].split('-')[-1].split('.')[0].strip())

In [None]:
import os
import pandas as pd
from glob import glob

def load_subtitles_dataset(dataset_path):
    # Find all .ass subtitle files in the dataset_path
    subtitles_paths = glob(os.path.join(dataset_path, '*.ass'))

    scripts = []
    episode_num = []

    for path in subtitles_paths:
        try:
            # Read lines with UTF-8 encoding, handling errors gracefully
            with open(path, 'r', encoding='utf-8', errors='ignore') as file:
                lines = file.readlines()

            # Skip the first 27 lines (assuming metadata) and extract the dialogue part
            lines = lines[27:]
            lines = [",".join(line.split(',')[9:]) for line in lines]

            # Replace '\N' (newlines in ASS subtitles) with a space
            lines = [line.replace('\\N', ' ') for line in lines]

            # Join the lines to form the full script for the episode
            script = " ".join(lines)

            # Extract the episode number from the filename
            episode = int(path.split('-')[-1].split('.')[0].strip())

            # Append script and episode number to the lists
            scripts.append(script)
            episode_num.append(episode)
        
        except Exception as e:
            print(f"Error processing file {path}: {e}")

    # Create a DataFrame with the episode number and script
    df = pd.DataFrame.from_dict({"episode": episode_num, "script": scripts})
    
    return df


In [None]:
dataset_path = "../data/Subtitles"
df = load_subtitles_dataset(dataset_path)

In [None]:
df.head()

In [None]:
script = df.iloc[0]['script']

In [None]:
script

In [None]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

In [None]:
from nltk.tokenize import sent_tokenize
script_sentences = sent_tokenize(script)
print(script_sentences[:3])


In [None]:
# Batch Sentence
sentence_batch_size=20
script_batches = []
for index in range(0,len(script_sentences),sentence_batch_size):
    sent = " ".join(script_sentences[index:index+sentence_batch_size])
    script_batches.append(sent)

In [None]:
script_batches[:2]

In [None]:
theme_output = theme_classifier(
    script_batches[:2],
    theme_list,
    multi_label=True
)

In [None]:
theme_output

In [None]:
# Wrangle Ouput
# battle: [0.51489498, 0.2156498]
themes = {}
for output in theme_output:
    for label,score in zip(output['labels'],output['scores']):
        if label not in themes:
            themes[label] = []
        themes[label].append(score)

In [None]:
themes = {key: np.mean(np.array(value)) for key,value in themes.items()}

In [None]:
themes

In [None]:
def get_themes_inference(script):
    script_sentences = sent_tokenize(script)

    # Batch Sentence
    sentence_batch_size=20
    script_batches = []
    for index in range(0,len(script_sentences),sentence_batch_size):
        sent = " ".join(script_sentences[index:index+sentence_batch_size])
        script_batches.append(sent)
    
    # Run Model
    theme_output = theme_classifier(
        script_batches[:2],
        theme_list,
        multi_label=True
    )

    # Wrangle Output 
    themes = {}
    for output in theme_output:
        for label,score in zip(output['labels'],output['scores']):
            if label not in themes:
                themes[label] = []
            themes[label].append(score)

    themes = {key: np.mean(np.array(value)) for key,value in themes.items()}

    return themes

In [None]:
df = df.head(2)

In [None]:
df

In [None]:
output_themes = df['script'].apply(get_themes_inference)

In [None]:
output_themes

In [None]:
theme_df = pd.DataFrame(output_themes.tolist())

In [None]:
theme_df

In [None]:
df

In [None]:
df[theme_df.columns] = theme_df
df

In [None]:
df = df.drop('dialogue',axis=1)

In [None]:
theme_output = df.drop(['episode','script'],axis=1).sum().reset_index()
theme_output.columns = ['theme','score']
theme_output

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sns.barplot(data = theme_output ,x="theme",y="score" )
plt.xticks(rotation=45)
plt.show()