In [None]:
!pip install seaborn

In [2]:
from google.colab import drive
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
#create a class to read, clean, and visualise any file
class data_science():

  #attributes for the class are the file path and file type (csv, excel or json file)
  def __init__(self, file_path, file_type):
    self.file_path = file_path
    self.file_type = file_type
    self.df = None

  #class method to read the file depending on the file type as discussed earlier
  def read_file(self):
    if self.file_type == 'csv':
      self.df = pd.read_csv(self.file_path)
    elif self.file_type == 'excel':
      self.df = pd.read_excel(self.file_path)
    elif self.file_type == 'json':
      self.df = pd.read_json(self.file_path)
    else:
      print("Unsupported file type")
    return self.df

  #class method to clean the titanic dataset
  def clean_file(self):
    self.df = self.df.drop(columns = ['cabin', 'boat', 'body'])
    self.df.drop_duplicates(inplace=True)
    self.df['age'] = self.df['age'].round()
    self.df['age'] = self.df['age'].astype('Int64')
    self.df['sex'] = self.df['sex'].replace(to_replace = ['male', 'female'], value = ['M', 'F'])
    self.df['survived'] = self.df['survived'].replace(to_replace = [1, 0], value = ['Yes', 'No'])
    self.df['age'] = self.df.groupby('sex')['age'].transform(lambda x: x.fillna(x.mode()[0]))
    self.df['fare'] = self.df['fare'].fillna(self.df['fare'].mode()[0])
    self.df['embarked'] = self.df['embarked'].fillna(self.df['embarked'].mode()[0])
    self.df['home.dest'] = self.df['home.dest'].fillna(self.df['home.dest'].mode()[0])
    return self.df

  #class method to visualise the scatter plot for fare against age
  def age_v_fare_sc(self):
    sns.relplot(data = self.df, x="age", y="fare")
    plt.title('Scatter Plot for Fare by Age')
    plt.xlabel('Age')
    plt.ylabel('Fare')

  #class method to visualise the box plot for fare against age group
  def age_v_fare_bc(self):
    bins = [0, 9, 19, 29, 39, 49, 59, 69, 79, 89]
    labels = ['0–9', '10–19', '20–29', '30–39', '40–49',
              '50–59', '60–69', '70–79', '80–89']
    self.df['age_group'] = pd.cut(self.df['age'], bins=bins, labels=labels, include_lowest=True)
    sns.boxplot(data = self.df, x = 'age_group', y = 'fare')
    plt.title('Fare by Age Group')
    plt.xlabel('Age Group')
    plt.ylabel('Fare')

  #class method to visualise the bar plot for fare against sex
  def sex_v_fare(self):
    sns.barplot(y = 'fare', x = 'sex', data = self.df.groupby('sex')['fare'].mean().reset_index())
    plt.title('Fare by Sex')
    plt.xlabel('Sex')
    plt.ylabel('Fare')

  #class method to visualise the grouped bar plot for number of survival status against age
  def age_v_survived(self):
    bins = [0, 9, 19, 29, 39, 49, 59, 69, 79, 89]
    labels = ['0–9', '10–19', '20–29', '30–39', '40–49',
              '50–59', '60–69', '70–79', '80–89']
    self.df['age_group'] = pd.cut(self.df['age'], bins=bins, labels=labels, include_lowest=True)
    count_df = self.df.groupby(['age_group', 'survived']).size().reset_index(name='sum')
    sns.barplot(
        data = count_df,
        x='age_group',
        y='sum',
        hue='survived',
        palette={'Yes' : 'blue', 'No' : 'yellow'}
    )
    plt.title('Survival Count by Age Group')
    plt.xlabel('Age Group')
    plt.ylabel('Survival Count')

  #class method to visualise the grouped bar plot for number of survival status against sex
  def sex_v_survived(self):
    sum_df = self.df.groupby(['sex', 'survived']).size().reset_index(name='sum')
    sns.barplot(
        data = sum_df,
        x='sex',
        y='sum',
        hue='survived',
        palette={'Yes' : 'blue', 'No' : 'yellow'}
    )
    plt.title('Survival Count by Sex')
    plt.xlabel('Sex')
    plt.ylabel('Survival Count')

  #class method to visualise the correlation heatmap for the data after dropping string columns and encoding binary inputs
  def corr_map(self):
    self.df = self.df.drop(columns=['name', 'home.dest', 'ticket', 'age_group'], errors='ignore')
    df_encoded = pd.get_dummies(self.df, columns=['sex', 'survived', 'embarked'], drop_first=True)
    sns.heatmap(df_encoded.corr(), annot=True, cmap = 'bwr')
    plt.title('Correlation Heatmap')

In [None]:
#create object using file path and file type
doc = '/content/drive/MyDrive/TechCrush Tasks/titanic3.xls'
file1 = data_science(doc, 'excel')
file1.read_file()

In [None]:
file1.clean_file()

In [None]:
file1.age_v_fare_sc()

In [None]:
file1.age_v_fare_bc()

In [None]:
file1.sex_v_fare()

In [None]:
file1.age_v_survived()

In [None]:
file1.sex_v_survived()

In [None]:
file1.corr_map()