# Data Catalog And Lineage

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import os
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

## Data Catalog and Lineage

This notebook handles the documentation and tracking of data assets, their origins, and transformations.
It establishes a system for maintaining metadata about datasets and tracking how data flows through the pipeline.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import json
import datetime
import os
import uuid
from typing import Dict, List, Any, Optional, Union
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
# Define a class for data catalog management
class DataCatalog:
    """
    A class to manage dataset metadata, lineage, and catalog information.
    """
    def __init__(self, catalog_path: str = "data_catalog.json"):
        """
        Initialize the data catalog.
        
        Args:
            catalog_path: Path to the catalog JSON file
        """
        self.catalog_path = catalog_path
        self.catalog = self._load_catalog()
        
    def _load_catalog(self) -> Dict:
        """Load the catalog from file if it exists, otherwise create a new one."""
        if os.path.exists(self.catalog_path):
            with open(self.catalog_path, 'r') as f:
                return json.load(f)
        else:
            return {"datasets": {}, "transformations": {}}
    
    def save_catalog(self):
        """Save the catalog to file."""
        with open(self.catalog_path, 'w') as f:
            json.dump(self.catalog, f, indent=2)
    
    def register_dataset(self, 
                         dataset_id: str, 
                         name: str, 
                         description: str, 
                         schema: Dict, 
                         source: str,
                         owner: str,
                         tags: List[str] = None,
                         quality_metrics: Dict = None) -> str:
        """
        Register a dataset in the catalog.
        
        Args:
            dataset_id: Unique identifier for the dataset
            name: Name of the dataset
            description: Description of the dataset
            schema: Schema information
            source: Source of the dataset
            owner: Owner of the dataset
            tags: List of tags for the dataset
            quality_metrics: Quality metrics for the dataset
            
        Returns:
            dataset_id: The ID of the registered dataset
        """
        if dataset_id in self.catalog["datasets"]:
            print(f"Dataset {dataset_id} already exists. Updating...")
        
        self.catalog["datasets"][dataset_id] = {
            "name": name,
            "description": description,
            "schema": schema,
            "source": source,
            "owner": owner,
            "tags": tags or [],
            "quality_metrics": quality_metrics or {},
            "created_at": datetime.datetime.now().isoformat(),
            "updated_at": datetime.datetime.now().isoformat()
        }
        
        self.save_catalog()
        return dataset_id
    
    def register_transformation(self, 
                               transformation_id: str,
                               name: str,
                               description: str,
                               input_datasets: List[str],
                               output_datasets: List[str],
                               transformation_code: str = None,
                               parameters: Dict = None) -> str:
        """
        Register a data transformation in the catalog.
        
        Args:
            transformation_id: Unique identifier for the transformation
            name: Name of the transformation
            description: Description of the transformation
            input_datasets: List of input dataset IDs
            output_datasets: List of output dataset IDs
            transformation_code: Code used for the transformation
            parameters: Parameters used in the transformation
            
        Returns:
            transformation_id: The ID of the registered transformation
        """
        if transformation_id in self.catalog["transformations"]:
            print(f"Transformation {transformation_id} already exists. Updating...")
        
        self.catalog["transformations"][transformation_id] = {
            "name": name,
            "description": description,
            "input_datasets": input_datasets,
            "output_datasets": output_datasets,
            "transformation_code": transformation_code,
            "parameters": parameters or {},
            "created_at": datetime.datetime.now().isoformat(),
            "updated_at": datetime.datetime.now().isoformat()
        }
        
        self.save_catalog()
        return transformation_id
    
    def get_dataset_lineage(self, dataset_id: str) -> Dict:
        """
        Get the lineage information for a dataset.
        
        Args:
            dataset_id: ID of the dataset
            
        Returns:
            Dict containing upstream and downstream datasets
        """
        if dataset_id not in self.catalog["datasets"]:
            raise ValueError(f"Dataset {dataset_id} not found in catalog")
        
        lineage = {"upstream": [], "downstream": []}
        

