Skip to content

Commit

Permalink
feat: Modularize statismo-io for better composition (#425)
Browse files Browse the repository at this point in the history
Feature: Modularize statismo io for better composability
  • Loading branch information
Andreas-Forster committed Feb 5, 2024
2 parents 2ed8c7e + ad460ad commit e59696f
Showing 1 changed file with 79 additions and 29 deletions.
108 changes: 79 additions & 29 deletions src/main/scala/scalismo/io/statisticalmodel/StatismoIO.scala
Expand Up @@ -82,10 +82,22 @@ object StatismoIO {
canWarp: DomainWarp[D, DDomain],
vectorizer: Vectorizer[EuclideanVector[D]]
): Try[PointDistributionModel[D, DDomain]] = {

val modelOrFailure = for {
for {
h5file <- StatisticalModelIOUtils.openFileForReading(file)
model <- readStatismoPDM(h5file, modelPath)
_ <- Try(h5file.close())
} yield model
}

private[io] def readStatismoPDM[D: NDSpace, DDomain[D] <: DiscreteDomain[D]](
h5file: StatisticalModelReader,
modelPath: HDFPath
)(implicit
typeHelper: StatismoDomainIO[D, DDomain],
canWarp: DomainWarp[D, DDomain],
vectorizer: Vectorizer[EuclideanVector[D]]
): Try[PointDistributionModel[D, DDomain]] = {
val modelOrFailure = for {
mesh <- h5file.readStringAttribute(HDFPath(modelPath, "representer"), "datasetType") match {
case Success("POINT_SET") => readPointSetRepresentation(h5file, modelPath)
case Success("POLYGON_MESH") => readStandardMeshRepresentation(h5file, modelPath)
Expand All @@ -95,21 +107,16 @@ object StatismoIO {
case Success(datasetType) =>
Failure(new Exception(s"cannot read model of datasetType $datasetType"))
case Failure(t) => Failure(t)

}
meanVector <- readStandardMeanVector(h5file, modelPath)

(pcaVarianceVector, pcaBasis) <- readStandardPCAbasis(h5file, modelPath)

_ <- Try(h5file.close())
} yield {
val refVector: DenseVector[Double] = DenseVector(
mesh.pointSet.points.toIndexedSeq.flatMap(p => p.toBreezeVector.toArray).toArray
)
val meanDefVector: DenseVector[Double] = meanVector - refVector
PointDistributionModel[D, DDomain](mesh, meanDefVector, pcaVarianceVector, pcaBasis)
}

modelOrFailure
}

Expand All @@ -118,13 +125,23 @@ object StatismoIO {
file: File,
modelPath: HDFPath = HDFPath("/")
)(implicit typeHelper: StatismoDomainIO[D, DDomain]): Try[Unit] = {
val discretizedMean = model.mean.pointSet.points.toIndexedSeq.flatten(_.toArray)
for {
h5file <- StatisticalModelIOUtils.createFile(file)
_ <- writeStatismoPDM(model, h5file, modelPath)
_ <- Try(h5file.close())
} yield ()
}

private[io] def writeStatismoPDM[D: NDSpace, DDomain[D] <: DiscreteDomain[D]](
model: PointDistributionModel[D, DDomain],
h5file: HDF5Writer,
modelPath: HDFPath
)(implicit typeHelper: StatismoDomainIO[D, DDomain]): Try[Unit] = {
val discretizedMean = model.mean.pointSet.points.toIndexedSeq.flatten(_.toArray)
val variance = model.gp.variance

val pcaBasis = model.gp.basisMatrix.copy

val maybeError = for {
h5file <- StatisticalModelIOUtils.createFile(file)
_ <- h5file.writeArray[Float](HDFPath(modelPath, "model/mean"), discretizedMean.toArray.map(_.toFloat))
_ <- h5file.writeArray[Float](HDFPath(modelPath, "model/noiseVariance"), Array(0f))
_ <- h5file.writeNDArray[Float](
Expand All @@ -150,9 +167,6 @@ object StatismoIO {
_ <- h5file.createGroup(HDFPath(modelPath, "modelinfo/modelBuilder-0/parameters"))
_ <- h5file.createGroup(HDFPath(modelPath, "modelinfo/modelBuilder-0/dataInfo"))
_ <- h5file.write()
_ <- Try {
h5file.close()
}
} yield ()

maybeError
Expand Down Expand Up @@ -343,14 +357,24 @@ object StatismoIO {
file: File,
modelPath: HDFPath
): Try[Unit] = {
for {
h5file <- StatisticalModelIOUtils.createFile(file)
_ <- writeStatismoImageModel(gp, h5file, modelPath)
_ <- Try(h5file.close())
} yield ()
}

private[io] def writeStatismoImageModel[D: NDSpace, A: Vectorizer](
gp: DiscreteLowRankGaussianProcess[D, DiscreteImageDomain, A],
h5file: HDF5Writer,
modelPath: HDFPath
): Try[Unit] = {

val discretizedMean = gp.meanVector.map(_.toFloat)
val variance = gp.variance.map(_.toFloat)

val pcaBasis = gp.basisMatrix.copy.map(_.toFloat)

val maybeError = for {
h5file <- StatisticalModelIOUtils.createFile(file)
_ <- h5file.writeArray(HDFPath(modelPath, "model/mean"), discretizedMean.toArray)
_ <- h5file.writeArray(HDFPath(modelPath, "model/noiseVariance"), Array(0f))
_ <- h5file.writeNDArray(
Expand All @@ -375,9 +399,6 @@ object StatismoIO {
)
_ <- h5file.createGroup(HDFPath(modelPath, "modelinfo/modelBuilder-0/parameters"))
_ <- h5file.createGroup(HDFPath(modelPath, "modelinfo/modelBuilder-0/dataInfo"))
_ <- Try {
h5file.close()
}
} yield ()

maybeError
Expand Down Expand Up @@ -447,10 +468,19 @@ object StatismoIO {
file: java.io.File,
modelPath: HDFPath = HDFPath("/")
): Try[DiscreteLowRankGaussianProcess[D, DiscreteImageDomain, A]] = {

val modelOrFailure = for {
for {
h5file <- StatisticalModelIOUtils.openFileForReading(file)
model <- readStatismoImageModel(h5file, modelPath)
_ <- Try(h5file.close())
} yield model
}

private[io] def readStatismoImageModel[D: NDSpace: CreateStructuredPoints, A: Vectorizer](
h5file: StatisticalModelReader,
modelPath: HDFPath
): Try[DiscreteLowRankGaussianProcess[D, DiscreteImageDomain, A]] = {

val modelOrFailure = for {
representerName <- h5file.readStringAttribute(HDFPath(modelPath, "representer"), "name")
// read mesh according to type given in representer
image <- representerName match {
Expand Down Expand Up @@ -573,19 +603,30 @@ object StatismoIO {
euclidVecVectorizer: Vectorizer[EuclideanVector[D]],
scalarVectorizer: Vectorizer[S]
): Try[DiscreteLowRankGaussianProcess[D, DDomain, S]] = {
for {
h5file <- StatisticalModelIOUtils.openFileForReading(file)
model <- readIntensityModel(h5file, modelPath)
_ <- Try(h5file.close())
} yield model
}

private[io] def readIntensityModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D], S: Scalar](
h5file: StatisticalModelReader,
modelPath: HDFPath
)(implicit
domainIO: StatismoDomainIO[D, DDomain],
euclidVecVectorizer: Vectorizer[EuclideanVector[D]],
scalarVectorizer: Vectorizer[S]
): Try[DiscreteLowRankGaussianProcess[D, DDomain, S]] = {

val modelOrFailure = for {
h5file <- StatisticalModelIOUtils.openFileForReading(file)
domain <- readStandardMeshRepresentation(h5file, modelPath)
meanArray <- h5file.readArrayFloat(HDFPath(modelPath, "model/mean"))
meanVector = DenseVector(meanArray.map(_.toDouble))
pcaBasisArray <- h5file.readNDArrayFloat(HDFPath(modelPath, "model/pcaBasis"))
pcaVarianceArray <- h5file.readArrayFloat(HDFPath(modelPath, "model/pcaVariance"))
pcaVarianceVector = DenseVector(pcaVarianceArray.map(_.toDouble))
pcaBasisMatrix = ndFloatArrayToDoubleMatrix(pcaBasisArray)
_ <- Try {
h5file.close()
}
} yield {

val dgp = new DiscreteLowRankGaussianProcess[D, DDomain, S](
Expand All @@ -605,14 +646,26 @@ object StatismoIO {
gp: DiscreteLowRankGaussianProcess[D, DDomain, S],
file: File,
modelPath: HDFPath = HDFPath("/")
)(implicit domainIO: StatismoDomainIO[D, DDomain]): Try[Unit] = {
for {
h5file <- StatisticalModelIOUtils.createFile(file = file)
model <- writeIntensityModel(gp, h5file, modelPath)
_ <- Try(h5file.close())
} yield model
}

private[io] def writeIntensityModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D], S: Scalar](
gp: DiscreteLowRankGaussianProcess[D, DDomain, S],
h5file: HDF5Writer,
modelPath: HDFPath
)(implicit domainIO: StatismoDomainIO[D, DDomain]): Try[Unit] = {
val meanVector = gp.meanVector.toArray
val variance = gp.variance
val pcaBasis = gp.basisMatrix.copy

val representerPath = HDFPath(modelPath, "representer")

val maybeError = for {
h5file <- StatisticalModelIOUtils.createFile(file = file)
representerPath = HDFPath(modelPath, "representer")
_ <- writeRepresenterStatismov090(h5file, representerPath, gp.domain, modelPath)
_ <- h5file.writeArray[Float](HDFPath(modelPath, "model/mean"), meanVector.map(_.toFloat))
_ <- h5file.writeArray[Float](HDFPath(modelPath, "model/noiseVariance"), Array(0f))
Expand All @@ -632,9 +685,6 @@ object StatismoIO {
_ <- h5file.writeString(HDFPath(modelPath, "modelinfo/modelBuilder-0/builderName"), "scalismo")
_ <- h5file.createGroup(HDFPath(modelPath, "modelinfo/modelBuilder-0/parameters"))
_ <- h5file.createGroup(HDFPath(modelPath, "modelinfo/modelBuilder-0/dataInfo"))
_ <- Try {
h5file.close()
}
} yield ()

maybeError
Expand Down

0 comments on commit e59696f

Please sign in to comment.