In [80]:
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import udf, col, lower, avg, round, count
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType
import pycountry

In [81]:
spark = SparkSession\
    .builder\
    .appName("Solver problema 1")\
    .getOrCreate()

In [82]:
# Para esta primera parte basta con inferencia para extraer la data necesaria.
file = "../tests/vshort-users-details-2023.csv"

df_user_location = spark\
	.read\
	.format("csv")\
	.option("header", "true")\
	.load(file)

df_user_location.show(5)

+------+--------+------+--------------------+--------------------+--------------------+------------+----------+--------+---------+-------+-------+-------------+-------------+---------+----------------+
|Mal ID|Username|Gender|            Birthday|            Location|              Joined|Days Watched|Mean Score|Watching|Completed|On Hold|Dropped|Plan to Watch|Total Entries|Rewatched|Episodes Watched|
+------+--------+------+--------------------+--------------------+--------------------+------------+----------+--------+---------+-------+-------+-------------+-------------+---------+----------------+
|     1|   Xinil|  Male|1985-03-04T00:00:...|          California|2004-11-05T00:00:...|       142.3|      7.37|     1.0|    233.0|    8.0|   93.0|         64.0|        399.0|     60.0|          8458.0|
|     3| Aokaado|  Male|                NULL|        Oslo, Norway|2004-11-11T00:00:...|        68.6|      7.34|    23.0|    137.0|   99.0|   44.0|         40.0|        343.0|     15.0|        

In [83]:
# Se genera la VIEW
df_user_location.createOrReplaceTempView("df_user_location")


In [84]:
# Por DATAFRAME
id_location = df_user_location.select("Mal ID", "Location")

In [85]:
# Por SQL
id_location = spark.sql("SELECT `Mal ID`, Location FROM df_user_location")

