# CSCI 4253 / 5253 - Lab #4 - Patent Problem with Spark RDD - SOLUTION
<div>
 <h2> CSCI 4283 / 5253 
  <IMG SRC="https://www.colorado.edu/cs/profiles/express/themes/cuspirit/logo.png" WIDTH=50 ALIGN="right"/> </h2>
</div>

This [Spark cheatsheet](https://s3.amazonaws.com/assets.datacamp.com/blog_assets/PySpark_SQL_Cheat_Sheet_Python.pdf) is useful

In [1]:
from pyspark import SparkContext, SparkConf
import numpy as np
import operator

In [2]:
conf=SparkConf().setAppName("Lab4-rdd").setMaster("local[*]")
sc = SparkContext(conf=conf)

Using PySpark and RDD's on the https://coding.csel.io machines is slow -- most of the code is executed in Python and this is much less efficient than the java-based code using the PySpark dataframes. Be patient and trying using `.cache()` to cache the output of joins. You may want to start with a reduced set of data before running the full task. You can use the `sample()` method to extract just a sample of the data or use 

These two RDD's are called "rawCitations" and "rawPatents" because you probably want to process them futher (e.g. convert them to integer types, etc). 

The `textFile` function returns data in strings. This should work fine for this lab.

Other methods you use might return data in type `Byte`. If you haven't used Python `Byte` types before, google it. You can convert a value of `x` type byte into e.g. a UTF8 string using `x.decode('uft-8')`. Alternatively, you can use the `open` method of the gzip library to read in all the lines as UTF-8 strings like this:
```
import gzip
with gzip.open('cite75_99.txt.gz', 'rt',encoding='utf-8') as f:
    rddCitations = sc.parallelize( f.readlines() )
```
This is less efficient than using `textFile` because `textFile` would use the underlying HDFS or other file system to read the file across all the worker nodes while the using `gzip.open()...readlines()` will read all the data in the frontend and then distribute it to all the worker nodes.

In [3]:
rddCitations = sc.textFile("cite75_99.txt.gz")
rddPatents = sc.textFile("apat63_99.txt.gz")

The data looks like the following.

In [4]:
rddCitations.take(5)

['"CITING","CITED"',
 '3858241,956203',
 '3858241,1324234',
 '3858241,3398406',
 '3858241,3557384']

In [5]:
rddPatents.take(5)

['"PATENT","GYEAR","GDATE","APPYEAR","COUNTRY","POSTATE","ASSIGNEE","ASSCODE","CLAIMS","NCLASS","CAT","SUBCAT","CMADE","CRECEIVE","RATIOCIT","GENERAL","ORIGINAL","FWDAPLAG","BCKGTLAG","SELFCTUB","SELFCTLB","SECDUPBD","SECDLWBD"',
 '3070801,1963,1096,,"BE","",,1,,269,6,69,,1,,0,,,,,,,',
 '3070802,1963,1096,,"US","TX",,1,,2,6,63,,0,,,,,,,,,',
 '3070803,1963,1096,,"US","IL",,1,,2,6,63,,9,,0.3704,,,,,,,',
 '3070804,1963,1096,,"US","OH",,1,,2,6,63,,3,,0.6667,,,,,,,']

In other words, they are a single string with multiple CSV's. You will need to convert these to (K,V) pairs, probably convert the keys to `int` and so on. You'll need to `filter` out the header string as well since there's no easy way to extract all the lines except the first.

## Step 1: Initialize Spark Context and Load Raw Data Files
-Start Spark session and SparkContext. Read the citation and patent data files as RDDs and cache them for reuse.

In [6]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Lab4-RDD") \
    .master("local[*]") \
    .getOrCreate()

sc = spark.sparkContext


In [7]:
CITATIONS_PATH = "cite75_99.txt.gz"
PATENTS_PATH = "apat63_99.txt.gz"

rddCitations = sc.textFile(CITATIONS_PATH).cache()
rddPatents = sc.textFile(PATENTS_PATH).cache()


In [8]:
print(rddCitations.count())
print(rddPatents.count())


16522439
2923923


## Step 2: Parse Patent Header and Identify Column Indices

-Extract the header from patents data to determine the indices of key columns: PATENT, COUNTRY, and POSTATE.

In [9]:
import csv

pat_header = rddPatents.first()
hdr = next(csv.reader([pat_header]))

i_PATENT = hdr.index("PATENT")
i_COUNTRY = hdr.index("COUNTRY")
i_POSTATE = hdr.index("POSTATE")


## Step 3:
- Define a function to extract US patents with non-empty state information. Create an RDD mapping patent_id to state, filtering out invalid entries, and cache it.

In [10]:
def extract_us_state(line):
    cols = next(csv.reader([line]))
    patent_id = cols[i_PATENT]
    country = (cols[i_COUNTRY] or "").upper()
    state = (cols[i_POSTATE] or "").upper()
    if country == "US" and state:
        return (patent_id, state)
    else:
        return None

patent_to_state = (
    rddPatents
      .filter(lambda line: line != pat_header)
      .map(extract_us_state)
      .filter(lambda x: x is not None)
      .cache()
)


## Step 4:
Define a function to parse citation lines into (citing, cited) pairs. Handles lines separated by comma or whitespace, filtering invalid entries and cache the resulting RDD.


In [11]:
def parse_citation(line):
    line = line.strip()
    if not line:
        return None
    parts = line.split(",") if "," in line else line.split()
    if len(parts) < 2:
        return None
    citing = parts[0].strip()
    cited  = parts[1].strip()
    if not citing.isdigit() or not cited.isdigit():
        return None
    return (citing, cited)

cit_pairs = (
    rddCitations
      .map(parse_citation)
      .filter(lambda x: x is not None)
      .cache()
)


## Step 5: 

- Join citation pairs with patent states for citing patents.
- Rearrange data keyed by cited patent.
- Join again to attach cited patent states.
- Filter citations where citing and cited patents have the same state.
- Count the number of same-state citations per citing patent and cache the result.

In [12]:
# Join to get CITING_STATE
citing_join = cit_pairs.join(patent_to_state)

# Re-key by CITED
by_cited = citing_join.map(lambda x: (x[1][0], (x[0], x[1][1])))

# Join to get CITED_STATE
both_states = by_cited.join(patent_to_state)

# Count same-state citations
same_state_counts = (
    both_states
      .filter(lambda x: x[1][0][1] == x[1][1])
      .map(lambda x: (x[1][0][0], 1))
      .reduceByKey(lambda a, b: a + b)
      .cache()
)


In [13]:
same_state_counts.count()

571919

In [14]:
top10 = same_state_counts.takeOrdered(10, key=lambda x: -x[1])
top10


[('5959466', 125),
 ('5983822', 103),
 ('6008204', 100),
 ('5952345', 98),
 ('5958954', 96),
 ('5998655', 96),
 ('5936426', 94),
 ('5739256', 90),
 ('5951547', 90),
 ('5913855', 90)]

In [15]:
# Get top 10 patents by same-state count
top10 = (
    same_state_counts
      .takeOrdered(10, key=lambda x: -x[1])
)

for patent, count in top10:
    print(patent, count)


5959466 125
5983822 103
6008204 100
5952345 98
5958954 96
5998655 96
5936426 94
5739256 90
5951547 90
5913855 90
