In [13]:
from pyspark.sql.types import StructType, StructField, StringType, ArrayType
from pyspark.sql.functions import udf, col

from itertools import combinations

In [11]:
schema = StructType([
    StructField("group", StringType(), True),
    StructField("team_1", StringType(), True),
    StructField("team_2", StringType(), True),
    StructField("team_3", StringType(), True),
    StructField("team_4", StringType(), True),
])

data = spark.read.csv("../data/groups.csv", sep=",", schema=schema, header=False)
data.show()

+-----+------------+-------------------+------------+------------+
|group|      team_1|             team_2|      team_3|      team_4|
+-----+------------+-------------------+------------+------------+
|    A|      Brazil|            Croatia|      Mexico|    Cameroon|
|    B|       Spain|            Holland|       Chile|   Australia|
|    C|    Colombia|             Greece| Ivory Coast|       Japan|
|    D|     Uruguay|         Costa Rica|     England|       Italy|
|    E| Switzerland|            Ecuador|      France|    Honduras|
|    F|   Argentina| Bosnia-Herzegovina|        Iran|     Nigeria|
|    G|     Germany|           Portugal|       Ghana|         USA|
|    H|     Belgium|            Algeria|      Russia| South Korea|
+-----+------------+-------------------+------------+------------+



In [17]:
udf_define_matches = udf(lambda x,y,z,t: list(combinations([x, y, z, t], 2)), ArrayType(ArrayType(StringType())))

In [18]:
data = data.withColumn("matches", udf_define_matches(col("team_1"), col("team_2"), col("team_3"), col("team_4")))

In [41]:
udf_matches_group = udf(lambda group, matches: [[group] + match for match in matches], ArrayType(ArrayType(StringType())))


all_matches = data\
.withColumn("group_matches", udf_matches_group(col("group"), col("matches")))\
.select("group_matches").rdd.map(lambda x: x["group_matches"]).collect()

matches_flattened_list = [y for x in all_matches for y in x]

In [44]:
schema = schema = StructType([
    StructField("group", StringType(), True), 
    StructField("team_1", StringType(), True),
    StructField("team_2", StringType(), True),
])

matches = spark.createDataFrame(matches_flattened_list, schema=schema)

matches.show(5)

+-----+--------+---------+
|group|  team_1|   team_2|
+-----+--------+---------+
|    A|  Brazil|  Croatia|
|    A|  Brazil|   Mexico|
|    A|  Brazil| Cameroon|
|    A| Croatia|   Mexico|
|    A| Croatia| Cameroon|
+-----+--------+---------+
only showing top 5 rows



In [45]:
matches.coalesce(1).write.csv("../data/first_round_matches")