# ReadMe
The objective of this project is to read the National Vulnerability Database's (NVD) of Common exposures and vulnerabilites (CVES) and create a tool 
that is used as a part of the data pipeline to determine how similar CVE's are to one another. 

The first goal to such a pipeline is to clean the description data, tokenize the data. The second goal is tokenize and read the data that will be used as a training label. 

# Requirements
Required packages:
- pyspark
- numpy

TF-IDF use cases:
- LDA
- Similarity with
1. **Feature Selection**:
   - Examine the top TF-IDF terms to identify important features. You can select a subset of these features for further analysis or modeling.
   - Consider removing low-TF-IDF terms (common words) that might not contribute significantly to your task.

2. **Clustering and Topic Modeling**:
   - Apply clustering algorithms (e.g., K-means, DBSCAN) to group similar documents based on their TF-IDF vectors.
   - Explore topic modeling techniques (e.g., Latent Dirichlet Allocation, Non-Negative Matrix Factorization) to discover latent topics within your corpus.

3. **Document Similarity**:
   - Calculate cosine similarity between TF-IDF vectors of different documents. This helps identify similar documents.
   - Use similarity scores to recommend related articles, products, or content.

4. **Classification and Sentiment Analysis**:
   - Train classifiers (e.g., SVM, Random Forest) using TF-IDF features as input. This is useful for tasks like sentiment analysis, spam detection, or document categorization.
   - Convert text data into TF-IDF vectors and use them as features for machine learning models.

5. **Search and Information Retrieval**:
   - Build an inverted index using TF-IDF vectors to create an efficient search engine.
   - Retrieve relevant documents based on user queries by ranking them using their TF-IDF scores.

6. **Visualizations**:
   - Visualize TF-IDF scores using word clouds, scatter plots, or bar charts to gain insights into term importance.
   - Plot the distribution of TF-IDF values across the entire dataset.

7. **Optimize Hyperparameters**:
   - Experiment with different parameters (e.g., n-grams, stop words, max features) in your TF-IDF vectorization process.
   - Use cross-validation to find optimal settings.

TODO: Find sources to back up these claims.

In [1]:
import os, math, re
import pyspark
import urllib.request
import zipfile
from pyspark.sql import SparkSession
from pyspark.sql.functions import  col, count, countDistinct, concat_ws, explode, expr, lit, udf, split, sum
from pyspark.sql.types import DoubleType, FloatType, StringType, ArrayType
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.linalg import VectorUDT, SparseVector

In [2]:
# Load the dataset if it doesn't already exist. 
data_dir = "data"
os.makedirs(data_dir, exist_ok=True) 
fileUrls = [
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2002.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2003.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2004.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2005.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2006.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2007.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2008.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2009.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2010.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2011.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2012.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2013.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2014.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2015.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2016.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2017.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2018.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2019.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2020.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2021.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2022.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2023.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-2024.json.zip',
        'https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-recent.json.zip'
    ]

# Iterate through each URL
for url in fileUrls:
    filename = url.split("/")[-1]
    outputfile = os.path.join(data_dir, filename)
    checkfile = os.path.join(data_dir, os.path.splitext(filename)[0])

    # Check if the file already exists
    if not os.path.exists(checkfile):
        # Download the file
        urllib.request.urlretrieve(url, outputfile)
        print(f"Downloaded: {filename}")

        # Extract the file
        with zipfile.ZipFile(outputfile, "r") as zip_ref:
            zip_ref.extractall(data_dir)

        # Delete the original zip file
        os.remove(outputfile)



In [3]:
spark = (
    SparkSession.builder
        .master("local[*]")
        .appName("voltcve")
        .config("spark.driver.host", "127.0.0.1")
        .config("spark.driver.bindAddress", "127.0.0.1")
        .config("spark.default.parallelism", 8)
        .config("spark.driver.memory", "25g") \
       .config("spark.executor.memory", "10g") \
        .getOrCreate()
)
# read the data
cves = spark.read.option("multiline", "true").json("data/nvdcve-1.1-2020.json")

# Manipulate the data to be more usable 
exploded = cves.select(explode(col("CVE_Items")).alias("cves"))


