In [38]:
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
from scipy.sparse import hstack

In [39]:
df = pd.read_csv('tmdb_5000_movies.csv')

In [40]:
# Check for null values
df.isnull().sum()

budget                     0
genres                     0
homepage                3091
id                         0
keywords                   0
original_language          0
original_title             0
overview                   3
popularity                 0
production_companies       0
production_countries       0
release_date               1
revenue                    0
runtime                    2
spoken_languages           0
status                     0
tagline                  844
title                      0
vote_average               0
vote_count                 0
dtype: int64

In [41]:
# Drop rows with missing values in 'genres' column
df = df.dropna(subset=['genres'])

In [42]:
# Fill missing values in 'overview' column with an empty string
df['overview'] = df['overview'].fillna('')

In [43]:
# Convert genres from JSON string to list of genre names
import ast

def parse_genres(genre_str):
    try:
        genres_list = ast.literal_eval(genre_str)
        return [genre['name'] for genre in genres_list]
    except:
        return []

df['genres'] = df['genres'].apply(parse_genres)


In [44]:
# Create a list of unique genres
all_genres = set()
for genres in df['genres']:
    all_genres.update(genres)
all_genres = list(all_genres)

In [45]:
# Create binary columns for each genre
for genre in all_genres:
    df[genre] = df['genres'].apply(lambda x: 1 if genre in x else 0)

In [46]:
# Feature and target variables
X = df[['original_title', 'overview']]
y = df[all_genres]

In [47]:
# Text vectorization
vectorizer_title = TfidfVectorizer(stop_words='english')
vectorizer_overview = TfidfVectorizer(stop_words='english')

In [48]:
# Fit and transform text data
X_title = vectorizer_title.fit_transform(X['original_title'])
X_overview = vectorizer_overview.fit_transform(X['overview'])

In [49]:

# Combine title and overview features
X_combined = hstack([X_title, X_overview])

# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X_combined, y, test_size=0.3, random_state=42)

In [50]:
# Train a separate classifier for each genre
classifiers = {}
for genre in all_genres:
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train[genre])
    classifiers[genre] = clf

In [51]:

# Make predictions
y_pred = pd.DataFrame()
for genre, clf in classifiers.items():
    y_pred[genre] = clf.predict(X_test)

In [52]:
# Evaluate the model
print("Classification Report:")
for genre in all_genres:
    print(f"\nGenre: {genre}")
    print(classification_report(y_test[genre], y_pred[genre]))

Classification Report:

Genre: Animation
              precision    recall  f1-score   support

           0       0.96      1.00      0.98      1384
           1       0.00      0.00      0.00        57

    accuracy                           0.96      1441
   macro avg       0.48      0.50      0.49      1441
weighted avg       0.92      0.96      0.94      1441


Genre: Documentary
              precision    recall  f1-score   support

           0       0.98      1.00      0.99      1407
           1       0.00      0.00      0.00        34

    accuracy                           0.98      1441
   macro avg       0.49      0.50      0.49      1441
weighted avg       0.95      0.98      0.96      1441


Genre: Foreign
              precision    recall  f1-score   support

           0       0.99      1.00      1.00      1430
           1       0.00      0.00      0.00        11

    accuracy                           0.99      1441
   macro avg       0.50      0.50      0.50      14

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

              precision    recall  f1-score   support

           0       0.90      1.00      0.95      1300
           1       0.00      0.00      0.00       141

    accuracy                           0.90      1441
   macro avg       0.45      0.50      0.47      1441
weighted avg       0.81      0.90      0.86      1441


Genre: Drama
              precision    recall  f1-score   support

           0       0.70      0.73      0.71       762
           1       0.68      0.65      0.66       679

    accuracy                           0.69      1441
   macro avg       0.69      0.69      0.69      1441
weighted avg       0.69      0.69      0.69      1441


Genre: Horror
              precision    recall  f1-score   support

           0       0.90      1.00      0.94      1287
           1       1.00      0.03      0.05       154

    accuracy                           0.90      1441
   macro avg       0.95      0.51      0.50      1441
weighted avg       0.91      0.90      0.85  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
