<a href="https://colab.research.google.com/github/RafaelNovais/MasterAI/blob/master/Assignment_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.sql.functions;

public class TweetClustering {

    public static void main(String[] args) {
        // Create SparkSession
        SparkSession spark = SparkSession.builder()
                .appName("Tweet Clustering")
                .master("local[*]") // Set master to run locally
                .getOrCreate();

        // Load tweets data into DataFrame
        Dataset<Row> tweetsDF = spark.read().option("header", true).csv("tweets.csv");

        // Load city data into DataFrame
        Dataset<Row> cityDF = spark.read().option("header", true).csv("cities.csv");

        // Join tweetsDF and cityDF based on tweet_location and city_name
        Dataset<Row> joinedDF = tweetsDF.join(cityDF,
                                    functions.lower(tweetsDF.col("tweet_location"))
                                    .equalTo(functions.lower(cityDF.col("name"))), "left");

        // Select relevant columns for clustering
        Dataset<Row> selectedDF = joinedDF.select("tweet_text", "latitude", "longitude");

        // Drop rows with missing or null location-related properties
        selectedDF = selectedDF.filter("latitude is not null AND longitude is not null");

        // Assemble feature vector
        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(new String[]{"latitude", "longitude"})
                .setOutputCol("features");

        // Initialize KMeans model
        KMeans kmeans = new KMeans()
                .setK(100) // Number of clusters
                .setSeed(12345L); // Set seed for reproducibility

        // Create pipeline
        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{assembler, kmeans});

        // Fit pipeline to data
        PipelineModel model = pipeline.fit(selectedDF);

        // Get cluster predictions
        Dataset<Row> predictions = model.transform(selectedDF);

        // Print tweet texts and their respective cluster indexes (predictions) for 1000 tweets
        Dataset<Row> sampledPredictions = predictions.limit(1000);
        sampledPredictions.select("tweet_text", "prediction").show(false);

        // Save model
        model.write().overwrite().save("tweet_clustering_model");

        // Stop SparkSession
        spark.stop();
    }
}