In [None]:
        # Find transformations where this dataset is an output (upstream)
        for trans_id, trans in self.catalog["transformations"].items():
            if dataset_id in trans["output_datasets"]:
                lineage["upstream"].append({
                    "transformation_id": trans_id,
                    "transformation_name": trans["name"],
                    "input_datasets": trans["input_datasets"]
                })
        

In [None]:
        # Find transformations where this dataset is an input (downstream)
        for trans_id, trans in self.catalog["transformations"].items():
            if dataset_id in trans["input_datasets"]:
                lineage["downstream"].append({
                    "transformation_id": trans_id,
                    "transformation_name": trans["name"],
                    "output_datasets": trans["output_datasets"]
                })
        
        return lineage
    
    def visualize_lineage(self, dataset_id: str = None):
        """
        Visualize the lineage of datasets.
        
        Args:
            dataset_id: Optional ID of a specific dataset to visualize
        """
        try:
            import networkx as nx
            
            G = nx.DiGraph()
            

In [None]:
            # Add all datasets as nodes
            for ds_id, ds in self.catalog["datasets"].items():
                G.add_node(ds_id, label=ds["name"], type="dataset")
            

In [None]:
            # Add all transformations as nodes
            for trans_id, trans in self.catalog["transformations"].items():
                G.add_node(trans_id, label=trans["name"], type="transformation")
                

In [None]:
                # Add edges from input datasets to transformation
                for input_ds in trans["input_datasets"]:
                    G.add_edge(input_ds, trans_id)
                

In [None]:
                # Add edges from transformation to output datasets
                for output_ds in trans["output_datasets"]:
                    G.add_edge(trans_id, output_ds)
            

In [None]:
            # If a specific dataset is provided, filter the graph
            if dataset_id:

In [None]:
                # Get all ancestors and descendants
                ancestors = nx.ancestors(G, dataset_id)
                descendants = nx.descendants(G, dataset_id)
                relevant_nodes = ancestors.union(descendants).union({dataset_id})
                G = G.subgraph(relevant_nodes)
            
            plt.figure(figsize=(12, 8))
            pos = nx.spring_layout(G)
            

In [None]:
            # Draw dataset nodes
            dataset_nodes = [n for n, d in G.nodes(data=True) if d.get("type") == "dataset"]
            nx.draw_networkx_nodes(G, pos, nodelist=dataset_nodes, node_color='skyblue', node_size=500)
            

In [None]:
            # Draw transformation nodes
            trans_nodes = [n for n, d in G.nodes(data=True) if d.get("type") == "transformation"]
            nx.draw_networkx_nodes(G, pos, nodelist=trans_nodes, node_color='lightgreen', node_size=300, node_shape='s')
            

In [None]:
            # Draw edges
            nx.draw_networkx_edges(G, pos, arrows=True)
            

In [None]:
            # Draw labels
            labels = {n: G.nodes[n].get("label", n) for n in G.nodes()}
            nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)
            
            plt.title("Data Lineage Graph")
            plt.axis('off')
            plt.tight_layout()
            plt.show()
            
        except ImportError:
            print("Please install networkx to visualize lineage: pip install networkx")



In [None]:
# Define a class for data transformation tracking
class DataTransformer:
    """
    A class to track and document data transformations.
    """
    def __init__(self, catalog: DataCatalog):
        """
        Initialize the data transformer.
        
        Args:
            catalog: DataCatalog instance
        """
        self.catalog = catalog
        self.transformation_history = []
    
    def transform(self, 
                 transformation_func,
                 input_data: Dict[str, pd.DataFrame],
                 transformation_id: str,
                 name: str,
                 description: str,
                 parameters: Dict = None) -> Dict[str, pd.DataFrame]:
        """
        Apply a transformation function and track the lineage.
        
        Args:
            transformation_func: Function that performs the transformation
            input_data: Dictionary of input dataframes {dataset_id: dataframe}
            transformation_id: Unique identifier for the transformation
            name: Name of the transformation
            description: Description of the transformation
            parameters: Parameters for the transformation
            
        Returns:
            Dictionary of output dataframes {dataset_id: dataframe}
        """

