Notebook for EDA of Superconducter data set

In [None]:
# import modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Ensure all columns are displayed without truncation in Jupyter
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)  # Prevents truncation
pd.set_option('display.max_rows', None)  # Show all rows if needed
pd.set_option('max_colwidth', None)  # Prevent column width truncation

# read in data
df = pd.read_csv('./data/train.csv')
df_m = pd.read_csv('./data/unique_m.csv')


In [None]:
df_m.head()

In [None]:
df.head()

In [None]:
print('df_m shape: ', df_m.shape)
print('df shape: ', df.shape)

In [None]:
print(df.describe())

In [None]:
print(df_m.describe())

In [None]:
print(df.info())

In [None]:
print(df_m.info())

In [None]:
# Check explicit for missing values
print("Missing values:\n", df.isnull().sum().sum())

In [None]:
# check to see if critical_temp values line up between the two data sets

# print(df['critical_temp'] == df_m['critical_temp'])
df['critical_temp'].equals(df_m['critical_temp'])

In [None]:
# Plot the distribution of the target variable (critical temperature)
plt.figure(figsize=(10, 6))
sns.histplot(df['critical_temp'], bins=30, kde=True)
plt.title('Distribution of Critical Temperature')
plt.xlabel('Critical Temperature')
plt.ylabel('Frequency')
plt.show()

In [None]:
# Drop non-element columns
df_elements = df_m.drop(columns=['critical_temp', 'material'])

# Count the number of nonzero entries for each element column
frequency = (df_elements != 0).sum()

# Sort the counts in descending order and limit to the top 20 elements
frequency_sorted = frequency.sort_values(ascending=False)
top20 = frequency_sorted.head(20)

# Convert the series to a DataFrame for plotting
df_freq = top20.reset_index()
df_freq.columns = ['Element', 'Frequency']

# Create a seaborn barplot using a blue palette and assign hue to 'Element'
plt.figure(figsize=(12, 6))
sns.barplot(data=df_freq, x='Element', y='Frequency', hue='Element', palette='Blues_d')

plt.xticks(rotation=90)  # Rotate x labels for better readability
plt.xlabel('Element')
plt.ylabel('Frequency')
plt.title('Frequency of Elements in the Dataset (Top 20)')
plt.tight_layout()
plt.show()


In [None]:
# Correlation matrix and heatmap to see feature relationships
plt.figure(figsize=(16, 14))
corr_matrix = df.corr()
sns.heatmap(corr_matrix, cmap='viridis', annot=False, fmt=".2f")
plt.title('Feature Correlation Matrix')
plt.show()

In [None]:
# Sample 2000 rows and compute the correlation matrix
df_sample = df.sample(n=4000, random_state=42)
corr_matrix = df_sample.corr()

# Apply a threshold to show only strong correlations
threshold = 0.2
filtered_corr = corr_matrix[(corr_matrix > threshold) | (corr_matrix < -threshold)]

# Plot heatmap
plt.figure(figsize=(18, 16))  # set figure size
sns.heatmap(
    filtered_corr,
    annot=False,
    fmt=".1f",
    cmap="coolwarm",
    linewidths=0.2,
    mask=filtered_corr.isna(),
    annot_kws={"size": 4},  # Reduce annotation font size
    # cbar_kws={"shrink": 0.75},  # Shrink color bar for better readability
)

# Rotate axis labels
plt.xticks(rotation=45, ha='right', fontsize=10)
plt.yticks(fontsize=10)

plt.title("Filtered Correlation Matrix (|corr| > 0.2)", fontsize=14)
plt.show()
