# CSCI 4253 / 5253 - Lab #4 - Patent Problem with Spark DataFrames
<div>
 <h2> CSCI 4283 / 5253 
  <IMG SRC="https://www.colorado.edu/cs/profiles/express/themes/cuspirit/logo.png" WIDTH=50 ALIGN="right"/> </h2>
</div>

This [Spark cheatsheet](https://s3.amazonaws.com/assets.datacamp.com/blog_assets/PySpark_SQL_Cheat_Sheet_Python.pdf) is useful as is [this reference on doing joins in Spark dataframe](http://www.learnbymarketing.com/1100/pyspark-joins-by-example/).

The [DataBricks company has one of the better reference manuals for PySpark](https://docs.databricks.com/spark/latest/dataframes-datasets/index.html) -- they show you how to perform numerous common data operations such as joins, aggregation operations following `groupBy` and the like.

In [39]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

The following aggregation functions may be useful -- [these can be used to aggregate results of `groupby` operations](https://docs.databricks.com/spark/latest/dataframes-datasets/introduction-to-dataframes-python.html#example-aggregations-using-agg-and-countdistinct). More documentation is at the [PySpark SQL Functions manual](https://spark.apache.org/docs/2.3.0/api/python/pyspark.sql.html#module-pyspark.sql.functions). Feel free to use other functions from that library.

In [40]:
from pyspark.sql.functions import col, count, countDistinct

Create our session as described in the tutorials

In [41]:
spark = SparkSession \
    .builder \
    .appName("Lab4-Dataframe") \
    .master("local[*]")\
    .getOrCreate()

Read in the citations and patents data and check that the data makes sense. Note that unlike in the RDD solution, the data is automatically inferred to be Integer() types.

In [42]:
citations = spark.read.load('cite75_99.txt.gz',
            format="csv", sep=",", header=True,
            compression="gzip",
            inferSchema="true")

In [43]:
citations.show(5)

+-------+-------+
| CITING|  CITED|
+-------+-------+
|3858241| 956203|
|3858241|1324234|
|3858241|3398406|
|3858241|3557384|
|3858241|3634889|
+-------+-------+
only showing top 5 rows



In [44]:
patents = spark.read.load('apat63_99.txt.gz',
            format="csv", sep=",", header=True,
            compression="gzip",
            inferSchema="true")

In [45]:
patents.show(5)

+-------+-----+-----+-------+-------+-------+--------+-------+------+------+---+------+-----+--------+--------+-------+--------+--------+--------+--------+--------+--------+--------+
| PATENT|GYEAR|GDATE|APPYEAR|COUNTRY|POSTATE|ASSIGNEE|ASSCODE|CLAIMS|NCLASS|CAT|SUBCAT|CMADE|CRECEIVE|RATIOCIT|GENERAL|ORIGINAL|FWDAPLAG|BCKGTLAG|SELFCTUB|SELFCTLB|SECDUPBD|SECDLWBD|
+-------+-----+-----+-------+-------+-------+--------+-------+------+------+---+------+-----+--------+--------+-------+--------+--------+--------+--------+--------+--------+--------+
|3070801| 1963| 1096|   NULL|     BE|   NULL|    NULL|      1|  NULL|   269|  6|    69| NULL|       1|    NULL|    0.0|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|
|3070802| 1963| 1096|   NULL|     US|     TX|    NULL|      1|  NULL|     2|  6|    63| NULL|       0|    NULL|   NULL|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|
|3070803| 1963| 1096|   NULL|     US|     IL|    NULL|      1|  NULL|     2|  6|    6

In [None]:
from pyspark.sql.functions import coalesce, lit, col, count
patents_clean = patents.select(
    col("PATENT"), col("GYEAR"), col("GDATE"), col("APPYEAR"), 
    col("COUNTRY"), col("POSTATE"), col("ASSIGNEE"), col("ASSCODE"),
    col("CLAIMS"), col("NCLASS"), col("CAT"), col("SUBCAT"), 
    col("CMADE"), col("CRECEIVE"), col("RATIOCIT"), col("GENERAL"),
    col("ORIGINAL"), col("FWDAPLAG"), col("BCKGTLAG"), col("SELFCTUB"),
    col("SELFCTLB"), col("SECDUPBD"), col("SECDLWBD")
).filter(
    (col("COUNTRY") == "US") & 
    (col("POSTATE").isNotNull()) & 
    (col("POSTATE") != "")
)

#Creating patent-state lookup
patent_states = patents_clean.select("PATENT", "POSTATE").cache()

# Joining citations with citing patent states
citations_with_citing_state = citations.alias("c") \
    .join(patent_states.alias("p1"), col("c.CITING") == col("p1.PATENT"), "inner") \
    .select(
        col("c.CITING"),
        col("c.CITED"),
        col("p1.POSTATE").alias("CITING_STATE")
    )

# Joining with cited patent states
citations_with_both_states = citations_with_citing_state.alias("ccs") \
    .join(patent_states.alias("p2"), col("ccs.CITED") == col("p2.PATENT"), "inner") \
    .select(
        col("ccs.CITING"),
        col("ccs.CITED"),
        col("ccs.CITING_STATE"),
        col("p2.POSTATE").alias("CITED_STATE")
    )

# Filtering for same-state citations and count
same_state_citations = citations_with_both_states \
    .filter(col("CITING_STATE") == col("CITED_STATE")) \
    .groupBy("CITING") \
    .agg(count("*").alias("SAME_STATE"))

# Creating final_result DataFrame (THIS IS WHAT WAS MISSING!)
final_result = patents_clean.alias("p") \
    .join(same_state_citations.alias("s"), col("p.PATENT") == col("s.CITING"), "left") \
    .select(
        col("p.*"),
        coalesce(col("s.SAME_STATE"), lit(0)).alias("SAME_STATE")
    ) \
    .orderBy(col("SAME_STATE").desc(), col("PATENT").asc()) \
    .limit(10)

print("final_result DataFrame created successfully!")

results = final_result.collect()
headers = ['PATENT', 'GYEAR', 'GDATE', 'APPYEAR', 'COUNTRY', 'POSTATE', 'ASSIGNEE', 'ASSCODE', 
           'CLAIMS', 'NCLASS', 'CAT', 'SUBCAT', 'CMADE', 'CRECEIVE', 'RATIOCIT', 'GENERAL', 
           'ORIGINAL', 'FWDAPLAG', 'BCKGTLAG', 'SELFCTUB', 'SELFCTLB', 'SECDUPBD', 'SECDLWBD', 'SAME_STATE']
col_widths = [8, 6, 8, 8, 8, 8, 12, 8, 7, 7, 5, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 11]

# Printing header
header_line = ""
for i, header in enumerate(headers):
    header_line += f"{header:<{col_widths[i]}}"
print(header_line)
print("-" * builtins.sum(col_widths))

# Print data rows
for row in results:
    assignee_str = str(row.ASSIGNEE) if row.ASSIGNEE is not None else 'null'
    assignee_short = assignee_str[:10] if len(assignee_str) > 10 else assignee_str
    
    values = [
        row.PATENT, row.GYEAR, row.GDATE, row.APPYEAR, row.COUNTRY, row.POSTATE,
        assignee_short, row.ASSCODE, row.CLAIMS, row.NCLASS, row.CAT, row.SUBCAT, 
        row.CMADE, row.CRECEIVE, 
        f"{row.RATIOCIT:.4f}" if row.RATIOCIT is not None else 'null',
        f"{row.GENERAL:.4f}" if row.GENERAL is not None else 'null',
        f"{row.ORIGINAL:.4f}" if row.ORIGINAL is not None else 'null',
        f"{row.FWDAPLAG:.4f}" if row.FWDAPLAG is not None else 'null',
        f"{row.BCKGTLAG:.4f}" if row.BCKGTLAG is not None else 'null',
        f"{row.SELFCTUB:.4f}" if row.SELFCTUB is not None else 'null',
        f"{row.SELFCTLB:.4f}" if row.SELFCTLB is not None else 'null',
        f"{row.SECDUPBD:.4f}" if row.SECDUPBD is not None else 'null',
        f"{row.SECDLWBD:.4f}" if row.SECDLWBD is not None else 'null',
        row.SAME_STATE
    ]
    
    row_line = ""
    for i, value in enumerate(values):
        str_val = str(value) if value is not None else 'null'
        if len(str_val) > col_widths[i] - 1:
            str_val = str_val[:col_widths[i] - 3] + ".."
        row_line += f"{str_val:<{col_widths[i]}}"
    print(row_line)

spark.stop()

final_result DataFrame created successfully!