In [None]:
        # Record the start time
        start_time = datetime.datetime.now()
        

In [None]:
        # Apply the transformation
        output_data = transformation_func(input_data, parameters or {})
        

In [None]:
        # Record the end time
        end_time = datetime.datetime.now()
        

In [None]:
        # Record the transformation
        transformation_record = {
            "transformation_id": transformation_id,
            "name": name,
            "description": description,
            "input_datasets": list(input_data.keys()),
            "output_datasets": list(output_data.keys()),
            "parameters": parameters or {},
            "start_time": start_time.isoformat(),
            "end_time": end_time.isoformat(),
            "duration_seconds": (end_time - start_time).total_seconds()
        }
        

In [None]:
        # Add to history
        self.transformation_history.append(transformation_record)
        

In [None]:
        # Register the transformation in the catalog
        self.catalog.register_transformation(
            transformation_id=transformation_id,
            name=name,
            description=description,
            input_datasets=list(input_data.keys()),
            output_datasets=list(output_data.keys()),
            parameters=parameters or {}
        )
        
        return output_data
    
    def get_transformation_history(self) -> List[Dict]:
        """
        Get the history of transformations.
        
        Returns:
            List of transformation records
        """
        return self.transformation_history



In [None]:
# Define a class for data profiling and quality metrics
class DataProfiler:
    """
    A class to generate data profiles and quality metrics.
    """
    @staticmethod
    def profile_dataframe(df: pd.DataFrame) -> Dict:
        """
        Generate a profile of a dataframe.
        
        Args:
            df: Pandas DataFrame to profile
            
        Returns:
            Dictionary containing profile information
        """
        profile = {
            "shape": df.shape,
            "columns": list(df.columns),
            "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
            "missing_values": {col: int(df[col].isna().sum()) for col in df.columns},
            "missing_percentage": {col: float(df[col].isna().mean() * 100) for col in df.columns},
            "unique_values": {col: int(df[col].nunique()) for col in df.columns if df[col].dtype == 'object' or df[col].dtype.name == 'category'},
            "memory_usage": {col: float(df[col].memory_usage(deep=True) / 1024) for col in df.columns},  # KB
            "total_memory_kb": float(df.memory_usage(deep=True).sum() / 1024)
        }
        

In [None]:
        # Add numeric column statistics
        numeric_cols = df.select_dtypes(include=['number']).columns
        if len(numeric_cols) > 0:
            profile["numeric_stats"] = {}
            for col in numeric_cols:
                profile["numeric_stats"][col] = {
                    "min": float(df[col].min()) if not pd.isna(df[col].min()) else None,
                    "max": float(df[col].max()) if not pd.isna(df[col].max()) else None,
                    "mean": float(df[col].mean()) if not pd.isna(df[col].mean()) else None,
                    "median": float(df[col].median()) if not pd.isna(df[col].median()) else None,
                    "std": float(df[col].std()) if not pd.isna(df[col].std()) else None
                }
        
        return profile
    
    @staticmethod
    def calculate_quality_metrics(df: pd.DataFrame) -> Dict:
        """
        Calculate data quality metrics for a dataframe.
        
        Args:
            df: Pandas DataFrame to analyze
            
        Returns:
            Dictionary containing quality metrics
        """
        total_rows = len(df)
        total_cells = total_rows * len(df.columns)
        missing_cells = df.isna().sum().sum()
        
        metrics = {
            "completeness": {
                "score": float((total_cells - missing_cells) / total_cells * 100),
                "description": "Percentage of non-missing values"
            },
            "column_completeness": {
                col: float((total_rows - df[col].isna().sum()) / total_rows * 100) 
                for col in df.columns
            }
        }
        

