In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import tree
import graphviz
import pydot

In [2]:
def partition(x):
    partitions = {}
    i = 0
    for x_i in x:
        if isinstance(x_i, np.ndarray) or isinstance(x_i, list):
            x_i = tuple(x_i)
        if x_i in partitions:
            partitions[x_i].append(i)
        else:
            partitions[x_i] = [i]
        i += 1
    return partitions

In [3]:
def entropy(y):
    entropy=0
    
    y_part=partition(y)

    for key in y_part.keys():
        fraction = len(y_part[key])/len(y)
        entropy += -(fraction)*np.log2(fraction)

    return entropy

In [4]:
def mutual_information(x, y):
    h_y = entropy(y)
    p_x = partition(x)

    h_y_x = 0
    for v_i in p_x:
        # selecting new vector from y for each unique value of x
        vec_vxi = y[[i for i in p_x[v_i]]]
        h_y_x += len(p_x[v_i])/len(x)*entropy(vec_vxi)

    # Compute the mutual information
    I_xy = h_y - h_y_x

    return I_xy

In [5]:
def id3(x, y, attribute_value_pairs=None, depth=0, max_depth=5):
    
    dtree = {}

    if attribute_value_pairs is None:
        attribute_value_pairs = []
        for index in range (len(x[0])):
            for value in np.unique(x[index]):
            # for val in np.unique(np.array([item[idx] for item in x])):
                attribute_value_pairs.append((index, value))

    attribute_value_pairs = np.array(attribute_value_pairs)

    # check for pure splits
    unique_values_of_y, count_y = np.unique(y, return_counts=True)
    if len(unique_values_of_y) == 1:
        return unique_values_of_y[0]

    if len(attribute_value_pairs) == 0 or depth == max_depth:
        return unique_values_of_y[np.argmax(count_y)]

    info_gain = []

    for feat, val in attribute_value_pairs:
        info_gain.append(mutual_information(np.array((x[:, feat] == val).astype(int)), y))

    info_gain = np.array(info_gain)
    (feat, val) = attribute_value_pairs[np.argmax(info_gain)]

    partitions = partition(np.array((x[:, feat] == val).astype(int)))

    attribute_value_pairs = np.delete(attribute_value_pairs, np.argmax(info_gain), 0)
    for value, indices in partitions.items():
        x_new = x.take(np.array(indices), axis=0)
        y_new = y.take(np.array(indices), axis=0)
        output = bool(value)

        dtree[(feat, val, output)] = id3(x_new, y_new, attribute_value_pairs=attribute_value_pairs, depth=depth+1, max_depth=max_depth)

    return dtree


In [6]:
def predict_example(x, tree):
    for decision_node, child_tree in tree.items():
        index = decision_node[0]
        value = decision_node[1]
        decision = decision_node[2]

        if decision == (x[index] == value):
            if type(child_tree) is not dict:
                class_label = child_tree
            else:
                class_label = predict_example(x, child_tree)

            return class_label


def compute_error(y_true, y_pred):
    return (1/len(y_true)) * sum(y_true != y_pred)


In [7]:
def accuracy(y_true, y_pred):
    return np.mean(y_true == y_pred) 


In [8]:
def visualize(tree, depth=0):
    if depth == 0:
        print('TREE')

    for inx, split_criterion in enumerate(tree):
        sub_trees = tree[split_criterion]

        # Print the current node: split criterion
        print('|\t' * depth, end='')
        print('+-- [SPLIT: x{0} = {1}]'.format(split_criterion[0], split_criterion[1]))

        # Print the children
        if type(sub_trees) is dict:
            visualize(sub_trees, depth + 1)
        else:
            print('|\t' * (depth + 1), end='')
            print('+-- [LABEL = {0}]'.format(sub_trees))

In [9]:
def find_depth(tree, tree_depth=1):
    for key in tree:
        if isinstance(tree[key], dict):
            tree_depth = find_depth(tree[key], tree_depth + 1)
    return tree_depth

