In [1]:
import ROOT
from ROOT import TMVA, TFile, TTree, TCut
from subprocess import call
from os.path import isfile

import torch
from torch import nn

Welcome to JupyROOT 6.26/04


In [2]:
TMVA.Tools.Instance()
TMVA.PyMethodBase.PyInitialize()

In [3]:
output = TFile.Open('TMVA.root', 'RECREATE')
factory = TMVA.Factory('TMVAClassification', output,
                       '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')


In [4]:
if not isfile('tmva_class_example.root'):
    call(['curl', '-L', '-O', 'http://root.cern.ch/files/tmva_class_example.root'])


In [5]:
data = TFile.Open('tmva_class_example.root')
signal = data.Get('TreeS')
background = data.Get('TreeB')

In [6]:
dataloader = TMVA.DataLoader('dataset')

In [7]:
for branch in signal.GetListOfBranches():
    dataloader.AddVariable(branch.GetName())

In [8]:
dataloader.AddSignalTree(signal, 1.0)
dataloader.AddBackgroundTree(background, 1.0)
dataloader.PrepareTrainingAndTestTree(TCut(''),
                                      'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')


DataSetInfo              : [dataset] : Added class "Signal"
                         : Add Tree TreeS of type Signal with 6000 events
DataSetInfo              : [dataset] : Added class "Background"
                         : Add Tree TreeB of type Background with 6000 events
                         : Dataset[dataset] : Class index : 0  name : Signal
                         : Dataset[dataset] : Class index : 1  name : Background


In [9]:
#model
model = nn.Sequential()
model.add_module('linear_1', nn.Linear(in_features=4, out_features=64))
model.add_module('relu', nn.ReLU())
model.add_module('linear_2', nn.Linear(in_features=64, out_features=4))
model.add_module('relu', nn.ReLU())
model.add_module('linear_3', nn.Linear(in_features=4, out_features=64))
model.add_module('relu', nn.ReLU())
model.add_module('linear_4', nn.Linear(in_features=64, out_features=2))
model.add_module('softmax', nn.Softmax(dim=1))

In [10]:
# Construct loss function and Optimizer.
loss = torch.nn.MSELoss()
optimizer = torch.optim.SGD


In [11]:
# Define train function
def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
    trainer = optimizer(model.parameters(), lr=0.01)
    schedule, schedulerSteps = scheduler
    best_val = None

    for epoch in range(num_epochs):
        # Training Loop
        # Set to train mode
        model.train()
        running_train_loss = 0.0
        running_val_loss = 0.0
        for i, (X, y) in enumerate(train_loader):
            trainer.zero_grad()
            output = model(X)
            train_loss = criterion(output, y)
            train_loss.backward()
            trainer.step()

            # print train statistics
            running_train_loss += train_loss.item()
            if i % 32 == 31:    # print every 32 mini-batches
                print("[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
                running_train_loss = 0.0

        if schedule:
            schedule(optimizer, epoch, schedulerSteps)

        # Validation Loop
        # Set to eval mode
        model.eval()
        with torch.no_grad():
            for i, (X, y) in enumerate(val_loader):
                output = model(X)
                val_loss = criterion(output, y)
                running_val_loss += val_loss.item()

            curr_val = running_val_loss / len(val_loader)
            if save_best:
                if best_val==None:
                    best_val = curr_val
                best_val = save_best(model, curr_val, best_val)

            # print val statistics per epoch
            print("[{}] val loss: {:.3f}".format(epoch+1, curr_val))
            running_val_loss = 0.0

    print("Finished Training on {} Epochs!".format(epoch+1))

    return model



In [12]:
# Define predict function
def predict(model, test_X, batch_size=32):
    # Set to eval mode
    model.eval()

    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    predictions = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            X = data[0]
            outputs = model(X)
            predictions.append(outputs)
        preds = torch.cat(predictions)

    return preds.numpy()


In [13]:
load_model_custom_objects = {"optimizer": optimizer, "criterion": loss, "train_func": train, "predict_func": predict}


In [14]:
m = torch.jit.script(model)
torch.jit.save(m, "model.pt")
print(m)


RecursiveScriptModule(
  original_name=Sequential
  (linear_1): RecursiveScriptModule(original_name=Linear)
  (relu): RecursiveScriptModule(original_name=ReLU)
  (linear_2): RecursiveScriptModule(original_name=Linear)
  (linear_3): RecursiveScriptModule(original_name=Linear)
  (linear_4): RecursiveScriptModule(original_name=Linear)
  (softmax): RecursiveScriptModule(original_name=Softmax)
)


In [15]:
factory.BookMethod(dataloader, TMVA.Types.kFisher, 'Fisher',
                   '!H:!V:Fisher:VarTransform=D,G')
factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 'PyTorch',
                   'H:!V:VarTransform=D,G:FilenameModel=model.pt:NumEpochs=20:BatchSize=32')



