In [5]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats

# Ensure plots are displayed inline in Colab
%matplotlib inline

# Set up the style for better-looking plots
plt.style.use('seaborn')
sns.set_palette("deep")
sns.set_context("notebook", font_scale=1.2)

class DataUnderstanding:
    def __init__(self, df):
        self.df = df

    def get_summary_statistics(self):
        return self.df.describe()

    def get_missing_values(self):
        return self.df.isnull().sum()

    def get_info(self):
        return self.df.info()

    def get_dtypes(self):
        return self.df.dtypes

    def get_value_counts(self):
        return {column: self.df[column].value_counts() for column in self.df.columns}

# Load the dataset
df = pd.read_csv('train.csv')
du = DataUnderstanding(df)

# Data Preprocessing
df = df.drop('Cabin', axis=1)
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
df.dropna(subset=['Age'], inplace=True)

# Function to remove outliers
def remove_outliers(data, cols, threshold=3):
    for col in cols:
        z_scores = np.abs(stats.zscore(data[col]))
        data = data[(z_scores < threshold)]
    return data

numerical_columns = ['PassengerId', 'Survived', 'Pclass', 'Age', 'SibSp', 'Parch', 'Fare']

# 1. Survival Rate Pie Chart
def plot_survival_rate(df):
    survival_counts = df['Survived'].value_counts()
    labels = ['Did not survive', 'Survived']
    colors = ['#ff9999', '#66b3ff']

    fig = go.Figure(data=[go.Pie(labels=labels,
                                 values=survival_counts,
                                 hole=.3,
                                 marker_colors=colors)])

    fig.update_layout(title_text='Survival Rate', title_x=0.5)
    fig.show()

plot_survival_rate(df)

# 2. Distribution of Age
fig = px.histogram(df, x='Age', nbins=30, marginal='box',
                   title='Distribution of Passenger Ages',
                   labels={'Age': 'Age (years)', 'count': 'Number of Passengers'},
                   color_discrete_sequence=['#66b3ff'])
fig.update_layout(bargap=0.1)
fig.show()

# 3. Passenger Class Distribution
pclass_counts = df['Pclass'].value_counts().sort_index().reset_index()
pclass_counts.columns = ['Pclass', 'Count']
fig = px.bar(pclass_counts, x='Pclass', y='Count',
             title='Passenger Class Distribution',
             labels={'Pclass': 'Passenger Class', 'Count': 'Number of Passengers'},
             color='Pclass', color_discrete_sequence=px.colors.qualitative.Set2)
fig.update_layout(showlegend=False)
fig.show()

# 4. Scatter Plot: Age vs. Fare by Survival
fig = px.scatter(df, x='Age', y='Fare', color='Survived', size='Fare',
                 title='Age vs. Fare by Survival',
                 labels={'Age': 'Age (years)', 'Fare': 'Fare ($)', 'Survived': 'Survived'},
                 color_discrete_sequence=['#ff9999', '#66b3ff'])
fig.show()

# 5. Survivors by Passenger Class
fig = px.histogram(df, x='Pclass', color='Survived', barmode='group',
                   title='Survivors by Passenger Class',
                   labels={'Pclass': 'Passenger Class', 'count': 'Number of Passengers', 'Survived': 'Survived'},
                   color_discrete_sequence=['#ff9999', '#66b3ff'])
fig.show()

# 6. Fare Distribution by Passenger Class
fig = px.box(df, x='Pclass', y='Fare', color='Pclass',
             title='Fare Distribution by Passenger Class',
             labels={'Pclass': 'Passenger Class', 'Fare': 'Fare ($)'},
             color_discrete_sequence=px.colors.qualitative.Set2)
fig.show()

# 7. Correlation Heatmap
corr_matrix = df[numerical_columns].corr()
fig = px.imshow(corr_matrix, text_auto=True, aspect="auto",
                title='Correlation Heatmap of Numerical Variables',
                color_continuous_scale='RdBu_r')
fig.show()

# 8. Pair Plot
fig = px.scatter_matrix(df[numerical_columns],
                        dimensions=numerical_columns,
                        color='Survived',
                        title='Pair Plot of Numerical Variables',
                        color_discrete_sequence=['#ff9999', '#66b3ff'])
fig.update_traces(diagonal_visible=False)
fig.show()

# 9. Age Distribution by Passenger Class and Survival
fig = px.violin(df, x='Pclass', y='Age', color='Survived', box=True, points="all",
                title='Age Distribution by Passenger Class and Survival',
                labels={'Pclass': 'Passenger Class', 'Age': 'Age (years)', 'Survived': 'Survived'},
                color_discrete_sequence=['#ff9999', '#66b3ff'])
fig.show()

# 10. Multivariate Parallel Coordinates Plot
fig = px.parallel_coordinates(df, dimensions=['Age', 'Fare', 'Pclass', 'Survived'],
                              color='Survived', color_continuous_scale=px.colors.diverging.RdYlBu,
                              title='Multivariate Parallel Coordinates Plot')
fig.show()


The seaborn styles shipped by Matplotlib are deprecated since 3.6, as they no longer correspond to the styles shipped by seaborn. However, they will remain available as 'seaborn-v0_8-<style>'. Alternatively, directly use the seaborn API instead.

