In [1]:
%AddJar file:///home/jovyan/data/magpie/dist/Magpie.jar

Starting download from file:///home/jovyan/data/magpie/dist/Magpie.jar
Finished download of Magpie.jar


# Build a Hierarchical Model for Predicting Band Gap Energies
Create a model that predicts the band gap energies of materials given their compositions. 

In [2]:
import java.io.PrintWriter;

In [3]:
import magpie.data.Dataset;
import magpie.data.materials.CompositionDataset;
import magpie.data.utilities.filters.ContainsElementFilter;
import magpie.data.utilities.splitters.{MultipleElementGroupsSplitter, PredictedClassIntervalSplitter};
import magpie.data.materials.util.PropertyLists;
import magpie.models.BaseModel;
import magpie.models.regression.{SplitRegression, WekaRegression, RandomGuessRegression};
import magpie.models.classification.WekaClassifier;
import magpie.optimization.rankers.TargetEntryRanker;

## Load in the Data
Read in the data from disk (which automatically removes duplicates in this version of Magpie), and remove noble-gas-containing materials

In [4]:
val data = new CompositionDataset()

In [5]:
data.importText("/home/jovyan/data/datasets/bandgap.data", null);
println(s"Read in ${data.NEntries} entries")

Read in 25212 entries


In [6]:
val filter = new ContainsElementFilter();

In [7]:
filter.setElementList(Array[String]("He","Ne","Ar","Kr","Xe"))

In [8]:
filter.setExclude(true)

In [9]:
filter.filter(data)
println(s"Trimmed data to ${data.NEntries} entries.")

Trimmed data to 25186 entries.


In [10]:
data.setTargetProperty("bandgap", false);

## Compute the Representation
Generates the attributes that serve as inputs into the ML model.

In [11]:
data.setDataDirectory("/home/jovyan/data/magpie/Lookup Data")

In [12]:
for (prop <- PropertyLists.getPropertySet("general")) {
    data.addElementalProperty(prop);
}

In [13]:
data.generateAttributes()
println(s"Generated ${data.NAttributes} attributes")

Generated 145 attributes


## Create the Machine Learning Models
Two separate models: one that is just an ensemble of decision trees, and another that is a hierarchical collection of many different machine learning models.

### Create the Single Model
Just an ensemble of Reduced Error Pruning Trees (REPTrees)

In [14]:
val singleModel = new WekaRegression("meta.RandomSubSpace",
        Array[String]("-num-slots", "0", "-W", "weka.classifiers.trees.REPTree", "--", "-N", "7")
    );

### Create the Ensemble Model
The model we converged on in the paper trains different regression models on different parts of the dataset. On training, the model first trains a classification model to predict the range of the band gap: <=0, 0-1.5, 1.5-3, or >3 eV. We then train a regression model on all data predicted to be in the '<=0 eV' class. For entries flagged with the other labels, we then separate out all of the halogen-containing entries, then the chalcogen-containing entries, and the pnictide-containing entries - yielding 4 separate subsets. We then train a regression model on each subset. 

When evaluating the model, we first assign a label to the entry using the classifier and then select the appropriate regression model based on that label and the composition of the entry. 

The model looks something like this:
<img src="./figures/band_gap_model_fig.png" style="width: 450px;"/>

In [15]:
val hierModel = new SplitRegression();

Add in first level of splitting: a classifier to predict the range of the band gap energy

In [16]:
val clfrSplitter = new PredictedClassIntervalSplitter();
val clfr = new WekaClassifier("meta.RandomSubSpace",
        Array[String]("-num-slots", "0", "-W", "weka.classifiers.trees.REPTree", "--", "-N", "7")
    );
clfrSplitter.setClassifier(clfr);
clfrSplitter.setEdges(Array[Double](0, 1.5, 3))
hierModel.setPartitioner(clfrSplitter);

Add in the second level, splitting based on the element types

In [17]:
val secondLevel = new SplitRegression();

In [18]:
val elemSplitter = new MultipleElementGroupsSplitter();
elemSplitter.addElementGroup("F Cl Br I At");
elemSplitter.addElementGroup("O S Se Te Po");
elemSplitter.addElementGroup("N P As Sb Bi");
secondLevel.setPartitioner(elemSplitter);

In [19]:
secondLevel.setGenericModel(singleModel);

Assemble the whole thing

In [20]:
hierModel.setGenericModel(secondLevel);
hierModel.setModel(0, singleModel);

## Test the Model
Our overall goal is to find new materials with band gap energies in the range suitable for solar cells. To test our model's effectiveness for identifying these materials, we employ a cross-validation test where we train our model on 90% of the original dataset and evaluate its ability to identify materials with band gap energies between 0.9-1.7 eV in the remaining 10%. Specifically, we select the 30 materials predicted to have band gaps closest to the center of this range and then evaluate how many of those actually have a band gap in the target range. 

In [21]:
val test_size = 0.1;
val n_repeats = 100;

In [22]:
val randomGuess = new RandomGuessRegression();

In [23]:
val ranker = new TargetEntryRanker(1.3);
ranker.setMaximizeFunction(false);

In [24]:
/**
  * Run the cross-validation test
  */
def runTest(trainData : Dataset, testData : Dataset, model : BaseModel) : Integer = {
    model.train(trainData);
    model.run(testData);
    
    // Rank the materials by their distance from the center of the range
    val ranks = ranker.rankEntries(testData, false);
    
    // Evaluate how many from the top 3 are within the desired range
    var score = 0;
    for (i <- 0 until 30) {
        val x = testData.getEntry(ranks(i)).getMeasuredClass();
        if (x >= 0.9 && x <= 1.7) {
            score = score + 1;
        }
    }
    
    return score;
}

In [25]:
val fp = new PrintWriter("cv-results.csv");
fp.println("random,single,hierarchical");
for (i <- 1 to n_repeats) {
    print(s"\rIteration ${i}/${n_repeats}")
    // Get the training and test set    
    val trainData = data.clone();
    val testData = trainData.getRandomSubset(0.1);
       
    // Train the simple model
    fp.print(s"${runTest(trainData, testData, randomGuess)},")
    fp.print(s"${runTest(trainData, testData, singleModel)},")
    fp.println(s"${runTest(trainData, testData, hierModel)}")
    fp.flush();
}
fp.close()

Iteration 100/100

## Save the Model
Save the model using Java's serialization methods, so that it can be re-used in a later script

In [26]:
hierModel.saveState("bandgap-model-template.obj")

In [27]:
data.emptyClone().saveState("bandgap-model-dataset-template.obj")