custom objects for loading model :  {'optimizer': <class 'torch.optim.sgd.SGD'>, 'criterion': MSELoss(), 'train_func': <function train at 0x7efe66c95dd0>, 'predict_func': <function predict at 0x7efe66c978c0>}


<cppyy.gbl.TMVA.MethodPyTorch object at 0x562a63ceab40>

Factory                  : Booking method: [1mFisher[0m
                         : 
Fisher                   : [dataset] : Create Transformation "D" with events from all classes.
                         : 
                         : Transformation, Variable selection : 
                         : Input : variable 'var1' <---> Output : variable 'var1'
                         : Input : variable 'var2' <---> Output : variable 'var2'
                         : Input : variable 'var3' <---> Output : variable 'var3'
                         : Input : variable 'var4' <---> Output : variable 'var4'
Fisher                   : [dataset] : Create Transformation "G" with events from all classes.
                         : 
                         : Transformation, Variable selection : 
                         : Input : variable 'var1' <---> Output : variable 'var1'
                         : Input : variable 'var2' <---> Output : variable 'var2'
                         : Input : variable 'v

In [16]:
# Run training, test and evaluation
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()

RecursiveScriptModule(
  original_name=Sequential
  (linear_1): RecursiveScriptModule(original_name=Linear)
  (relu): RecursiveScriptModule(original_name=ReLU)
  (linear_2): RecursiveScriptModule(original_name=Linear)
  (linear_3): RecursiveScriptModule(original_name=Linear)
  (linear_4): RecursiveScriptModule(original_name=Linear)
  (softmax): RecursiveScriptModule(original_name=Softmax)
)
[1, 32] train loss: 0.258
[1, 64] train loss: 0.241
[1, 96] train loss: 0.232
[1, 128] train loss: 0.223
[1, 160] train loss: 0.213
[1, 192] train loss: 0.198
[1] val loss: 0.193
[2, 32] train loss: 0.195
[2, 64] train loss: 0.170
[2, 96] train loss: 0.167
[2, 128] train loss: 0.156
[2, 160] train loss: 0.151
[2, 192] train loss: 0.145
[2] val loss: 0.148
[3, 32] train loss: 0.152
[3, 64] train loss: 0.127
[3, 96] train loss: 0.136
[3, 128] train loss: 0.129
[3, 160] train loss: 0.130
[3, 192] train loss: 0.129
[3] val loss: 0.133
[4, 32] train loss: 0.138
[4, 64] train loss: 0.113
[4, 96] train los

0%, time left: unknown
7%, time left: 0 sec
13%, time left: 0 sec
19%, time left: 0 sec
25%, time left: 0 sec
32%, time left: 0 sec
38%, time left: 0 sec
44%, time left: 0 sec
50%, time left: 0 sec
57%, time left: 0 sec
63%, time left: 0 sec
69%, time left: 0 sec
75%, time left: 0 sec
82%, time left: 0 sec
88%, time left: 0 sec
94%, time left: 0 sec
0%, time left: unknown
7%, time left: 0 sec
13%, time left: 0 sec
19%, time left: 0 sec
25%, time left: 0 sec
32%, time left: 0 sec
38%, time left: 0 sec
44%, time left: 0 sec
50%, time left: 0 sec
57%, time left: 0 sec
63%, time left: 0 sec
69%, time left: 0 sec
75%, time left: 0 sec
82%, time left: 0 sec
88%, time left: 0 sec
94%, time left: 0 sec


In [17]:
# Plot ROC Curves
roc = factory.GetROCCurve(dataloader)
roc.SaveAs('ROC_ClassificationPyTorch.png')

Info in <TCanvas::Print>: png file ROC_ClassificationPyTorch.png has been created