In [None]:
        # Check for duplicated rows
        duplicate_rows = df.duplicated().sum()
        metrics["uniqueness"] = {
            "score": float((total_rows - duplicate_rows) / total_rows * 100),
            "description": "Percentage of unique rows"
        }
        
        return metrics



In [None]:
# Load the sample dataset
def load_nz_industry_data():
    """
    This function simulates loading the New Zealand Industry Financial Dataset.
    In a real scenario, you would load from a file or database.
    
    Returns:
        A pandas DataFrame containing the dataset
    """

In [None]:
    # Create a sample of the dataset based on the provided information

In [None]:
    # In a real scenario, you would load the actual data
    

In [None]:
    # Create sample data
    years = list(range(2013, 2024))
    industry_levels = ["Level 1", "Level 2", "Level 3", "Level 4"]
    industry_codes = ["99999", "AA111", "BB222", "CC333", "DD444"]
    industry_names = ["All industries", "Agriculture", "Manufacturing", "Construction", "Services"]
    units = ["Dollars (millions)"]
    variable_codes = ["H01", "H04", "H05", "H07", "H08"]
    variable_names = ["Total income", "Sales, government funding, grants and subsidies", 
                     "Interest, dividends and donations", "Non-operating income", "Total expenditure"]
    variable_categories = ["Financial performance"]
    

In [None]:
    # Generate sample data
    data = []
    for year in years:
        for i, industry_code in enumerate(industry_codes):
            for var_code, var_name in zip(variable_codes, variable_names):

In [None]:
                # Generate a random value
                value = np.random.randint(10000, 1000000)
                
                data.append({
                    "Year": year,
                    "Industry_aggregation_NZSIOC": industry_levels[min(i, len(industry_levels)-1)],
                    "Industry_code_NZSIOC": industry_code,
                    "Industry_name_NZSIOC": industry_names[i],
                    "Units": units[0],
                    "Variable_code": var_code,
                    "Variable_name": var_name,
                    "Variable_category": variable_categories[0],
                    "Value": str(value),
                    "Industry_code_ANZSIC06": f"ANZSIC06 divisions A-S (excluding classes K6330, L6711, O7552, O760, O771, O772, S9540, S9601, S9602, and S9603)"
                })
    

In [None]:
    # Create DataFrame
    df = pd.DataFrame(data)
    
    return df


In [None]:
# Main execution

In [None]:
# Initialize the data catalog
catalog = DataCatalog()


In [None]:
# Initialize the data transformer
transformer = DataTransformer(catalog)


In [None]:
# Load the raw data
raw_data = load_nz_industry_data()


In [None]:
# Profile the raw data
profiler = DataProfiler()
raw_profile = profiler.profile_dataframe(raw_data)
raw_quality = profiler.calculate_quality_metrics(raw_data)


In [None]:
# Register the raw dataset in the catalog
raw_dataset_id = "nz_industry_financial_raw"
catalog.register_dataset(
    dataset_id=raw_dataset_id,
    name="New Zealand Industry Financial Data (Raw)",
    description="Raw financial data for New Zealand industries from 2013 to 2023",
    schema={
        "columns": raw_profile["columns"],
        "dtypes": raw_profile["dtypes"]
    },
    source="Annual Enterprise Survey",
    owner="Data Science Team",
    tags=["financial", "industry", "new zealand", "raw"],
    quality_metrics=raw_quality
)


In [None]:
# Define a transformation function to clean the data
def clean_data_transformation(input_data, params):
    """
    Clean the raw data by converting types and handling any issues.
    
    Args:
        input_data: Dictionary with input DataFrames
        params: Parameters for the transformation
        
    Returns:
        Dictionary with output DataFrames
    """
    df = input_data[params["input_dataset_id"]].copy()
    