In [86]:
# Mapeo manual de abreviaciones de subdivisiones a países
ABBREVIATION_MAP = {
    "ab": "Canada",
    "bc": "Canada",
    "on": "Canada",
    "qc": "Canada",
    "ny": "United States",
    "ca": "United States",
    "tx": "United States",
    "usa": "United States",
    "uk": "United Kingdom",
    "uae": "United Arab Emirates",
    "prc": "China",
    "us": "United States",
    "gb": "United Kingdom",
    "ru": "Russia",
    "au": "Australia",
    "nz": "New Zealand",
    "de": "Germany",
    "fr": "France",
    "es": "Spain",
    "it": "Italy",
    "jp": "Japan",
    "cn": "China",
    "kr": "South Korea",
    "br": "Brazil",
    "mx": "Mexico",
    "in": "India",
    "za": "South Africa",
    "se": "Sweden",
    "no": "Norway",
    "fi": "Finland",
    "dk": "Denmark",
    "ch": "Switzerland",
    "nl": "Netherlands",
    "be": "Belgium",
    "at": "Austria",
    "pl": "Poland",
    "ar": "Argentina",
    "cl": "Chile",
    "co": "Colombia",
    "ve": "Venezuela",
    "pe": "Peru",
    "ph": "Philippines",
    "id": "Indonesia",
    "my": "Malaysia",
    "sg": "Singapore",
    "th": "Thailand",
    "vn": "Vietnam",
    "sa": "Saudi Arabia",
    "eg": "Egypt",
    "tr": "Turkey",
    "gr": "Greece",
    "pt": "Portugal",
    "cz": "Czech Republic",
    "sk": "Slovakia",
    "hu": "Hungary",
    "ro": "Romania",
    "bg": "Bulgaria",
    "rs": "Serbia",
    "hr": "Croatia",
    "si": "Slovenia",
    "ua": "Ukraine",
    "by": "Belarus",
    "lt": "Lithuania",
    "lv": "Latvia",
    "ee": "Estonia",
    "is": "Iceland",
    "ie": "Ireland",
    "pk": "Pakistan",
    "bd": "Bangladesh",
    "lk": "Sri Lanka",
    "np": "Nepal",
    "af": "Afghanistan",
    "ir": "Iran",
    "iq": "Iraq",
    "sy": "Syria",
    "jo": "Jordan",
    "lb": "Lebanon",
    "il": "Israel",
    "ps": "Palestine",
    "kw": "Kuwait",
    "qa": "Qatar",
    "bh": "Bahrain",
    "om": "Oman",
    "ye": "Yemen",
    "dz": "Algeria",
    "ma": "Morocco",
    "tn": "Tunisia",
    "ly": "Libya",
    "sd": "Sudan",
    "ng": "Nigeria",
    "gh": "Ghana",
    "ke": "Kenya",
    "tz": "Tanzania",
    "ug": "Uganda",
    "zm": "Zambia",
    "zw": "Zimbabwe",
    "bw": "Botswana",
    "na": "Namibia",
    "ao": "Angola",
    "mz": "Mozambique",
    "mg": "Madagascar",
    "et": "Ethiopia",
    "sn": "Senegal",
    "ci": "Ivory Coast",
    "cm": "Cameroon",
    "cd": "Congo",
    "cg": "Republic Of The Congo",
    "ga": "Gabon",
    "gq": "Equatorial Guinea",
    "cv": "Cape Verde",
    "st": "Sao Tome And Principe",
    "sc": "Seychelles",
    "mu": "Mauritius",
    "km": "Comoros",
    "dj": "Djibouti",
    "er": "Eritrea",
    "so": "Somalia",
    "rw": "Rwanda",
    "bi": "Burundi",
    "mw": "Malawi",
    "ls": "Lesotho",
    "sz": "Eswatini",
    "sl": "Sierra Leone",
    "lr": "Liberia",
    "gm": "Gambia",
    "gw": "Guinea-Bissau",
    "gn": "Guinea",
    "ml": "Mali",
    "bf": "Burkina Faso",
    "ne": "Niger",
    "td": "Chad",
    "mr": "Mauritania",
    "bj": "Benin",
    "tg": "Togo",
    "cf": "Central African Republic",
    "ss": "South Sudan",
    "bt": "Bhutan",
    "mv": "Maldives",
    "kh": "Cambodia",
    "la": "Laos",
    "mm": "Myanmar",
    "bn": "Brunei",
    "tl": "Timor-Leste",
    "pg": "Papua New Guinea",
    "fj": "Fiji",
    "ws": "Samoa",
    "to": "Tonga",
    "vu": "Vanuatu",
    "sb": "Solomon Islands",
    "ki": "Kiribati",
    "tv": "Tuvalu",
    "nr": "Nauru",
    "pw": "Palau",
    "mh": "Marshall Islands",
    "fm": "Micronesia",
    "as": "American Samoa",
    "gu": "Guam",
    "mp": "Northern Mariana Islands",
    "ck": "Cook Islands",
    "nu": "Niue",
    "wf": "Wallis And Futuna",
    "pf": "French Polynesia",
    "nc": "New Caledonia",
    "tk": "Tokelau",
    "pn": "Pitcairn Islands",
    "gs": "South Georgia And The South Sandwich Islands",
    "sh": "Saint Helena",
    "fk": "Falkland Islands",
    "ai": "Anguilla",
    "bm": "Bermuda",
    "vg": "British Virgin Islands",
    "ky": "Cayman Islands",
    "ms": "Montserrat",
    "tc": "Turks And Caicos Islands",
    "vi": "US Virgin Islands",
    "pr": "Puerto Rico",
    "um": "United States Minor Outlying Islands",
    "hk": "Hong Kong",
    "mo": "Macau",
    "tw": "Taiwan",
    "fo": "Faroe Islands",
    "gl": "Greenland",
    "ax": "Aland Islands",
    "je": "Jersey",
    "gg": "Guernsey",
    "im": "Isle Of Man",
    "yt": "Mayotte",
    "re": "Reunion",
    "mq": "Martinique",
    "gp": "Guadeloupe",
    "bl": "Saint Barthelemy",
    "mf": "Saint Martin",
    "pm": "Saint Pierre And Miquelon",
    "tf": "French Southern Territories",
    "bv": "Bouvet Island",
    "hm": "Heard Island And Mcdonald Islands",
    "aq": "Antarctica",
    "cw": "Curacao",
    "sx": "Sint Maarten",
    "bq": "Caribbean Netherlands",
    "xk": "Kosovo"
}

# Función para identificar el país a partir de la ubicación
def get_country(location):
    if not location or location.strip() == "":
        return "UNKNOWN"
    
    location = location.lower().strip()
    location_parts = [part.strip() for part in location.split(",")]  # Dividir por comas y limpiar espacios
    
    # Verificar si alguna parte coincide con una abreviación en el mapeo manual
    for part in location_parts:
        if part in ABBREVIATION_MAP:
            return ABBREVIATION_MAP[part]  # Devolver el país correspondiente
    
    # Verificar si el país ya está incluido en la ubicación
    for country in pycountry.countries:
        if country.name.lower() in location_parts or \
           country.alpha_2.lower() in location_parts or \
           country.alpha_3.lower() in location_parts:
            return country.name
    
    # Verificar subdivisiones (estados, provincias, etc.)
    for subdivision in pycountry.subdivisions:
        if subdivision.name.lower() in location_parts:
            return pycountry.countries.get(alpha_2=subdivision.country_code).name
    
    # Si no se encuentra, devolver "UNKNOWN"
    return "UNKNOWN"

# Registrar la función como UDF
get_country_udf = udf(get_country, StringType())

# Normalizar la columna Location y aplicar la función para obtener el país
id_location_with_country = id_location.withColumn("Country", get_country_udf(lower(col("Location"))))

# Seleccionar solo las columnas Mal ID y Country
id_location_with_country_final = id_location_with_country.select("Mal ID", "Country")

# Mostrar los resultados
id_location_with_country_final.show(20, truncate=False)

[Stage 8:>                                                          (0 + 1) / 1]