# Now 'exploded' contains individual rows for each 'CVE_Item'
# Scheme format: https://en.wikipedia.org/wiki/Common_Platform_Enumeration
# cpe:<cpe_version>:<part>:<vendor>:<product>:<version>:<update>:<edition>:<language>:<sw_edition>:<target_sw>:<target_hw>:<other>
cpe_df = exploded.select(col("cves.cve.CVE_data_meta.ID").alias("id"), explode(col("cves.configurations.nodes.cpe_match")).alias("cpe"))

# Take just the first row to train. The CPE23Uri will contain a lot of repeated data
# As each version is given a row.
cpe_df = cpe_df.select(col("id"), col("cpe")[0].alias("cpe"))
cpe_df = cpe_df.select("id", split(cpe_df.cpe.cpe23uri,":",-1)[4].alias("label"))

cpe_df.show(truncate=200)
cpe_df.printSchema()

descr_df = exploded.select(col("cves.cve.CVE_data_meta.ID").alias("id"),
          col("cves.cve.description.description_data.value").alias("description"));

descr_df = descr_df.withColumn("description_single", concat_ws(" ", descr_df["description"]))


doc_count = descr_df.selectExpr("count(distinct id)").first()[0]
print("Number of docs: {}".format(doc_count))

24/03/19 07:58:04 WARN Utils: Your hostname, sandbox resolves to a loopback address: 127.0.0.1; using 192.168.0.14 instead (on interface eth0)
24/03/19 07:58:04 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/03/19 07:58:04 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/03/19 07:58:10 WARN TaskSetManager: Lost task 0.0 in stage 0.0 (TID 0) (192.168.56.10 executor 1): java.io.InvalidClassException: org.apache.spark.rdd.RDD; local class incompatible: stream classdesc serialVersionUID = 3516924559342767982, local class serialVersionUID = 823754013007382808
	at java.base/java.io.ObjectStreamClass.initNonProxy(ObjectStreamClass.java:597)
	at java.base/java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:2051)
	at java.base/java.io.ObjectInputStre

Py4JJavaError: An error occurred while calling o33.json.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 4 times, most recent failure: Lost task 0.3 in stage 0.0 (TID 3) (192.168.56.14 executor 0): java.io.InvalidClassException: org.apache.spark.rdd.RDD; local class incompatible: stream classdesc serialVersionUID = 3516924559342767982, local class serialVersionUID = 823754013007382808
	at java.base/java.io.ObjectStreamClass.initNonProxy(ObjectStreamClass.java:597)
	at java.base/java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:2051)
	at java.base/java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1898)
	at java.base/java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:2051)
	at java.base/java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1898)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2224)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733)
	at java.base/java.io.ObjectInputStream$FieldValues.<init>(ObjectInputStream.java:2606)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2457)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2257)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733)
	at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:509)
	at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:467)
	at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:87)
	at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:129)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:86)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2493)
	at org.apache.spark.sql.catalyst.json.JsonInferSchema.infer(JsonInferSchema.scala:120)
	at org.apache.spark.sql.execution.datasources.json.MultiLineJsonDataSource$.$anonfun$infer$5(JsonDataSource.scala:167)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.datasources.json.MultiLineJsonDataSource$.infer(JsonDataSource.scala:167)
	at org.apache.spark.sql.execution.datasources.json.JsonDataSource.inferSchema(JsonDataSource.scala:64)
	at org.apache.spark.sql.execution.datasources.json.JsonFileFormat.inferSchema(JsonFileFormat.scala:59)
	at org.apache.spark.sql.execution.datasources.DataSource.$anonfun$getOrInferFileFormatSchema$11(DataSource.scala:208)
	at scala.Option.orElse(Option.scala:447)
	at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:205)
	at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:407)
	at org.apache.spark.sql.DataFrameReader.loadV1Source(DataFrameReader.scala:229)
	at org.apache.spark.sql.DataFrameReader.$anonfun$load$2(DataFrameReader.scala:211)
	at scala.Option.getOrElse(Option.scala:189)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:211)
	at org.apache.spark.sql.DataFrameReader.json(DataFrameReader.scala:362)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: java.io.InvalidClassException: org.apache.spark.rdd.RDD; local class incompatible: stream classdesc serialVersionUID = 3516924559342767982, local class serialVersionUID = 823754013007382808
	at java.base/java.io.ObjectStreamClass.initNonProxy(ObjectStreamClass.java:597)
	at java.base/java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:2051)
	at java.base/java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1898)
	at java.base/java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:2051)
	at java.base/java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1898)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2224)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733)
	at java.base/java.io.ObjectInputStream$FieldValues.<init>(ObjectInputStream.java:2606)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2457)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2257)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733)
	at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:509)
	at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:467)
	at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:87)
	at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:129)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:86)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more