In [10]:
def confusion_matrix(y_pred, y_true):
    confusion_matrix = [[0, 0],[0, 0]]
    for label_indx in range(len(y_true)):
        if y_pred[label_indx]==y_true[label_indx]:
            if y_pred[label_indx]==1:
                confusion_matrix[0][0] += 1
            else:
                confusion_matrix[1][1] += 1
        else:
            if y_pred[label_indx]==1:
                confusion_matrix[1][0] += 1
            else:
                confusion_matrix[0][1] += 1

    for row in confusion_matrix:
        print(row)

In [11]:
df=pd.read_csv('movies_dataset_processed.csv')
df

Unnamed: 0.1,Unnamed: 0,IMDb-rating,appropriate_for,director,downloads,industry,language,posted_date,release_date,run_time,storyline,title,views,writer,days_to_post,bucket
0,0,4.8,R,John Swab,304,Holywood,English,2023-02-20,2023-01-28,105,Doc\r\n facilitates a fragile truce between th...,Little Dixie,2794,John Swab,23,6.0
1,1,6.4,TV-PG,Paul Ziller,73,Holywood,English,2023-02-20,2023-02-05,84,Caterer\r\n Goldy Berry reunites with detectiv...,Grilling Season: A Curious Caterer Mystery,1002,John Christian Plummer,15,6.0
2,2,5.2,R,Ben Wheatley,1427,Holywood,"English,Hindi",2021-04-20,2021-06-18,107,As the world searches for a cure to a disastro...,In the Earth,14419,Ben Wheatley,59,7.0
3,3,6.5,R,Benjamin Caron,1781,Holywood,English,2023-02-13,2023-02-17,116,"Motivations are suspect, and expectations are ...",Sharper,18225,"Brian Gatewood, Alessandro Tanaka",4,4.0
4,4,6.9,PG-13,Ravi Kapoor,458,Holywood,English,2023-02-18,2022-12-02,80,An\r\n unmotivated South Asian American rapper...,Four Samosas,6912,Ravi Kapoor,78,7.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9897,9897,7.1,Not Rated,Biren Nag,1932,Bolywood,Hindi,1970-01-01,1962-05-11,158,"After a lusty Thakur rapes a young girl, she k...",Bees Saal Baad,6076,"Dhruva Chatterjee, Dev Kishan",2792,9.0
9898,9898,7.0,G,Guy Hamilton,2544,Holywood,"English,German,Polish,French",1970-01-01,1969-09-17,132,Historical reenactment of the air war in the e...,Battle of Britain,9319,"James Kennaway, Wilfred Greatorex, Derek Dempster",106,8.0
9899,9899,5.6,R,Barbara Topsøe-Rothenborg,12284,Holywood,"Spanish,German,English",2016-05-26,1970-01-01,90,"LOVE AT FIRST HICCUP is a charming, innocent, ...",Love at First Hiccup,36022,"Barbara Topsøe-Rothenborg, Søren Frellesen, De...",16947,10.0
9900,9900,7.1,Not Rated,Biren Nag,1932,Bolywood,Hindi,1970-01-01,1962-05-11,158,"After a lusty Thakur rapes a young girl, she k...",Bees Saal Baad,6077,"Dhruva Chatterjee, Dev Kishan",2792,9.0


In [12]:
from sklearn import preprocessing 
label_encoder = preprocessing.LabelEncoder()
df['appropriate_for']= label_encoder.fit_transform(df['appropriate_for'])
df['IMDb-rating']= label_encoder.fit_transform(df['IMDb-rating'])
df['views'] = df['views'].str.replace(',', '').astype(int)
df['downloads'] = df['downloads'].str.replace(',', '').astype(int)
df['release_year'] = pd.to_datetime(df['release_date']).dt.year


In [13]:
test_size = 0.3
num_test_samples = int(test_size * df.shape[0])
num_train_samples = df.shape[0] - num_test_samples
df_train = df.sample(num_train_samples, random_state=42)
df_test = df.drop(df_train.index)