+------+-------------+
|Mal ID|Country      |
+------+-------------+
|1     |United States|
|3     |Norway       |
|4     |Australia    |
|9     |UNKNOWN      |
|18    |UNKNOWN      |
|20    |Norway       |
|23    |Canada       |
|36    |UNKNOWN      |
|44    |UNKNOWN      |
|47    |UNKNOWN      |
|53    |UNKNOWN      |
|66    |Canada       |
|70    |UNKNOWN      |
|71    |UNKNOWN      |
|77    |UNKNOWN      |
|80    |UNKNOWN      |
|82    |France       |
|83    |United States|
|90    |UNKNOWN      |
|91    |Canada       |
+------+-------------+
only showing top 20 rows



                                                                                

In [87]:
file = "../tests/vshort-users-score-2023.csv"

df_user_rating = spark\
	.read\
	.format("csv")\
	.option("header", "true")\
	.load(file)

df_user_rating.show(5)

+-------+--------+--------+--------------------+------+
|user_id|Username|anime_id|         Anime Title|rating|
+-------+--------+--------+--------------------+------+
|      1|   Xinil|      21|           One Piece|     9|
|      1|   Xinil|      48|         .hack//Sign|     7|
|      1|   Xinil|     320|              A Kite|     5|
|      1|   Xinil|      49|    Aa! Megami-sama!|     8|
|      1|   Xinil|     304|Aa! Megami-sama! ...|     8|
+-------+--------+--------+--------------------+------+
only showing top 5 rows



In [88]:
# EN DATAFRAME

# Calcular la media de los ratings por usuario y redondear a 2 decimales
average_ratings = df_user_rating.groupBy("user_id").agg(round(avg("rating"), 2).alias("average_rating"))

# Mostrar los resultados
average_ratings.show(10, truncate=False)

+-------+--------------+
|user_id|average_rating|
+-------+--------------+
|23     |7.46          |
|9      |7.71          |
|1      |7.44          |
|20     |8.06          |
|37     |7.0           |
|4      |6.52          |
|83     |5.0           |
+-------+--------------+



In [89]:
# EN SQL

# Crear una vista temporal del DataFrame
df_user_rating.createOrReplaceTempView("ratings")

# Ejecutar la consulta SQL para calcular el promedio redondeado a 2 decimales
average_ratings_sql = spark.sql("""
    SELECT 
        user_id, 
        ROUND(AVG(rating), 2) AS average_rating
    FROM ratings
    GROUP BY user_id
""")

# Mostrar los resultados
average_ratings_sql.show(10, truncate=False)

+-------+--------------+
|user_id|average_rating|
+-------+--------------+
|23     |7.46          |
|9      |7.71          |
|1      |7.44          |
|20     |8.06          |
|37     |7.0           |
|4      |6.52          |
|83     |5.0           |
+-------+--------------+



In [90]:
# Renombrar la columna "Mal ID" a "user_id" en id_location_with_country_final
id_location_with_country_final = id_location_with_country_final.withColumnRenamed("Mal ID", "user_id")

In [91]:
# EN DATAFRAME

# Realizar el join entre average_ratings y id_location_with_country_final
joined_df = average_ratings.join(id_location_with_country_final, on="user_id", how="inner")

# Calcular el promedio de los ratings por país y la cantidad de usuarios
average_ratings_by_country = joined_df.groupBy("Country").agg(
    round(avg("average_rating"), 2).alias("average_rating_by_country"),
    count("user_id").alias("user_count")
)

# Mostrar los resultados
average_ratings_by_country.show(20, truncate=False)

                                                                                

+-------------+-------------------------+----------+
|Country      |average_rating_by_country|user_count|
+-------------+-------------------------+----------+
|United States|6.22                     |2         |
|Norway       |8.06                     |1         |
|UNKNOWN      |7.71                     |1         |
|Canada       |7.46                     |1         |
|Australia    |6.52                     |1         |
+-------------+-------------------------+----------+



In [92]:
# EN SQL

# Crear vistas temporales para los DataFrames
average_ratings.createOrReplaceTempView("average_ratings")
id_location_with_country_final.createOrReplaceTempView("id_location_with_country_final")

# Consulta SQL para realizar el JOIN, calcular el promedio por país y contar la cantidad de usuarios
average_ratings_by_country_sql = spark.sql("""
    SELECT 
        c.Country,
        ROUND(AVG(a.average_rating), 2) AS average_rating_by_country,
        COUNT(a.user_id) AS user_count
    FROM 
        average_ratings a
    INNER JOIN 
        id_location_with_country_final c
    ON 
        a.user_id = c.user_id
    GROUP BY 
        c.Country
""")

# Mostrar los resultados
average_ratings_by_country_sql.show(20, truncate=False)

                                                                                

+-------------+-------------------------+----------+
|Country      |average_rating_by_country|user_count|
+-------------+-------------------------+----------+
|United States|6.22                     |2         |
|Norway       |8.06                     |1         |
|UNKNOWN      |7.71                     |1         |
|Canada       |7.46                     |1         |
|Australia    |6.52                     |1         |
+-------------+-------------------------+----------+



In [93]:
spark.stop()