# Preparacion de datos con spark.

In [1]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.storagelevel import StorageLevel



spark = (
    SparkSession.builder
    .appName("AnimeRecommender")
    .config("spark.local.dir","/mnt/HDD1TB/spark_temp")
    .config("spark.sql.shuffle.partitions", "48")
    .config("spark.driver.memory", "4g")
    .config("spark.executor.memory", "4g")
    .config("spark.executor.cores", "2")
    .config("spark.memory.fraction", "0.6")
    .config("spark.memory.storageFraction", "0.3")
    .getOrCreate()
)


26/02/02 13:02:02 WARN Utils: Your hostname, brian-IA resolves to a loopback address: 127.0.1.1; using 192.168.1.52 instead (on interface enp37s0)
26/02/02 13:02:02 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/02 13:02:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/02/02 13:02:02 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [None]:
rating = spark.read.parquet("../data_anime/rating.parquet")
anime  = spark.read.parquet("../data_anime/anime.parquet") ##Solo este fue procesdado anteriormente.


In [76]:
display(rating.show(5))
display(anime.show(5))

+------+-------+------+
|userID|animeID|rating|
+------+-------+------+
|     1|      1|    10|
|     1|      2|    10|
|     1|      3|     7|
|     1|      4|    10|
|     1|      5|    10|
+------+-------+------+
only showing top 5 rows



None