In [17]:
X_cols = ['IMDb-rating', 'appropriate_for', 'downloads','run_time', 'views', 'release_year']
Y_cols=['bucket']

X_train = df_train[X_cols].values
y_train = df_train[Y_cols].values.reshape(-1, 1)

X_test = df_test[X_cols].values
y_test = df_test[Y_cols].values.reshape(-1, 1)

In [19]:
depth = []
bestAcc = 0
bestDepth = 0

for d in range(1, 11):
    depth.append(d)
    decision_tree = id3(X_train, y_train,max_depth=d)
            
    y_pred = [predict_example(x, decision_tree) for x in X_train]
    trn_err = compute_error(y_train, y_pred)
    
    y_pred = [predict_example(x, decision_tree) for x in X_test]
    tst_err = compute_error(y_test, y_pred)
     
    
    print('Confusion matrix:')
    confusion_matrix(y_pred, y_test)

    acc = accuracy(y_test, y_pred)
    print("Accuracy:", acc)
    
    if(bestAcc < acc):
        bestAcc = acc
        bestDepth = d

Confusion matrix:
[0, 2258]
[0, 712]
Accuracy: 0.23973063973063974
Confusion matrix:
[24, 2234]
[0, 712]
Accuracy: 0.23843825459987075
Confusion matrix:
[24, 2249]
[0, 697]
Accuracy: 0.22688025031459375
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22656089514675373
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22605085648856693
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22605085648856693
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22605085648856693
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22605085648856693
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22605085648856693
Confusion matrix:
[24, 2246]
[0, 700]
Accuracy: 0.22605085648856693


In [20]:
visualize(decision_tree, depth=bestDepth)

|	+-- [SPLIT: x5 = 2022]
|	|	+-- [SPLIT: x2 = 283]
|	|	|	+-- [SPLIT: x1 = 6]
|	|	|	|	+-- [SPLIT: x3 = 100]
|	|	|	|	|	+-- [SPLIT: x0 = 53]
|	|	|	|	|	|	+-- [SPLIT: x0 = 7]
|	|	|	|	|	|	|	+-- [SPLIT: x0 = 130]
|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 2018]
|	|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 4330]
|	|	|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 29842]
|	|	|	|	|	|	|	|	|	|	|	+-- [LABEL = 3.0]
|	|	|	|	|	+-- [SPLIT: x0 = 53]
|	|	|	|	|	|	+-- [SPLIT: x4 = 29601]
|	|	|	|	|	|	|	+-- [LABEL = 1.0]
|	|	|	|	|	|	+-- [SPLIT: x4 = 29601]
|	|	|	|	|	|	|	+-- [SPLIT: x0 = 7]
|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 130]
|	|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 2018]
|	|	|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 4330]
|	|	|	|	|	|	|	|	|	|	|	+-- [LABEL = 3.0]
|	|	|	|	+-- [SPLIT: x3 = 100]
|	|	|	|	|	+-- [SPLIT: x0 = 53]
|	|	|	|	|	|	+-- [SPLIT: x0 = 7]
|	|	|	|	|	|	|	+-- [SPLIT: x0 = 130]
|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 2018]
|	|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 4330]
|	|	|	|	|	|	|	|	|	|	+-- [SPLIT: x0 = 29842]
|	|	|	|	|	|	|	|	|	|	|	+-- [LABEL = 3.0]
|	|	|	

In [21]:
print("Best Accuracy is ", bestAcc)
print("Best Depth is ", bestDepth)

Best Accuracy is  0.23973063973063974
Best Depth is  1


In [24]:
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

decision_tree = id3(X_train, y_train,max_depth=bestDepth)
y_pred = [predict_example(x, decision_tree) for x in X_train]
mse = mean_squared_error(y_train,y_pred)
r2 = r2_score(y_train,y_pred)

y_pred = [predict_example(x, decision_tree) for x in X_test]
mse = mean_squared_error(y_test,y_pred)



print(f'R^2 score: {r2:.2f}')
print(f'Mean squared error: {mse}')

R^2 score: -0.81
Mean squared error: 11.985858585858585
