In [47]:
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, IntegerType, FloatType
from pyspark.sql.functions import udf, col

from itertools import combinations
import sys
sys.path.append("../python/")
from featurization_data import FeaturizationData

In [48]:
schema = StructType([
            StructField("team", StringType(), True),
            StructField("country", StringType(), True)])
teams = spark.read.csv("../data/common/en.teams.tsv", sep="\t", header=False, schema=schema)
teams.show(5)

+----+-----------+
|team|    country|
+----+-----------+
|  AN|       Aden|
|  AF|Afghanistan|
|  AL|    Albania|
|  DZ|    Algeria|
|  AD|    Andorra|
+----+-----------+
only showing top 5 rows



In [49]:
# +-----+----------+-----+--------+
# |label|prediction|group| matches|
# +-----+----------+-----+--------+

data = [(2.0, 2.0, "A", ["BR", "HR"]),
(0.0, 2.0, "A", ["BR", "MX"]),
(2.0, 2.0, "A", ["BR", "CM"]),
(2.0, 2.0, "A", ["HR", "CM"]),
(2.0, 2.0, "A", ["MX", "CM"])]

data = spark.createDataFrame(data, ["label", "prediction", "group", "matches"])
data.show()

+-----+----------+-----+--------+
|label|prediction|group| matches|
+-----+----------+-----+--------+
|  2.0|       2.0|    A|[BR, HR]|
|  0.0|       2.0|    A|[BR, MX]|
|  2.0|       2.0|    A|[BR, CM]|
|  2.0|       2.0|    A|[HR, CM]|
|  2.0|       2.0|    A|[MX, CM]|
+-----+----------+-----+--------+



In [50]:
udf_get_team_1 = udf(lambda x: x[0], StringType())
udf_get_team_2 = udf(lambda x: x[1], StringType())

def result_team_2(result):
    if (result == 2):
        return 1.0
    elif (result == 1):
        return 2.0
    else:
        return 0.0
udf_result_team_2 = udf(lambda result: result_team_2(result), FloatType())

In [51]:
data = (data
.withColumn("team_1", udf_get_team_1(col("matches")))
.withColumn("team_2", udf_get_team_2(col("matches")))
.withColumn("result_team_2", udf_result_team_2(col("prediction")))
.withColumnRenamed("prediction", "result_team_1") 
        
)
data.show()

+-----+-------------+-----+--------+------+------+-------------+
|label|result_team_1|group| matches|team_1|team_2|result_team_2|
+-----+-------------+-----+--------+------+------+-------------+
|  2.0|          2.0|    A|[BR, HR]|    BR|    HR|          1.0|
|  0.0|          2.0|    A|[BR, MX]|    BR|    MX|          1.0|
|  2.0|          2.0|    A|[BR, CM]|    BR|    CM|          1.0|
|  2.0|          2.0|    A|[HR, CM]|    HR|    CM|          1.0|
|  2.0|          2.0|    A|[MX, CM]|    MX|    CM|          1.0|
+-----+-------------+-----+--------+------+------+-------------+



In [52]:
data = (data.join(teams, data.team_1 == teams.team)
.withColumnRenamed("country", "country_1").drop("team")
.join(teams, data.team_2 == teams.team)
.withColumnRenamed("country", "country_2").drop("team"))

data.show()
 

+-----+-------------+-----+--------+------+------+-------------+---------+---------+
|label|result_team_1|group| matches|team_1|team_2|result_team_2|country_1|country_2|
+-----+-------------+-----+--------+------+------+-------------+---------+---------+
|  2.0|          2.0|    A|[BR, HR]|    BR|    HR|          1.0|   Brazil|  Croatia|
|  0.0|          2.0|    A|[BR, MX]|    BR|    MX|          1.0|   Brazil|   Mexico|
|  2.0|          2.0|    A|[BR, CM]|    BR|    CM|          1.0|   Brazil| Cameroon|
|  2.0|          2.0|    A|[HR, CM]|    HR|    CM|          1.0|  Croatia| Cameroon|
|  2.0|          2.0|    A|[MX, CM]|    MX|    CM|          1.0|   Mexico| Cameroon|
+-----+-------------+-----+--------+------+------+-------------+---------+---------+