## Questions from the data:
- Should CPE be trained on the same data with multiple labels?
- How much data is enough to adequately train a word2vec model?

## Tokenization: 
### Applying the right tokenization methods:
The dataset contains a lot of data the ordinary tokenization may not apply. For example, file names contain puncutation and can leave the token meaningless. 
An effort was made to preserve all meaninful punction that would describe software or hardware configurations, while removing stop words, and ordinary english punctation. 

In [None]:
@udf
def string_cleaner(input_str):
    # 1. Replace all "." or ':' followed by whitespace with an empty string.
    # a. Remove ending periods.
    # 2. Remove trademark, rights.
    # 3. Grab cotent in parentheses only.
    # 4. Remove some punctuation.
    cleaned_text = re.sub(r"[.:,]+\s+", " ", input_str)
    # Remove trailing periods
    cleaned_text = re.sub(r"\.$", "", cleaned_text)
    # Remove apostrophes, (TM), (R), parentheses, and double quotes
    cleaned_text = re.sub(r"\'|\(TM\)|\(R\)|\(|\)|\"", "", cleaned_text)
    # Convert to lowercase and strip leading/trailing spaces
    cleaned_text = cleaned_text.strip().lower()
    return cleaned_text

clean_tokens = descr_df.withColumn("token", split(string_cleaner(col("description_single")), " " ))
stop_words_remover = StopWordsRemover(inputCol="token", outputCol="cleanToken")
clean_tokens = stop_words_remover.transform(clean_tokens)
clean_tokens = clean_tokens.select("id", explode("cleantoken").alias("token"))
clean_tokens = clean_tokens.filter(col("token").isNotNull())

# # Show the resulting DataFrame
clean_tokens.show(truncate=200)



In [None]:
def tfidf(tokens, doc_count):
    allTokensForId = tokens.groupBy("id").agg(count("id").alias("allTokensForId"))

    tfds = tokens.groupBy("id", "token").agg(count("id").alias("rawtf"))
    dfds = tokens.groupBy("token").agg(countDistinct("id").alias("df"))

    # Join the two DataFrames on 'id'
    merged_df = tfds.join(allTokensForId, on="id")

    # Calculate the ratio of rawtf to allTokensForId
    tfds = merged_df.withColumn("tf", col("rawtf") / col("allTokensForId"))
    

    merged_df.show()

    # Define the UDF for idf calculation
    spark.udf.register("calcidfudf", lambda df: calcidf(doc_count, df), DoubleType())

    # Calculate idf and add it as a new column "idf"
    tokens_idf = dfds.withColumn("idf", expr("calcidfudf(df)"))

    # Show the resulting dataframe
    tfidfds = tokens_idf.join(tfds, "token", "left") \
        .withColumn("tf_idf", col("tf") * col("idf"))

    return tfidfds

def calcidf(doc_count, df):
    # Calculate the tf-idf using natural log
    return math.log((doc_count + 1.0) / (df + 1.0))


In [None]:
tfidf_df = tfidf(clean_tokens, doc_count)

tfidf_df.show()

Already calculated the TF-IDF scores for the documents. Now, let’s create a TF-IDF vector from these scores.

Assuming you have a DataFrame named tfidf_vector with columns ‘id’, ‘token’, ‘tf_idf’, and ‘allTokensForId’, you can proceed as follows:

    Group TF-IDF Scores by Document ID:
        Group the DataFrame by the ‘id’ column and aggregate the ‘tf_idf’ values into a list for each document.
        This will give you a list of TF-IDF scores for each document.

    Create a Sparse Vector Representation:
        Use the SparseVector class from PySpark to create a sparse vector representation for each document.
        The vector will have dimensions equal to the total number of unique tokens (terms) in your dataset.
        For each document, set the value at the index corresponding to the token to its corresponding TF-IDF score.

    Assemble the Sparse Vectors:
        Assemble the sparse vectors into a single column using the VectorAssembler.
        This will give you a new DataFrame with a column containing the TF-IDF vectors.


