Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-495] Raster data source uses shared FileSystem connections which lead to race condition #1236

Merged
merged 2 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

package org.apache.spark.sql.sedona_sql.io.raster

import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -29,7 +29,6 @@ import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType

import java.io.IOException
import java.nio.file.Paths
import java.util.UUID

private[spark] class RasterFileFormat extends FileFormat with DataSourceRegister {
Expand Down Expand Up @@ -82,7 +81,7 @@ private class RasterFileWriter(savePath: String,
dataSchema: StructType,
context: TaskAttemptContext) extends OutputWriter {

private val hfs = new Path(savePath).getFileSystem(context.getConfiguration)
private val hfs = FileSystem.newInstance(new Path(savePath).toUri, context.getConfiguration)
private val rasterFieldIndex = if (rasterOptions.rasterField.isEmpty) getRasterFieldIndex else dataSchema.fieldIndex(rasterOptions.rasterField.get)

private def getRasterFieldIndex: Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,9 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
val buildingDataLocation: String = resourceFolder + "813_buildings_test.csv"
val smallRasterDataLocation: String = resourceFolder + "raster/test1.tiff"
private val factory = new GeometryFactory()
var hdfsURI: String = _


override def beforeAll(): Unit = {
SedonaContext.create(sparkSession)
// Set up HDFS minicluster
val baseDir = new File("./target/hdfs/").getAbsoluteFile
FileUtil.fullyDelete(baseDir)
val hdfsConf = new HdfsConfiguration
hdfsConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath)
val builder = new MiniDFSCluster.Builder(hdfsConf)
val hdfsCluster = builder.build
hdfsURI = "hdfs://127.0.0.1:" + hdfsCluster.getNameNodePort + "/"
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -237,4 +227,18 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
}).sum
}).sum
}

/**
* Create a mini HDFS cluster and return the HDFS instance and the URI.
* @return (MiniDFSCluster, HDFS URI)
*/
def creatMiniHdfs(): (MiniDFSCluster, String) = {
val baseDir = new File("./target/hdfs/").getAbsoluteFile
FileUtil.fullyDelete(baseDir)
val hdfsConf = new HdfsConfiguration
hdfsConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath)
val builder = new MiniDFSCluster.Builder(hdfsConf)
val hdfsCluster = builder.build
(hdfsCluster, "hdfs://127.0.0.1:" + hdfsCluster.getNameNodePort + "/")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.sedona.sql

import org.apache.commons.io.FileUtils
import org.apache.hadoop.hdfs.MiniDFSCluster
import org.apache.spark.sql.SaveMode
import org.junit.Assert.assertEquals
import org.scalatest.{BeforeAndAfter, GivenWhenThen}
Expand Down Expand Up @@ -149,12 +150,14 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
}

it("should read geotiff using binary source and write geotiff back to hdfs using raster source") {
var rasterDf = sparkSession.read.format("binaryFile").load(rasterdatalocation)
val miniHDFS: (MiniDFSCluster, String) = creatMiniHdfs()
var rasterDf = sparkSession.read.format("binaryFile").load(rasterdatalocation).repartition(3)
val rasterCount = rasterDf.count()
rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(hdfsURI + "/raster-written")
rasterDf = sparkSession.read.format("binaryFile").load(hdfsURI + "/raster-written/*")
rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(miniHDFS._2 + "/raster-written")
rasterDf = sparkSession.read.format("binaryFile").load(miniHDFS._2 + "/raster-written/*")
rasterDf = rasterDf.selectExpr("RS_FromGeoTiff(content)")
assert(rasterDf.count() == rasterCount)
miniHDFS._1.shutdown()
}
}

Expand Down
Loading