+-------+--------------------+-----+----+-----+--------+--------------------+
|animeID|               title| type|year|score|episodes|        genres_union|
+-------+--------------------+-----+----+-----+--------+--------------------+
|      1|Howl's Moving Castle|MOVIE|2004| 8.41|       1|[folklore, maids,...|
|      2|          Death Note|   TV|2006| 8.63|      37|[united states, g...|
|      3|Problem Children ...|   TV|2013| 7.42|      10|[maid, genius, mu...|
|      4|             BTOOOM!|   TV|2012| 7.34|      12|[suspense, tv cen...|
|      5|    Sword Art Online|   TV|2012|  7.5|      25|[mutilation, fair...|
+-------+--------------------+-----+----+-----+--------+--------------------+
only showing top 5 rows



None

Filtro por cuantiles segun cantidad de ratings, no me sirven usuarios que hayan calificado 1, 2 o 3 veces y tampoco los usuarios que hayan calificado 1000 o mas veces.

El primer caso es obvio, el segundo es para no sesgar al futuro modelo a un grupo de usuarios muy hiperactivos.

In [77]:
user_stats = (
    rating
    .groupBy("userID")
    .agg(
        F.count("*").alias("n_ratings"),
        (F.max("rating") - F.min("rating")).alias("rating_range")
    )
)


#            DataFrame.approxQuantile(col, probabilities, relativeError)
quantiles = user_stats.approxQuantile("n_ratings", [0.2, 0.8], 0.01)
low, high = quantiles
print(low, high)

valid_users = (
    user_stats
    .filter(
        (F.col("rating_range") >= 3) &
        (F.col("n_ratings") >= low) &
        (F.col("n_ratings") <= high)
    )
    .select("userID")
)

rating_filt = rating.join(valid_users, "userID")
rating_filt.persist(StorageLevel.MEMORY_AND_DISK)




8.0 117.0


25/12/26 15:36:19 WARN CacheManager: Asked to cache already cached data.        


DataFrame[userID: bigint, animeID: bigint, rating: bigint]

In [79]:
rating_filt.show(10)

+------+-------+------+
|userID|animeID|rating|
+------+-------+------+
|    28|    836|    10|
|    28|     24|    10|
|    28|    512|    10|
|    28|     99|    10|
|    28|      6|    10|
|    28|      7|    10|
|    28|      1|    10|
|    28|    107|    10|
|    28|     33|    10|
|    28|    472|    10|
+------+-------+------+
only showing top 10 rows



OK OK OK, ahora vamos con el set de anime/items que ya fue proprocesdado antes de cargarlo

Expando los generos!

In [None]:
item_genres = (
    anime
    .select("animeID", F.explode("genres_union").alias("genre"))
    .withColumn("genre", F.lower(F.trim("genre"))) #ya fueron procesadoos peeero por que no.
)


In [81]:
item_genres.show(10)
#deberia filtrar los generos!

+-------+--------------+
|animeID|         genre|
+-------+--------------+
|      1|      folklore|
|      1|         maids|
|      1|     steampunk|
|      1|boy meets girl|
|      1|    disability|
|      1|          mina|
|      1|         witch|
|      1|  happy ending|
|      1|     adventure|
|      1|       romance|
+-------+--------------+
only showing top 10 rows



In [85]:
genre_freq = (
    item_genres
    .groupBy("genre")
    .count()
)

print(genre_freq.orderBy("count", ascending=False).show())
print(f"genre_freq.count(): {genre_freq.count()}")

+-------------------+-----+
|              genre|count|
+-------------------+-----+
|             comedy| 6711|
|             action| 5223|
|            fantasy| 4513|
|          adventure| 3579|
|             sci-fi| 3153|
|              drama| 2767|
|            romance| 2162|
|             hentai| 1599|
|       supernatural| 1504|
|      slice of life| 1395|
|            mystery|  957|
|              ecchi|  829|
|        avant garde|  787|
|             sports|  724|
|             horror|  564|
|           suspense|  443|
|             summer|  313|
|          dystopian|  313|
|alternative present|  309|
|   character driven|  309|
+-------------------+-----+
only showing top 20 rows

None
genre_freq.count(): 320


In [86]:
#genre_freq.show()
total = genre_freq.agg(F.sum("count")).first()[0]
#print(f"Total de géneros: {total}")

#genre_freq = genre_freq.orderBy(F.desc("count"))
#vamos a graficarlos
#quantiles = genre_freq.approxQuantile("count", [0.2, 0.8], 0.01)
#low, high = quantiles

Total de géneros: 101330


Igual que antes, tomo entre el 20 y 80 porciento!

In [88]:
uig = (
    rating_filt
    .join(item_genres, "animeID")
)

uig.show()
print(f"Number of rows: {uig.count()}")

+-------+------+------+--------------------+
|animeID|userID|rating|               genre|
+-------+------+------+--------------------+
|    836|    28|    10|  japanese mythology|
|    836|    28|    10|       body and host|
|    836|    28|    10|       slice of life|
|    836|    28|    10|              hanami|
|    836|    28|    10|            adoption|
|    836|    28|    10|             orphans|
|    836|    28|    10|            folklore|
|    836|    28|    10|     secret identity|
|    836|    28|    10|        supernatural|
|     24|    28|    10|               curse|
|     24|    28|    10|  japanese mythology|
|     24|    28|    10|dissociative iden...|
|     24|    28|    10|   elementary school|
|     24|    28|    10|           afterlife|
|     24|    28|    10|       body and host|
|     24|    28|    10|       slice of life|
|     24|    28|    10|  outdoor activities|
|     24|    28|    10|              hanami|
|     24|    28|    10|            adoption|
|     24| 



Number of rows: 622132390


                                                                                

In [89]:
user_genre_profile = (
    uig
    .groupBy("userID", "genre")
    .agg(
        F.avg("rating").alias("genre_pref"), # El promedio de calificaciones que ese usuario le ha dado a películas de ese género
        F.count("*").alias("n_items")         #El número de veces que el usuario ha calificado películas de ese género
    )
)

user_genre_profile.show(30)

[Stage 431:>                                                        (0 + 1) / 1]

+------+--------------------+-----------------+-------+
|userID|               genre|       genre_pref|n_items|
+------+--------------------+-----------------+-------+
|    28|  japanese mythology|             10.0|     11|
|    28|       body and host|             10.0|      7|
|    28|       slice of life|             10.0|      5|
|    28|              hanami|             10.0|      5|
|    28|            adoption|             10.0|      8|
|    28|             orphans|             10.0|      7|
|    28|            folklore|             10.0|     10|
|    28|     secret identity|             10.0|      5|
|    28|        supernatural|             10.0|     13|
|    28|               curse|             10.0|      9|
|    28|dissociative iden...|             10.0|      3|
|    28|   elementary school|             10.0|      3|
|    28|           afterlife|             10.0|      4|
|    28|  outdoor activities|             10.0|      5|
|    28|             archery|             10.0| 

                                                                                



* Para puntuar un género (g) para un usuario (u) usás una función **local, estable y acotada**.
  $$
  s_{u,g} = p_{u,g}\cdot \log(1+n_{u,g})
  $$

* Lo transformás con una **sigmoid escalada** para fijar un máximo interpretable:

  $$
  {genre\_score}_{u,g}=\frac{10}{1+e^{-\alpha s_{u,g}}}
  $$

* Propiedades clave:

  * **Nunca supera 10** (no requiere clipping).
  * **Monótona**: más preferencia o más evidencia ⇒ mayor score.
  * **Saturación**: evita explosiones en usuarios muy activos.
  * **Escala a grandes volúmenes**: sin ventanas, sin dependencias globales, cálculo por fila.
  * **$\alpha$**: parámetro que controla la saturación ($\alpha < 1$ signoide mas plana).
> Nota1: la otra forma es hacer un promedio ponderado de la probabilidad de género, esto requeria funciones un poco mas complicada y uso de windows, todo esto se tornaba complicado en spark y traia muchos errores y warnings.

> Nota2: la funcion signoidea saturaba rapidamente debido a los scores obtenidos, por ello habria que modificar la sigmoidea aplanandola (se suele hacer en motodos de calibración) el alpha elegido es empirico de una muestra de datos, puede no ser suficiente para todos los datos o ser demasiado pero algunos. 

In [93]:
user_genre_profile_norm = (
    user_genre_profile
    .withColumn(
        "genre_score",
        F.col("genre_pref") * F.log1p(F.col("n_items"))
    )
    .select("userID", "genre", "genre_score")
)

alpha = 0.1
user_genre_profile_norm = (
    user_genre_profile
    .withColumn(
        "genre_score",
        10 / (1 + F.exp(-alpha * (F.col("genre_pref") * F.log1p(F.col("n_items")))))
    )
    .select("userID", "genre", "genre_score")
)


En este punto me gustaria pivotear genere y genre_score, seria 320-2 nuevas columnas sin embarego no todos los usuario tiene tiene los generos, esto daria una matriz muy sparse... desperdicios de recursos.

Entonces, lo dejamos asi como está con 171.693.371 de filas (un maximo de 320 filas por usuario).

In [94]:
user_genre_profile_norm.show(30)
print(f"user_genre_profile_norm count: {user_genre_profile_norm.count()}")
#user_genre_profile_norm count: 171693371

[Stage 451:>                                                        (0 + 1) / 1]

+------+--------------------+-----------------+
|userID|               genre|      genre_score|
+------+--------------------+-----------------+
|    28|  japanese mythology|9.230769230769232|
|    28|       body and host| 8.88888888888889|
|    28|       slice of life|8.571428571428571|
|    28|              hanami|8.571428571428571|
|    28|            adoption|              9.0|
|    28|             orphans| 8.88888888888889|
|    28|            folklore|9.166666666666668|
|    28|     secret identity|8.571428571428571|
|    28|        supernatural|9.333333333333334|
|    28|               curse|9.090909090909092|
|    28|dissociative iden...|              8.0|
|    28|   elementary school|              8.0|
|    28|           afterlife|8.333333333333334|
|    28|  outdoor activities|8.571428571428571|
|    28|             archery|7.499999999999999|
|    28|              dragon|7.499999999999999|
|    28|cute boys doing c...|6.666666666666667|
|    28|                trap|6.666666666

                                                                                

In [96]:
user_features = user_genre_profile_norm #LA MAS PESADA (uderId, genre, genre_score)
item_features = ( anime.select( "animeID","type", "year", "score", "episodes", "genres_union"))
interactions  = rating_filt.select("userID", "animeID", "rating")

#Acá la idea es que se use interactions para el training y el test, cada interaccion usuario anime se irá a buscar
#en item_features para obtener el animeID, type, year, score, episodes, genres_union(deberá generarse un oneHot o algun embedding)
# y en user_features para obtener el userId, genre, genre_score


In [None]:
display(user_features.show(30))
display(item_features.show(30))
display(interactions.show(30))

[Stage 503:>                                                        (0 + 1) / 1]

+------+--------------------+-----------------+
|userID|               genre|      genre_score|
+------+--------------------+-----------------+
|    28|  japanese mythology|9.230769230769232|
|    28|       body and host| 8.88888888888889|
|    28|       slice of life|8.571428571428571|
|    28|              hanami|8.571428571428571|
|    28|            adoption|              9.0|
|    28|             orphans| 8.88888888888889|
|    28|            folklore|9.166666666666668|
|    28|     secret identity|8.571428571428571|
|    28|        supernatural|9.333333333333334|
|    28|               curse|9.090909090909092|
|    28|dissociative iden...|              8.0|
|    28|   elementary school|              8.0|
|    28|           afterlife|8.333333333333334|
|    28|  outdoor activities|8.571428571428571|
|    28|             archery|7.499999999999999|
|    28|              dragon|7.499999999999999|
|    28|cute boys doing c...|6.666666666666667|
|    28|                trap|6.666666666

                                                                                

None

+-------+-----+----+-----+--------+--------------------+
|animeID| type|year|score|episodes|        genres_union|
+-------+-----+----+-----+--------+--------------------+
|      1|MOVIE|2004| 8.41|       1|[folklore, maids,...|
|      2|   TV|2006| 8.63|      37|[united states, g...|
|      3|   TV|2013| 7.42|      10|[maid, genius, mu...|
|      4|   TV|2012| 7.34|      12|[suspense, tv cen...|
|      5|   TV|2012|  7.5|      25|[mutilation, fair...|
|      6|MOVIE|2001| 8.64|       1|[adventure, paral...|
|      7|MOVIE|1997| 8.59|       1|[folklore, stand-...|
|      8|   TV|2012|  8.0|      25|[magic weapons, s...|
|      9|   TV|2012| 7.31|      24|[tv censoring, mu...|
|     10|   TV|2009| 7.81|      11|[united states, s...|
|     11|MOVIE|2010| 7.95|       1|[supernatural, su...|
|     12|   TV|1998| 8.66|      26|[cyborg, crime fi...|
|     13|   TV|2003| 8.29|      51|[gangs, promise, ...|
|     14|   TV|1992| 8.23|     112|[supernatural, we...|
|     15|   TV|2008| 7.88|     

None

+------+-------+------+
|userID|animeID|rating|
+------+-------+------+
|    28|    836|    10|
|    28|     24|    10|
|    28|    512|    10|
|    28|     99|    10|
|    28|      6|    10|
|    28|      7|    10|
|    28|      1|    10|
|    28|    107|    10|
|    28|     33|    10|
|    28|    472|    10|
|    28|    108|    10|
|    28|   1222|    10|
|    28|    155|    10|
|    28|     42|    10|
|    28|    917|    10|
|    28|    144|    10|
|    28|    481|    10|
|    28|      2|    10|
|    28|    478|    10|
|    28|     62|    10|
|    28|     31|    10|
|    28|    739|    10|
|    28|   1476|    10|
|    28|    335|    10|
|    28|    117|    10|
|    28|    248|    10|
|    28|     39|    10|
|    28|    518|    10|
|    28|   1010|    10|
|    28|    368|    10|
+------+-------+------+
only showing top 30 rows



None

In [98]:
#print(f"user_features.count(): {user_features.count()}")
#print(f"item_features.count(): {item_features.count()}")
#print(f"interactions.count(): {interactions.count()}")

#user_features.count(): 171_693_371
#item_features.count(): 19_598
#interactions.count(): 38_644_585

In [99]:

user_features.write.mode("overwrite").parquet("../data_anime/prod/user_features")
item_features.write.mode("overwrite").parquet("../data_anime/prod/item_features")
interactions.write.mode("overwrite").parquet("../data_anime/prod/interactions")

25/12/26 16:10:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:22 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:22 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:22 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:22 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/12/26 16:10:22 WARN RowBasedKeyValueBatch: Calling spill() on

Chequeamos si se pueden leer.

In [7]:
!ls ../../data/

anime	    glove.6B.100d.txt  glove.6B.300d.txt  glove.6B.zip
data_anime  glove.6B.200d.txt  glove.6B.50d.txt   movies.csv


In [3]:
user_features2 = spark.read.parquet("../../data/data_anime/prodBKP/user_features")
item_features2 = spark.read.parquet("../../data/data_anime/prodBKP/item_features")
interactions2  = spark.read.parquet("../../data/data_anime/prodBKP/interactions")


user_features2.printSchema()
item_features2.printSchema()
interactions2.printSchema()


root
 |-- userID: long (nullable = true)
 |-- genre: string (nullable = true)
 |-- genre_score: double (nullable = true)

root
 |-- animeID: long (nullable = true)
 |-- type: string (nullable = true)
 |-- year: long (nullable = true)
 |-- score: double (nullable = true)
 |-- episodes: long (nullable = true)
 |-- genres_union: array (nullable = true)
 |    |-- element: string (containsNull = true)

root
 |-- userID: long (nullable = true)
 |-- animeID: long (nullable = true)
 |-- rating: long (nullable = true)



In [3]:
print("users:", user_features2.select("userID").distinct().count())
print("items:", item_features2.select("animeID").distinct().count())
print("interactions:", interactions2.count())


                                                                                

users: 934941
items: 19598
interactions: 38644585


# Ultimo paso, preparar datos para el modelo.

tengo esto:

- *(userID, genre, genre_score)*

Necesito:

- *userID → vector[320]*

voy a tener que crear un vector sparse

**Opción INDUSTRIAL: vector sparse → dense solo en PyTorch**

En Spark:
La idea es reemplazar los generos por unos ids, simplemente volverlos numericos, con eso ya los voy a tomar desde PyTorch y los voy a voler densos.

In [None]:
# Indexar géneros
genres = (
    user_features2
    .select("genre")
    .distinct()
    .withColumn("genre_idx", F.monotonically_increasing_id())
)

# [Stage 517:===========================>                          (12 + 12) / 24]
# +--------------------+---------+
# |               genre|genre_idx|
# +--------------------+---------+
# |              action|        0|
# |               deity|        1|
# |              orphan|        2|
# |          ed variety|        3|
# |           political|        4|
# | weekly shounen jump|        5|
# |           summoning|        6|
# |               fairy|        7|
# |              juujin|        8|
# |          stoic hero|        9|
# |   based on an eroge|       10|
# |south korean prod...|       11|
# |             orphans|       12|
# |           afterlife|       13|
# |                mina|       14|
# |       environmental|       15|
# |               maids|       16|
# |             teacher|       17|
# |        robot helper|       18|
# |          cosplaying|       19|
# +--------------------+---------+
# only showing top 20 rows

user_genre_idx = (
    user_features2
    .join(genres, "genre")
    .select("userID", "genre_idx", "genre_score")
)
# +------+---------+-----------------+
# |userID|genre_idx|      genre_score|
# +------+---------+-----------------+
# |    28|        0|9.277951988097573|
# |    28|        1|7.499999999999999|
# |    28|        3|6.666666666666667|
# |    28|        2|8.333333333333334|
# |    28|        4|6.666666666666667|
# |    28|        6|7.499999999999999|
# |    28|        5|8.571428571428571|
# |    67|        0|9.224372175665653|
# |    67|        3|6.666666666666667|
# |    67|        7|6.666666666666667|
# |    67|        8|6.666666666666667|
# |    67|        2|8.571428571428571|
# |    67|        4|6.666666666666667|
# |    67|        5|8.571428571428571|
# |   151|        0|9.517256999713751|
# |   151|        1|6.510896797541332|
# |   151|        3|7.343128869979563|
# |   151|        7|7.768346216722623|
# |   151|        8|8.407144059077158|
# |   151|        2|7.848023065213549|
# +------+---------+-----------------+
# only showing top 20 rows


#Este va a ser el mismo set de datos solo que cambiamos los generos por unos indices, es decir lo volvimos numericos!
user_genre_idx.write.mode("overwrite").parquet("../data_anime/prod/user_genre_idx")
#Este va a ser mi diccionario {generos: indices} no es necesario pero si sirve para interpretar los indices, no ocupa mucho espacio.
genres.write.mode("overwrite").csv("../data_anime/prod/vocab_genres.csv")


**Item features → vector numérico + multi-hot**

se tiene:

- *animeID, type, year, score, episodes, genres_union*

Transformación recomendada:

In [None]:
from pyspark.sql.functions import col, explode
#Estos son los tipos de dato numericos, siguen numericos.
item_base = item_features.select(
    "animeID",
    F.col("score").cast("double"),
    F.col("episodes").cast("double"),
    F.col("year").cast("double")
)

# Indexar géneros de ítems
#GENRES Y GENRES_UNION SON LO MISMO Y DEBEN
item_genre_idx = (
    item_features
    .select("animeID", explode("genres_union").alias("genre"))
    .join(genres, "genre")
    .select("animeID", "genre_idx")
)

item_base.write.mode("overwrite").parquet(
    "../data_anime/prod/item_base"
)

item_genre_idx.write.mode("overwrite").parquet(
    "../data_anime/prod/item_genre_idx"
)

                                                                                

**Crear interacciones con NEGATIVE SAMPLING (Spark)**

Se considera lo siguiente

| Rating | Interpretación        |
| -----: | --------------------- |
|   9–10 | fan / fuerte afinidad |
|    7–8 | le gustó              |
|    5–6 | neutral               |
|    1–4 | rechazo               |


*Positivos*

In [19]:
interactions2.columns

['userID', 'animeID', 'rating']

Rating_threshold = 7, de ahi en adelante van a ser positivos y si es menor a 7 van a ser negativos.

In [4]:
# Definir las fracciones para el muestreo estratificado (80% entrenamiento, 20% prueba)
fractions = {rating: 0.8 for rating in range(1, 11)}  # Asumiendo ratings del 1 al 10

# Crear el conjunto de entrenamiento con muestreo estratificado
train = interactions2.sampleBy("rating", fractions, seed=777)

# Crear el conjunto de prueba con un anti-join
test = interactions2.join(
    train.select("userID", "animeID"),
    on=["userID", "animeID"],
    how="left_anti"
)

# Verificar las proporciones
print("Distribución en el conjunto original:")
interactions2.groupBy("rating").count().orderBy("rating").show()

print("\nDistribución en entrenamiento:")
train.groupBy("rating").count().orderBy("rating").show()

print("\nDistribución en prueba:")
test.groupBy("rating").count().orderBy("rating").show()

# Mostrar totales
print(f"\nTotal de interacciones: {interactions2.count()}")
print(f"Entrenamiento: {train.count()} (80%)")
print(f"Prueba: {test.count()} (20%)")

Distribución en el conjunto original:


                                                                                

+------+--------+
|rating|   count|
+------+--------+
|     0|   28064|
|     1|  612643|
|     2|  238678|
|     3|  278731|
|     4| 1205351|
|     5| 1063496|
|     6| 2864471|
|     7| 6230667|
|     8| 8919432|
|     9| 6055601|
|    10|11147451|
+------+--------+


Distribución en entrenamiento:


                                                                                

+------+-------+
|rating|  count|
+------+-------+
|     1| 489745|
|     2| 191072|
|     3| 223463|
|     4| 963636|
|     5| 851174|
|     6|2291582|
|     7|4983404|
|     8|7134531|
|     9|4844355|
|    10|8917369|
+------+-------+


Distribución en prueba:


26/02/02 13:25:31 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:31 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:31 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:31 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:31 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:32 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:33 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:33 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/02/02 13:25:34 WARN RowBasedKeyValueBatch: Calling spill() on

+------+-------+
|rating|  count|
+------+-------+
|     0|  28007|
|     1| 122361|
|     2|  47213|
|     3|  54740|
|     4| 240599|
|     5| 209884|
|     6| 567503|
|     7|1236149|
|     8|1767300|
|     9|1196170|
|    10|2210482|
+------+-------+


Total de interacciones: 38644585
Entrenamiento: 30890331 (80%)




Prueba: 7680408 (20%)


                                                                                

Asegurar que TODOS los usuarios queden en train

Detectar usuarios sin train

In [None]:
#     # Ids del train
#     users_in_train = train_pos.select("userID").distinct()
#     
#     users_without_train = (
#         positives.select("userID").distinct()
#         .join(users_in_train, "userID", "left_anti") #resta de conjuntos!==left_anti
#     )
#     
#     #Fuerzo al menos una interaccion e train.
#     fallback_train = (
#         positives
#         .join(users_without_train, "userID")
#         .sample(fraction=1.0)  # todos
#         .limit(users_without_train.count())
#     )
#     
#     train_pos = train_pos.unionByName(fallback_train).distinct()
#     
#     
#     #usuarios sin test no es grave, luego se filtran.
#     users_in_test = test_pos.select("userID").distinct()
#     test_pos = test_pos.join(users_in_test, "userID")
#     
#     train_pos.show()

26/01/19 23:01:20 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/19 23:01:20 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
                                                                                

+------+-------+------+
|userID|animeID|rating|
+------+-------+------+
|   128|   2230|    10|
|   378|    152|    10|
|   378|    394|    10|
|   421|     19|    10|
|   462|    159|     7|
|   515|    584|    10|
|   559|   2946|     8|
|   730|      7|    10|
|   730|      2|     8|
|   813|    816|    10|
|   922|     16|    10|
|   949|    877|    10|
|  1050|      5|    10|
|  1233|    131|    10|
|  1375|   4688|     8|
|  1516|    484|     9|
|  1551|   4133|    10|
|  1551|   2899|     7|
|  1895|    216|     8|
|  2000|    919|    10|
+------+-------+------+
only showing top 20 rows



*Negativos (con sampleo controlado)*

En producción NO se usa 1% global cuando el catálogo es grande.

Se usa fixed negative sampling por usuario.

Estrategia estándar (YouTube / Meta / Spotify)

k negativos por positivo, típicamente:

k = 5

k = 10

k = 20

In [None]:
#    #item_features.cols |animeID| type|year|score|episodes|        genres_union|
#    all_items = item_features2.select("animeID").distinct()                              #todos los animeIDs
#    pos_counts = train_pos.groupBy("userID").count()
#    
#    user_items = train_pos.groupBy("userID").agg(
#        F.collect_set("animeID").alias("seen_items")
#    )
#    
#    # +------+--------------------+
#    # |userID|          seen_items|
#    # +------+--------------------+
#    # |    28|[335, 1, 31, 836,...|
#    # |    67|[147, 503, 901, 2...|
#    
#    
#    #sampleo 7 negativos por positivo (AGREGA MUCHOS DATOS!!! CONSIDERA TODA INTERACCION POSITIVA!!!)
#    # negatives = (
#    #     user_items
#    #     .join(pos_counts, "userID")
#    #     .join(all_items)
#    #     .where(~F.array_contains(F.col("seen_items"), F.col("animeID")))
#    #     .withColumn("rand", F.rand(seed=777))
#    #     .where(F.col("rand") < (F.col("count") * 7 / 10000))  
#    #     # ajustá el denominador según tamaño catálogo
#    #     .select("userID", "animeID")
#    #     .withColumn("label", F.lit(-1))
#    # )
#    
#    #train_interactions = train_pos.unionByName(negatives)
#    train_interactions = (
#        train_pos.select("userID", "animeID")
#        .withColumn("label", F.lit(1))
#        .unionByName(negatives)
#    )
#    
#    train_interactions.count() #38_644_585
#                               #380_171_460

+------+-------+------+-----+
|userID|animeID|rating|label|
+------+-------+------+-----+
|    17|     36|     8|    1|
|    17|      7|    10|    1|
|    17|    275|     8|    1|
|    17|      1|     8|    1|
|    17|    335|    10|    1|
|    17|    409|    10|    1|
|    17|    419|     8|    1|
|    17|     32|     8|    1|
|    17|    476|    10|    1|
|    17|    484|     8|    1|
|    17|    115|    10|    1|
|    17|     86|    10|    1|
|    17|    567|     8|    1|
|    17|     54|     8|    1|
|    17|    583|     8|    1|
|    17|    691|    10|    1|
|    17|     82|    10|    1|
|    17|     70|    10|    1|
|    17|     79|    10|    1|
|    17|    823|     8|    1|
|    17|      5|     8|    1|
|    17|      9|     8|    1|
|    17|   1010|     8|    1|
|    17|   1207|     8|    1|
|    17|   1012|     8|    1|
|    17|    844|    10|    1|
|    17|   1208|     8|    1|
|    17|   1020|     8|    1|
|    17|   1209|     8|    1|
|    17|    923|     8|    1|
|    17|  

In [None]:
#train_interactions.show()
#
# [Stage 867:>                                                        (0 + 1) / 1]
# +------+-------+-----+
# |userID|animeID|label|
# +------+-------+-----+
# |    28|    107|    1|
# |   151|    735|    1|
# |   221|     12|    1|
# |   221|   1100|    1|
# |   325|     39|    1|
# |   325|    199|    1|
# |   347|   1483|    1|
# |   347|    827|    1|
# |   347|    702|    1|
# |   416|    509|    1|
# |   416|     16|    1|
# |   519|    466|    1|
# |   566|    113|    1|
# |   566|    802|    1|
# |   584|      6|    1|
# |   584|   1585|    1|
# |   590|     12|    1|
# |   590|     44|    1|
# |   667|    835|    1|
# |   835|    918|    1|
# +------+-------+-----+
# only showing top 20 rows

# test_pos.show()
# 
# [Stage 998:>                                                        (0 + 1) / 1]
# +------+-------+
# |userID|animeID|
# +------+-------+
# |    28|     99|
# |    28|      7|
# |    28|      1|
# |    28|     33|
# |    28|    108|
# |    28|     62|
# |    28|   1476|
# |    28|    248|
# |    28|    518|
# |    28|     85|
# |    67|    163|
# |    67|     20|
# |    67|    144|
# |    67|    918|
# |    67|   2270|
# |   151|    583|
# |   151|     71|
# |   151|   1023|
# |   151|   1010|
# |   151|     16|
# +------+-------+
# only showing top 20 rows

In [153]:
#QUiero ver la proporcion de valores positivos y negativos
#train_interactions.groupBy("label").count().show()
#
# +-----+---------+
# |label|    count|
# +-----+---------+
# |    1| 25752438|
# |   -1|182799727|
# +-----+---------+

Dataset final

In [None]:
print("train:", train_interactions.count())
print("test:", test.count())

print("train users:", train_interactions.select("userID").distinct().count())
print("test users:", test.select("userID").distinct().count())


26/01/12 13:10:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:12 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:12 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:10:18 WARN RowBasedKeyValueBatch: Calling spill() on

train: 378206284


                                                                                

test: 6474813


26/01/12 13:13:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:32 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:32 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:37 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:37 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:37 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:37 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/12 13:13:37 WARN RowBasedKeyValueBatch: Calling spill() on

train users: 934090




test users: 895558


                                                                                

In [18]:
!ls ../../data/data_anime/prod/

item_base	test_data	    user_final
item_final	train_final	    user_genre_idx
item_genre_idx	train_interactions  vocab_genres.csv


In [5]:
path = "../../data/data_anime/prod/"
train.write.mode("overwrite").parquet(
    path + "train_interactions"
)

test.write.mode("overwrite").parquet(
    path + "test_data"
)


                                                                                

# Obtengo frecuencia de ratings a csv

In [12]:
#cargo los datos guardados
path = "../../data/data_anime/prod/"


train_interactions = spark.read.parquet(path + "train_interactions")
test_interactions  = spark.read.parquet(path + "test_data")

# Calcular el conteo total una sola vez
total_count_train = train_interactions.count()
total_count_test  = test_interactions.count()

# Calcular frecuencias normalizadas
rating_freq_train = (train_interactions
                    .groupBy("rating")
                    .count()
                    .withColumn("normalized_freq", F.col("count") / total_count_train)
                    .orderBy("rating"))

# Calcular frecuencias normalizadas
rating_freq_test = (test_interactions
                    .groupBy("rating")
                    .count()
                    .withColumn("normalized_freq", F.col("count") / total_count_test)
                    .orderBy("rating"))

# Mostrar los resultados
print(f"Total de calificaciones: {total_count_test}")
rating_freq_test.show(truncate=False)

print(f"Total de calificaciones: {total_count_train}")
rating_freq_train.show(truncate=False)

rating_freq_test.write.csv(path + "rating_freq_test.csv")
rating_freq_train.write.csv(path + "rating_freq_train.csv")


Total de calificaciones: 7657904
+------+-------+---------------------+
|rating|count  |normalized_freq      |
+------+-------+---------------------+
|0     |5510   |7.19518029998809E-4  |
|1     |122358 |0.015978001291214933 |
|2     |47213  |0.006165264020024278 |
|3     |54740  |0.0071481700475743756|
|4     |240599 |0.03141838811246524  |
|5     |209884 |0.027407499493333945 |
|6     |567503 |0.07410683132094631  |
|7     |1236148|0.16142119305752592  |
|8     |1767299|0.23078103355696283  |
|9     |1196170|0.15620070452698284  |
|10    |2210480|0.2886533965429705   |
+------+-------+---------------------+

Total de calificaciones: 30912806
+------+-------+---------------------+
|rating|count  |normalized_freq      |
+------+-------+---------------------+
|0     |22475  |7.270449664129487E-4 |
|1     |489745 |0.015842786966670058 |
|2     |191072 |0.0061809982568389295|
|3     |223463 |0.007228816432904861 |
|4     |963636 |0.03117271204691027  |
|5     |851174 |0.02753467284723360

In [None]:
#"../data_anime/prod/vocab_genres.csv")

# "../data_anime/prod/item_base"
# "../data_anime/prod/item_genre_idx"

#"../data_anime/prod/train_interactions"
# "../data_anime/prod/test_data"
#
#"../data_anime/prod/user_genre_idx")

# Test, veamos si podemos dejarlo servido para el modelo.

In [19]:
!ls "../../data/data_anime/prod/vocab_genres.csv"

part-00000-6d9ec3df-038b-4f10-b5fd-de90776b2ebc-c000.csv  _SUCCESS


In [None]:
path = "../../data/data_anime/prod/"

#Item
item_base      = spark.read.parquet(path + "item_base")
item_genre_idx = spark.read.parquet(path + "item_genre_idx")

#User
user_genre_idx = spark.read.parquet(path + "user_genre_idx")

#Interactions
train_interactions = spark.read.parquet(path + "train_interactions")
test_pos           = spark.read.parquet(path + "test_data")

vocab_genres = pd.read_csv(path + "vocab_genres.csv/part-00000-6d9ec3df-038b-4f10-b5fd-de90776b2ebc-c000.csv", header=None)


item_base.printSchema()
item_genre_idx.printSchema()

user_genre_idx.printSchema()

train_interactions.printSchema()
test_pos.printSchema()

root
 |-- animeID: long (nullable = true)
 |-- score: double (nullable = true)
 |-- episodes: double (nullable = true)
 |-- year: double (nullable = true)

root
 |-- animeID: long (nullable = true)
 |-- genre_idx: long (nullable = true)

root
 |-- userID: long (nullable = true)
 |-- genre_idx: long (nullable = true)
 |-- genre_score: double (nullable = true)

root
 |-- userID: long (nullable = true)
 |-- animeID: long (nullable = true)
 |-- label: integer (nullable = true)

root
 |-- userID: long (nullable = true)
 |-- animeID: long (nullable = true)



Primero vamos por el set de Item, vemos que item_base ya es numerico y tambien item_genre_idx aun quiero hacerle un pivot, luego un join con la base y ese va a ser mi item.

In [None]:
vocab_dict     = vocab_genres[0].to_dict()
len_genres = len(vocab_dict)
print(len_genres)

320


320

In [None]:
genresId = list(vocab_dict.keys())
item_pivot = (
    item_genre_idx
    .groupBy("animeID")
    .pivot("genre_idx", genresId)
    .agg(F.lit(1))
    .fillna(0)
)

item_final = (
    item_base
    .join(item_pivot, on="animeID", how="left")
)

item_final.write.mode("overwrite").parquet(path+"item_final")

                                                                                

Vamos con el de usuarios, este ya es un solo archivo, quiero hacer el pivot, lo malo es que son MUCHOS usuarios... va a costar

```bash
root
 |-- userID: long (nullable = true)
 |-- genre_idx: long (nullable = true)
 |-- genre_score: double (nullable = true)
```

In [50]:
user_final = (
    user_genre_idx
    .withColumn("genre_score", F.col("genre_score").cast("float"))
    .groupBy("userID")
    .pivot("genre_idx", genresId)
    .agg(F.first("genre_score"))
    .fillna(0.0)
)

user_final.write.mode("overwrite").parquet(path+"user_final")


26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
26/01/01 23:57:54 WARN RowBasedKeyValueBatch: Calling spill() on

MUY BIEN, hasta ahora nada se rompio!

Se necesita generar el set de entrenamiento, para ello se necesita:

Un join o dos con con item_final y user_final, usando la tabla de interacciones, eso indica que se van a repetir final de usuarios.

La idea era que fueran tres sets, user, item e interactions con misma cantidad de filas y a eso vamos pero mejor dejarlo en un solo set, total se tienen que acceder a los tres.

In [52]:
train_interactions.show(5)

+------+-------+-----+
|userID|animeID|label|
+------+-------+-----+
|    17|   1177|   -1|
|    17|   1240|   -1|
|    17|   1767|   -1|
|    17|   2301|   -1|
|    17|   3579|   -1|
+------+-------+-----+
only showing top 5 rows



In [None]:
# path = "../../data/data_anime/prod/"
# 
# user_final          = spark.read.parquet(path + "user_final")
# item_final          = spark.read.parquet(path + "item_final")
# train_interactions  = spark.read.parquet(path + "train_interactions")
# 
# 
# user_final = user_final.select(*[F.col(c).alias(f"u_{c}") if c.isdigit() else F.col(c) for c in user_final.columns])
# item_final = item_final.select(*[F.col(c).alias(f"i_{c}") if c.isdigit() else F.col(c) for c in item_final.columns])
# 
# 
# train_ui = (
#     train_interactions
#     .join(user_final, on="userID", how="inner")
# )
# 
# train_final = (
#     train_ui
#     .join(item_final, on="animeID", how="inner")
# )
# 
# 
# train_final.write.mode("overwrite").parquet(path+"train_final")