In [None]:
    # Convert Value column to numeric
    df["Value"] = pd.to_numeric(df["Value"], errors="coerce")
    

In [None]:
    # Extract clean industry codes
    df["Clean_Industry_Code"] = df["Industry_code_NZSIOC"].str.strip()
    

In [None]:
    # Create output
    output_data = {params["output_dataset_id"]: df}
    return output_data


In [None]:
# Apply the transformation
clean_params = {
    "input_dataset_id": raw_dataset_id,
    "output_dataset_id": "nz_industry_financial_clean"
}

output_data = transformer.transform(
    transformation_func=clean_data_transformation,
    input_data={raw_dataset_id: raw_data},
    transformation_id="clean_nz_industry_data",
    name="Clean NZ Industry Financial Data",
    description="Convert Value column to numeric and extract clean industry codes",
    parameters=clean_params
)


In [None]:
# Get the cleaned data
clean_data = output_data["nz_industry_financial_clean"]


In [None]:
# Profile the cleaned data
clean_profile = profiler.profile_dataframe(clean_data)
clean_quality = profiler.calculate_quality_metrics(clean_data)


In [None]:
# Register the cleaned dataset
clean_dataset_id = "nz_industry_financial_clean"
catalog.register_dataset(
    dataset_id=clean_dataset_id,
    name="New Zealand Industry Financial Data (Cleaned)",
    description="Cleaned financial data for New Zealand industries with proper data types",
    schema={
        "columns": clean_profile["columns"],
        "dtypes": clean_profile["dtypes"]
    },
    source="Derived from raw NZ Industry Financial Data",
    owner="Data Science Team",
    tags=["financial", "industry", "new zealand", "cleaned"],
    quality_metrics=clean_quality
)


In [None]:
# Define a transformation to create aggregated metrics
def create_aggregated_metrics(input_data, params):
    """
    Create aggregated financial metrics by year and industry.
    
    Args:
        input_data: Dictionary with input DataFrames
        params: Parameters for the transformation
        
    Returns:
        Dictionary with output DataFrames
    """
    df = input_data[params["input_dataset_id"]].copy()
    

In [None]:
    # Create pivot table for financial metrics by year and industry
    pivot_df = df.pivot_table(
        index=["Year", "Industry_name_NZSIOC"],
        columns="Variable_name",
        values="Value",
        aggfunc="sum"
    ).reset_index()
    

In [None]:
    # Calculate profit (if Total income and Total expenditure are available)
    if "Total income" in pivot_df.columns and "Total expenditure" in pivot_df.columns:
        pivot_df["Profit"] = pivot_df["Total income"] - pivot_df["Total expenditure"]
        pivot_df["Profit_Margin"] = pivot_df["Profit"] / pivot_df["Total income"] * 100
    

In [None]:
    # Create output
    output_data = {params["output_dataset_id"]: pivot_df}
    return output_data


In [None]:
# Apply the transformation
agg_params = {
    "input_dataset_id": clean_dataset_id,
    "output_dataset_id": "nz_industry_financial_aggregated"
}

output_data = transformer.transform(
    transformation_func=create_aggregated_metrics,
    input_data={clean_dataset_id: clean_data},
    transformation_id="aggregate_nz_industry_data",
    name="Aggregate NZ Industry Financial Data",
    description="Create aggregated financial metrics by year and industry",
    parameters=agg_params
)


In [None]:
# Get the aggregated data
agg_data = output_data["nz_industry_financial_aggregated"]


In [None]:
# Profile the aggregated data
agg_profile = profiler.profile_dataframe(agg_data)
agg_quality = profiler.calculate_quality_metrics(agg_data)


