In [5]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.preprocessing import LabelEncoder
import os

# Load the data from the provided CSV file
file_path = '/Users/aditya/Downloads/wiki_movie_plots_processed.csv'

# Check if the file exists
if not os.path.exists(file_path):
    raise FileNotFoundError(f"The file {file_path} does not exist")

movies = pd.read_csv(file_path)

# Display the first few rows to understand the structure
print(movies.head())

# Filter data to include only American movies
american_movies = movies[movies['Origin/Ethnicity'] == 'American']

# Display the first few rows to understand the filtered data
print(american_movies.head())

# Replace various separators with a single separator
american_movies['GenreCorrected'] = american_movies['Genre'].str.replace(r'[\s/-]+','|', regex=True)

# Take the first genre as the primary genre for simplicity
american_movies['PrimaryGenre'] = american_movies['GenreCorrected'].str.split('|').str[0].str.strip().str.lower()

# Count the occurrences of each genre
genre_counts = american_movies['PrimaryGenre'].value_counts()

# Identify the top 10 genres
top_10_genres = genre_counts.head(10).index
print("Top 10 Genres:\n", top_10_genres)

# Filter the data to include only the top 10 genres
top_10_american_movies = american_movies[american_movies['PrimaryGenre'].isin(top_10_genres)]

# Encode the primary genres as labels
le = LabelEncoder()
top_10_american_movies['PrimaryGenreEncoded'] = le.fit_transform(top_10_american_movies['PrimaryGenre'])

# Split the data into features (plot summaries) and target (primary genre)
X = top_10_american_movies['Plot']
y = top_10_american_movies['PrimaryGenreEncoded']

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert plot summaries to numerical features using TF-IDF
vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), stop_words='english')
X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

# Train a Multinomial Naive Bayes classifier
nb_model = MultinomialNB()
nb_model.fit(X_train_tfidf, y_train)

# Make predictions on the test set
y_pred = nb_model.predict(X_test_tfidf)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')

# Print the evaluation metrics
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

# Generate the classification report using the unique labels in y_test
unique_labels = sorted(set(y_test))
target_names = le.inverse_transform(unique_labels)

print("Classification Report:\n", classification_report(y_test, y_pred, labels=unique_labels, target_names=target_names))


   Release Year                             Title Origin/Ethnicity  \
0          1901            Kansas Saloon Smashers         American   
1          1901     Love by the Light of the Moon         American   
2          1901           The Martyred Presidents         American   
3          1901  Terrible Teddy, the Grizzly King         American   
4          1902            Jack and the Beanstalk         American   

                             Director Cast    Genre  \
0                             Unknown  NaN  unknown   
1                             Unknown  NaN  unknown   
2                             Unknown  NaN  unknown   
3                             Unknown  NaN  unknown   
4  George S. Fleming, Edwin S. Porter  NaN  unknown   

                                           Wiki Page  \
0  https://en.wikipedia.org/wiki/Kansas_Saloon_Sm...   
1  https://en.wikipedia.org/wiki/Love_by_the_Ligh...   
2  https://en.wikipedia.org/wiki/The_Martyred_Pre...   
3  https://en.wikipedia.

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  american_movies['GenreCorrected'] = american_movies['Genre'].str.replace(r'[\s/-]+','|', regex=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  american_movies['PrimaryGenre'] = american_movies['GenreCorrected'].str.split('|').str[0].str.strip().str.lower()
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus

Accuracy: 0.5058873002523129
Precision: 0.5117393494043431
Recall: 0.5058873002523129
F1 Score: 0.44682811122015753
Classification Report:
               precision    recall  f1-score   support

      action       0.76      0.13      0.22       124
   adventure       0.67      0.06      0.11       104
      comedy       0.49      0.68      0.57       647
     comedy,       0.00      0.00      0.00        82
       crime       0.65      0.09      0.16       139
       drama       0.45      0.77      0.56       718
      horror       0.88      0.39      0.54       174
     musical       0.00      0.00      0.00       110
    thriller       0.00      0.00      0.00        89
     western       0.89      0.57      0.70       191

    accuracy                           0.51      2378
   macro avg       0.48      0.27      0.29      2378
weighted avg       0.51      0.51      0.45      2378



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
