In [16]:
import pandas as pd

# Load data
plays = pd.read_csv('data/plays.csv')
locations = pd.read_csv('data/locations.csv')
tracks = pd.read_csv('data/tracks.csv')

# Filter valid plays
def _is_valid_play(play_df):
    return play_df[play_df['played_ms'] >= 60000]

# Join data 
def _join_data(plays_df, locations_df, tracks_df):
    return (
        plays_df
        .merge(locations_df, left_on='location_id', right_on='id', how='left')
        .merge(tracks_df, left_on='track_id', right_on='id', how='left')
    )

def _market_share(plays_with_all_info):
    return (
        plays_with_all_info
        .groupby(['country', 'licensor'])
        .agg(play_count=('country', 'size'))  # Count the occurrences
        .reset_index()  # Convert back to df
        .assign(  # Add market share 
            market_share=lambda df: df['play_count'] /
            df.groupby('country')['play_count'].transform('sum')
        )
    )

valid_plays = _is_valid_play(plays)
plays_with_all_info = _join_data(valid_plays, locations, tracks)
grouped = _market_share(plays_with_all_info) 


def run_data_quality_tests(input_df, output_df):
    # Input data tests
    assert not input_df[['country', 'licensor']].isnull().any().any(), \
        "Input data contains null values."
    assert not input_df.duplicated().any(), "Input data contains duplicate rows."
    
    ## Transformations tests
    grouped['play_count'].sum() == valid_plays['track_id'].count()
    
    # Output data tests
    assert not output_df.isnull().any().any(), "Output data contains null values."
    
    assert len(output_df) == input_df.groupby(['country', 'licensor']).ngroups, \
        "Number of rows in the output doesn't match the expected group count."
    market_share_sums = output_df.groupby('country')['market_share'].sum()
    assert (abs(market_share_sums - 1) < 1e-6).all(), \
        "Market share does not sum to 1 for some countries."
    print("All data quality tests passed!")



# Output result
grouped


Unnamed: 0,country,licensor,play_count,market_share
0,DE,Mere Lean Music,2,0.222222
1,DE,So New Inc.,2,0.222222
2,DE,Universe Muzak Corp.,2,0.222222
3,DE,Werner Licensing,3,0.333333
4,GB,Mere Lean Music,2,0.090909
5,GB,So New Inc.,5,0.227273
6,GB,Universe Muzak Corp.,10,0.454545
7,GB,Werner Licensing,5,0.227273
8,SE,So New Inc.,1,0.083333
9,SE,Universe Muzak Corp.,6,0.5


In [4]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

The operation couldn’t be completed. Unable to locate a Java Runtime.
Please visit http://www.java.com for information on installing Java.

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pyspark/bin/spark-class: line 97: CMD: bad array subscript
head: illegal line count -- -1


PySparkRuntimeError: [JAVA_GATEWAY_EXITED] Java gateway process exited before sending its port number.

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum

# Initialize Spark session
spark = SparkSession.builder.appName("MarketShareCalculation").getOrCreate()

# Load data
plays = spark.read.csv('data/plays.csv', header=True, inferSchema=True)
locations = spark.read.csv('data/locations.csv', header=True, inferSchema=True)
tracks = spark.read.csv('data/tracks.csv', header=True, inferSchema=True)

# Filter valid plays
valid_plays = plays.filter(col('played_ms') >= 60000)

# Join data
plays_with_location = valid_plays.join(locations, valid_plays['location_id'] == locations['id'], 'left')
plays_with_all_info = plays_with_location.join(tracks, valid_plays['track_id'] == tracks['id'], 'left')

# Aggregate data
grouped = plays_with_all_info.groupBy('country', 'licensor').count().withColumnRenamed('count', 'play_count')
totals = grouped.groupBy('country').agg(spark_sum('play_count').alias('total_play_count'))
grouped = grouped.join(totals, 'country')
grouped = grouped.withColumn('market_share', col('play_count') / col('total_play_count'))

# Calculate market share sum
market_share_sum = grouped.groupBy('country').agg(spark_sum('market_share').alias('market_share_sum'))

# Show result
market_share_sum.show()

# Write result to CSV
market_share_sum.write.csv('market_share_sum.csv', header=True)

# Stop Spark session
spark.stop()

The operation couldn’t be completed. Unable to locate a Java Runtime.
Please visit http://www.java.com for information on installing Java.

/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/pyspark/bin/spark-class: line 97: CMD: bad array subscript
head: illegal line count -- -1


PySparkRuntimeError: [JAVA_GATEWAY_EXITED] Java gateway process exited before sending its port number.