In [None]:
# Register the aggregated dataset
agg_dataset_id = "nz_industry_financial_aggregated"
catalog.register_dataset(
    dataset_id=agg_dataset_id,
    name="New Zealand Industry Financial Data (Aggregated)",
    description="Aggregated financial metrics by year and industry with calculated profit metrics",
    schema={
        "columns": agg_profile["columns"],
        "dtypes": agg_profile["dtypes"]
    },
    source="Derived from cleaned NZ Industry Financial Data",
    owner="Data Science Team",
    tags=["financial", "industry", "new zealand", "aggregated", "metrics"],
    quality_metrics=agg_quality
)


In [None]:
# Define a transformation to create time series features
def create_time_series_features(input_data, params):
    """
    Create time series features from the aggregated data.
    
    Args:
        input_data: Dictionary with input DataFrames
        params: Parameters for the transformation
        
    Returns:
        Dictionary with output DataFrames
    """
    df = input_data[params["input_dataset_id"]].copy()
    

In [None]:
    # Ensure data is sorted by Year
    df = df.sort_values(["Industry_name_NZSIOC", "Year"])
    

In [None]:
    # Calculate year-over-year growth for numeric columns
    numeric_cols = df.select_dtypes(include=['number']).columns
    

In [None]:
    # Group by industry and calculate growth rates
    growth_dfs = []
    
    for industry, group in df.groupby("Industry_name_NZSIOC"):
        group = group.sort_values("Year")
        

In [None]:
        # Calculate growth rates for each numeric column
        for col in numeric_cols:
            if col != "Year":
                group[f"{col}_YoY_Growth"] = group[col].pct_change() * 100
        
        growth_dfs.append(group)
    

In [None]:
    # Combine all industries back together
    ts_df = pd.concat(growth_dfs)
    

In [None]:
    # Create output
    output_data = {params["output_dataset_id"]: ts_df}
    return output_data


In [None]:
# Apply the transformation
ts_params = {
    "input_dataset_id": agg_dataset_id,
    "output_dataset_id": "nz_industry_financial_time_series"
}

output_data = transformer.transform(
    transformation_func=create_time_series_features,
    input_data={agg_dataset_id: agg_data},
    transformation_id="create_time_series_features",
    name="Create Time Series Features",
    description="Create year-over-year growth rates and other time series features",
    parameters=ts_params
)


In [None]:
# Get the time series data
ts_data = output_data["nz_industry_financial_time_series"]


In [None]:
# Profile the time series data
ts_profile = profiler.profile_dataframe(ts_data)
ts_quality = profiler.calculate_quality_metrics(ts_data)


In [None]:
# Register the time series dataset
ts_dataset_id = "nz_industry_financial_time_series"
catalog.register_dataset(
    dataset_id=ts_dataset_id,
    name="New Zealand Industry Financial Time Series",
    description="Time series features including year-over-year growth rates for financial metrics",
    schema={
        "columns": ts_profile["columns"],
        "dtypes": ts_profile["dtypes"]
    },
    source="Derived from aggregated NZ Industry Financial Data",
    owner="Data Science Team",
    tags=["financial", "industry", "new zealand", "time series", "growth rates"],
    quality_metrics=ts_quality
)


In [None]:
# Visualize the data lineage
catalog.visualize_lineage()


In [None]:
# Print the transformation history
print("\nTransformation History:")
for i, trans in enumerate(transformer.get_transformation_history()):
    print(f"\n{i+1}. {trans['name']}")
    print(f"   ID: {trans['transformation_id']}")
    print(f"   Description: {trans['description']}")
    print(f"   Input datasets: {trans['input_datasets']}")
    print(f"   Output datasets: {trans['output_datasets']}")
    print(f"   Duration: {trans['duration_seconds']:.2f} seconds")


In [None]:
# Get lineage for the final dataset
print("\nLineage for Time Series Dataset:")
lineage = catalog.get_dataset_lineage(ts_dataset_id)
print(json.dumps(lineage, indent=2))


In [None]:
# Save the final catalog
catalog.save_catalog()
print("\nData catalog saved to:", catalog.catalog_path)