In [59]:
rdd_team_1 = (data
              .groupBy(["country_1", "result_team_1"]).count()
              .rdd
              .map(lambda x: ((x["country_1"], x["result_team_1"]), x["count"])))

rdd_team_2 = (data
              .groupBy(["country_2", "result_team_2"]).count()
              .rdd
              .map(lambda x: ((x["country_2"], x["result_team_2"]), x["count"])))

print(rdd_team_1.collect())
print(rdd_team_2.collect())

[((u'Mexico', 2.0), 1), ((u'Croatia', 2.0), 1), ((u'Brazil', 2.0), 3)]
[((u'Croatia', 1.0), 1), ((u'Cameroon', 1.0), 3), ((u'Mexico', 1.0), 1)]


In [67]:
rdd_union = (rdd_team_1
             .union(rdd_team_2)
             .reduceByKey(lambda x,y: x + y)
             .map(lambda x: (x[0][0], [(x[0][1], x[1])]))
             .reduceByKey(lambda x,y: x + y)
             .map(lambda x: (x[0], sorted(x[1], key=lambda tup: tup[0], reverse=True))))

In [68]:
rdd_union.collect()

[(u'Brazil', [(2.0, 3)]),
 (u'Croatia', [(2.0, 1), (1.0, 1)]),
 (u'Cameroon', [(1.0, 3)]),
 (u'Mexico', [(2.0, 1), (1.0, 1)])]

In [69]:
schema = StructType([
            StructField("group", StringType(), True),
            StructField("country_1", StringType(), True),
            StructField("country_2", StringType(), True),
            StructField("country_3", StringType(), True),
            StructField("country_4", StringType(), True)])

groups = spark.read.csv("../data/groups.csv", sep=",", schema=schema, header=False)

groups.show()

+-----+------------+--------------------+------------+--------------+
|group|   country_1|           country_2|   country_3|     country_4|
+-----+------------+--------------------+------------+--------------+
|    A|      Brazil|             Croatia|      Mexico|      Cameroon|
|    B|       Spain|         Netherlands|       Chile|     Australia|
|    C|    Colombia|              Greece| Ivory Coast|         Japan|
|    D|     Uruguay|          Costa Rica|     England|         Italy|
|    E| Switzerland|             Ecuador|      France|      Honduras|
|    F|   Argentina| Bosnia and Herze...|        Iran|       Nigeria|
|    G|     Germany|            Portugal|       Ghana| United States|
|    H|     Belgium|             Algeria|      Russia|   South Korea|
+-----+------------+--------------------+------------+--------------+



In [73]:
udf_list_country = udf(lambda country_1,country_2,country_3,country_4: [country_1, country_2, country_3, country_4], ArrayType(StringType()))
groups.withColumn("test", udf_list_country(col("country_1"),col("country_2"),col("country_3"),col("country_4"))).show()

+-----+------------+--------------------+------------+--------------+--------------------+
|group|   country_1|           country_2|   country_3|     country_4|                test|
+-----+------------+--------------------+------------+--------------+--------------------+
|    A|      Brazil|             Croatia|      Mexico|      Cameroon|[ Brazil,  Croati...|
|    B|       Spain|         Netherlands|       Chile|     Australia|[ Spain,  Netherl...|
|    C|    Colombia|              Greece| Ivory Coast|         Japan|[ Colombia,  Gree...|
|    D|     Uruguay|          Costa Rica|     England|         Italy|[ Uruguay,  Costa...|
|    E| Switzerland|             Ecuador|      France|      Honduras|[ Switzerland,  E...|
|    F|   Argentina| Bosnia and Herze...|        Iran|       Nigeria|[ Argentina,  Bos...|
|    G|     Germany|            Portugal|       Ghana| United States|[ Germany,  Portu...|
|    H|     Belgium|             Algeria|      Russia|   South Korea|[ Belgium,  Alger...|