In [7]:
import os
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pymongo
from tqdm import tqdm
from scipy import stats
import math
import numpy as np

In [8]:
myclient = pymongo.MongoClient("mongodb://localhost:27017/")
db = myclient["imdb"]
imdb = db['imdb_cleaned']

In [9]:
num_movies = imdb.count_documents({})
num_movies

8181

In [13]:
budget_arr = []
revenue_arr = []

for movie in imdb.find():
    budget = movie['cleaned_Budget']
    revenue = movie['cleaned_Revenue']

    budget_arr.append(budget)
    revenue_arr.append(revenue)

In [14]:
budget_arr = np.array(budget_arr)
budget_arr = budget_arr/budget_arr.max()

revenue_arr = np.array(revenue_arr)
revenue_arr = revenue_arr/revenue_arr.max()

# Linear Regression

In [15]:
from scipy.stats import spearmanr

In [16]:
spearmanr(budget_arr, revenue_arr)

SpearmanrResult(correlation=0.7434469021284774, pvalue=0.0)

In [17]:
import statsmodels.api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std

In [18]:
budget_x = sm.add_constant(budget_arr)
model = sm.OLS(revenue_arr, budget_x)
results = model.fit()

In [19]:
results.summary()

0,1,2,3
Dep. Variable:,y,R-squared:,0.44
Model:,OLS,Adj. R-squared:,0.44
Method:,Least Squares,F-statistic:,6435.0
Date:,"Fri, 21 May 2021",Prob (F-statistic):,0.0
Time:,21:32:15,Log-Likelihood:,13887.0
No. Observations:,8181,AIC:,-27770.0
Df Residuals:,8179,BIC:,-27760.0
Df Model:,1,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
const,-0.0036,0.001,-5.666,0.000,-0.005,-0.002
x1,0.3235,0.004,80.221,0.000,0.316,0.331

0,1,2,3
Omnibus:,9241.878,Durbin-Watson:,1.726
Prob(Omnibus):,0.0,Jarque-Bera (JB):,1647354.689
Skew:,5.605,Prob(JB):,0.0
Kurtosis:,71.608,Cond. No.,8.31


In [20]:
from statsmodels.stats.outliers_influence import summary_table

In [21]:
def plot_regression_line(results, xrr, yrr):

    st, data, ss2 = summary_table(results, alpha=0.05)

    fittedvalues = data[:, 2]
    predict_mean_se  = data[:, 3]
    predict_mean_ci_low, predict_mean_ci_upp = data[:, 4:6].T
    predict_ci_low, predict_ci_upp = data[:, 6:8].T

    fig = make_subplots()
    fig.add_trace(
        go.Scatter(
            x=xrr, 
            y=yrr, 
            name="Data",
            mode='markers',
            line_color='rgba(153, 153, 255, .6)'
        ),
    )
    fig.add_trace(
        go.Scatter(
            x=xrr, 
            y=fittedvalues, 
            name="regression line",
            # mode='lines',
            # line_color='green'
            line=dict(color='green', width=3)
        ),
    )
    fig.add_trace(
        go.Scatter(
            x=xrr, 
            y=predict_ci_low, 
            name="95% prediction band",
            line=dict(color='rgba(153, 0, 51, .5)', width=1, dash='dot')
        ),
    )
    fig.add_trace(
        go.Scatter(
            x=xrr, 
            y=predict_ci_upp, 
            line=dict(color='rgba(153, 0, 51, .5)', width=1, dash='dot'),
            showlegend=False
        ),
    )

    fig.add_trace(
        go.Scatter(
            x=xrr, 
            y=predict_mean_ci_low, 
            name="95% prediction region",
            mode='lines',
            line=dict(color='rgba(0, 153, 51, .5)', width=1, dash='dot'),
        ),
    )
    fig.add_trace(
        go.Scatter(
            x=xrr, 
            y=predict_mean_ci_upp, 
            line=dict(color='rgba(0, 153, 51, .5)', width=1, dash='dot'),
            showlegend=False
        ),
    )

    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ))
    fig.update_layout(
        height=600, 
        width=800,
    )
    # fig.update_layout(
    #     # margin=dict(l=10, r=10, t=10, b=10),
    #     paper_bgcolor="LightSteelBlue",
    # )
    fig.update_layout(
        title="Plot Title",
        xaxis_title="X Axis Title",
        yaxis_title="Y Axis Title",
        legend_title="Legend Title",
    )
    
    return fig

In [22]:
fig = plot_regression_line(results, budget_arr, revenue_arr)
fig.show()