In [None]:
from pyspark.ml.linalg import SparseVector
from pyspark.sql.functions import collect_list
from pyspark.ml.feature import VectorAssembler

grouped_tfidf = tfidf_df.groupBy('id').agg(collect_list('tf_idf').alias('tfidf_list'))

# Create Sparse Vectors
def create_sparse_vector(tfidf_list, num_tokens):
    indices = range(num_tokens)
    values = [0.0] * num_tokens  # Initialize all values to 0.0
    for i, tfidf_score in enumerate(tfidf_list):
        values[i] = tfidf_score
    return SparseVector(num_tokens, indices, values)

num_tokens = len(tfidf_df.select('token').distinct().collect())
sparse_vector_udf = udf(create_sparse_vector, VectorUDT())

tfidf_vector_with_sparse = grouped_tfidf.withColumn('tfidf_vector', sparse_vector_udf('tfidf_list', lit(num_tokens)))

# Assemble the Sparse Vectors
assembler = VectorAssembler(inputCols=['tfidf_vector'], outputCol='tfidf_features')
final_tfidf_vector = assembler.transform(tfidf_vector_with_sparse)

# Show the resulting DataFrame
final_tfidf_vector.select('id', 'tfidf_features').show(truncate= 400)


Now that you have the TF-IDF vectors for the documents, here are some useful things to do with the results:

    Document Similarity:
        Calculate the similarity between documents using cosine similarity or other distance metrics. The closer the vectors, the more similar the documents.
        For example, you can find similar documents to a given query document by comparing their TF-IDF vectors.

    Topic Modeling:
        Apply topic modeling techniques (such as Latent Dirichlet Allocation or Non-Negative Matrix Factorization) to discover underlying topics in your corpus.
        Use the TF-IDF vectors as input for these models.

    Classification and Clustering:
        Train machine learning models (e.g., SVM, Random Forest, or k-means) using the TF-IDF vectors as features.
        Classify documents into predefined categories or cluster similar documents together.

    Keyword Extraction:
        Identify important keywords or phrases within each document based on their TF-IDF scores.
        Higher TF-IDF scores indicate more significant terms.

    Search and Retrieval:
        Use the TF-IDF vectors to build an efficient search index for your documents.
        Given a query, retrieve relevant documents based on their similarity to the query.

    Visualizations:
        Visualize the TF-IDF vectors in lower dimensions using techniques like t-SNE or PCA.
        Explore the distribution of documents in the vector space.



In [None]:
from pyspark.ml.linalg import Vectors
from pyspark.sql.functions import col
from pyspark.ml.feature import Normalizer

# Normalize the vectors
normalizer = Normalizer(inputCol="tfidf_features", outputCol="normFeatures")
data = normalizer.transform(final_tfidf_vector)

# Get the first document's features
first_doc_features = data.first().normFeatures


# Define a UDF to compute cosine similarity
# def cosine_similarity(v):
#     return float(first_doc_features.dot(v)) / (first_doc_features.norm(2) * v.norm(2))

# Define a UDF to compute cosine similarity
# def cosine_similarity(v):
#     dot_product = float(first_doc_features.dot(v))
#     norm_product = math.sqrt(sum([i**2 for i in first_doc_features])) * math.sqrt(sum([i**2 for i in v]))
#     return dot_product / norm_product

# def cosine_similarity(v):
#     dot_product = float(first_doc_features.dot(v))
#     norm_product = first_doc_features.norm(2) * v.norm(2)
#     return dot_product / norm_product

def cosine_similarity(v):
    dot_product = float(first_doc_features.dot(v))
    norm_product = float(first_doc_features.norm(2)) * float(v.norm(2))
    return float(dot_product) / float(norm_product)


cosine_similarity_udf = udf(cosine_similarity, FloatType())


# Compute the cosine similarity and add it as a new column
data = data.withColumn("cosine_sim", cosine_similarity_udf(col("normFeatures")))

# Show the top 5 documents
#data.sort(col("cosine_sim").desc()).select('id', 'cosine_sim').show(5)

In [None]:
N = 5
data.orderBy(col('cosine_sim').desc()).select('id', 'cosine_sim').limit(N).show()

In [None]:
spark.stop()