In [1]:
use strict;
use warnings;
use Data::Dump qw(dump);
use List::Util qw(zip min max sum);
use sml;
use AI::MXNet qw(mx);

In [2]:
sub evaluate_algorithm_train_test_split {
    my ($self, $dataset, $algorithm, %args) = ((splice @_, 0, 3), split =>0.6, metric => undef, @_);
    my ($train, $test) = train_test_split($dataset, split => $args{split});
    my $train_x = $train->slice_axis(axis => 1, begin => 0, end => $train->shape->[1] - 1);
    my $train_y = $train->slice_axis(axis => 1, begin => $train->shape->[1] - 1, end => $train->shape->[1]);

    my $test_x  = $test->slice_axis(axis => 1, begin => 0, end => $test->shape->[1] - 1);
    my $test_y  = $test->slice_axis(axis => 1, begin => $test->shape->[1] - 1, end => $test->shape->[1]);

    my $predicted = $algorithm->('sml', $train, $test_x, @_);

    my $actual    = $test_y->reshape([$test_y->shape->[0]]);  
    my $score;
    if (defined $args{metric}) {
        if ($args{metric} =~ /accuracy/i) {
            $score = sml->accuracy_metric($actual, $predicted);
        } elsif ($args{metric} =~ /rmse/i) {
            $score = sml->rmse_metric($actual, $predicted);
        }
    } else {
        # Heurística: si los valores reales parecen flotantes, usar RMSE
        $score = (grep { $_ =~ /\d+\.\d+/ } @$actual)
               ? sml->rmse_metric($actual, $predicted)
               : sml->accuracy_metric($actual, $predicted);
    }

    return wantarray ? ($score, $train, $test, $actual, $predicted) : $score;
}

sml->add_to_class('evaluate_algorithm_train_test_split', \&evaluate_algorithm_train_test_split);

*sml::evaluate_algorithm_train_test_split

In [3]:
mx->random->seed(1);
my $filename = '../data/pima-indians-diabetes.csv';
my $dataset = sml->load_csv($filename);
$dataset = mx->nd->array($dataset);
my $split = 0.6;
my ($accuracy, $train, $test, $actual, $predicted) = 
sml-> evaluate_algorithm_train_test_split($dataset, \&{'sml::zero_rule_algorithm_classification'}, split => $split, metric => 'accuracy');
print $accuracy;

66.56

1

In [4]:
my ($unique, $matrix) = sml->confusion_matrix($actual, $predicted);
sml->print_confusion_matrix($unique, $matrix)

A/P<AI::MXNet::NDArray 1 @cpu(0)> <AI::MXNet::NDArray 1 @cpu(0)>
A/P [0 1]
[
 [  0 205   0]
 [  1 103   0]
]


1

In [13]:
sub evaluate_algorithm_cross_validation_split {
    my ($self, $dataset, $algorithm, %args) = splice(@_, 0, 3);
    $args{n_folds} //= 10;
    $args{metric}  //= undef;

    my $folds_nd    = sml->cross_validation_split($dataset, n_folds => $args{n_folds});
    my $n_folds     = $folds_nd->shape->[0];
    my $fold_size   = $folds_nd->shape->[1];
    my $n_features  = $folds_nd->shape->[2];

    my (@scores, @train_losses, @test_losses, @actuals, @preds);

    for my $i (0 .. $n_folds - 1) {
        # test_nd y train_nd contiguos
        my $test_nd = $folds_nd->slice_axis(axis=>0,begin=>$i,end=>$i+1)
                              ->squeeze(axis=>0)->copy;
        my @train_parts;
        for my $j (0 .. $n_folds - 1) {
            next if $j == $i;
            push @train_parts,
                   $folds_nd->slice_axis(axis=>0,begin=>$j,end=>$j+1)
                            ->squeeze(axis=>0)->copy;
        }
        my $train_nd = mx->nd->concat(@train_parts, { dim=>0 })->copy;

        # test_input: todas las columnas menos la última
        my $test_input = $n_features>1
          ? $test_nd->slice_axis(axis=>1, begin=>0, end=>$n_features-1)
          : $test_nd;

        # ejecutar
        my ($pred_nd, $train_loss, $test_loss) =
            $algorithm->('sml', $train_nd, $test_input, %args);

        # extraer etiqueta real
        my $actual_nd = $test_nd->slice_axis(axis=>1,
                             begin=>$n_features-1, end=>$n_features)
                                ->copy;

        # aplanar si es 2D->[N,1]
        if ($actual_nd->ndim==2 && $actual_nd->shape->[1]==1) {
            $actual_nd = $actual_nd->reshape([-1]);
        }
        if ($pred_nd->ndim==2 && ($pred_nd->shape->[1]//0)==1) {
            $pred_nd   = $pred_nd  ->reshape([-1]);
        }

        # calcular métrica, con captura de fallo en aspdl
        my $score;
        if (defined $args{metric}) {
            if ($args{metric}=~ /accuracy/i) {
                $score = sml->accuracy_metric($actual_nd, $pred_nd);
            }
            else {
                $score = sml->rmse_metric($actual_nd, $pred_nd);
            }
        }
        else {
            my $aref;
            eval { $aref = $actual_nd->aspdl->list };
            if ($@) {
                die <<"ERR";
*** ASPDL ERROR AT FOLD $i ***
folds_nd shape   = [${n_folds}x${fold_size}x${n_features}]
train_nd shape   = [@{\$train_nd->shape}]
test_nd shape    = [@{\$test_nd->shape}]
test_input shape = [@{\$test_input->shape}]
raw actual_nd    = [@{\$actual_nd->shape}] (size=@{\$actual_nd->size})
exception        = $@
ERR
            }
            $score = (grep{/\\d+\\.\\d+/}@{$aref})
                   ? sml->rmse_metric($actual_nd, $pred_nd)
                   : sml->accuracy_metric($actual_nd, $pred_nd);
        }

        push @scores,       $score;
        push @train_losses, $train_loss;
        push @test_losses,  $test_loss;
        push @actuals,      $actual_nd;
        push @preds,        $pred_nd;
    }

    my $train_losses_tensor = mx->nd->array(\@train_losses);
    my $test_losses_tensor  = mx->nd->array(\@test_losses);
    my $actuals_tensor      = mx->nd->concat(@actuals, { dim=>0 });
    my $preds_tensor        = mx->nd->concat(@preds,   { dim=>0 });

    return wantarray
      ? (\@scores, $train_losses_tensor, $test_losses_tensor, $actuals_tensor, $preds_tensor)
      : \@scores;
}

sml->add_to_class('evaluate_algorithm_cross_validation_split', \&evaluate_algorithm_cross_validation_split);



*sml::evaluate_algorithm_cross_validation_split

Warning: Subroutine evaluate_algorithm_cross_validation_split redefined at reply input line 1.

Subroutine sml::evaluate_algorithm_cross_validation_split redefined at /usr/local/share/perl5/5.34/x86_64-linux-thread-multi/sml.pm line 16.


In [14]:
mx->random->seed(1);
$filename = '../data/pima-indians-diabetes.csv';
$dataset = sml->load_csv($filename);
$dataset = mx->nd->array($dataset);
my $n_fold = 5;
my ($scores, $train_losses, $test_losses, $actuals, $predictions) = sml->evaluate_algorithm_cross_validation_split(
    $dataset,
    \&sml::zero_rule_algorithm_classification,
    n_folds => $n_fold,
    metric  => 'accuracy');

Error: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (756 vs. 684) : new and old shape do not match total elements
Stack trace:
  File "/opt/softwares/apache-mxnet-src-1.9.1-incubating/include/mxnet/././tensor_blob.h", line 310
 at /usr/local/share/perl5/5.34/AI/MXNet/Base.pm line 303.
	AI::MXNet::Base::check_call(-1) called at /usr/local/share/perl5/5.34/AI/MXNet/NDArray.pm line 329
	AI::MXNet::NDArray::aspdl(AI::MXNet::NDArray=HASH(0x563c14f5cbb8)) called at /usr/local/share/perl5/5.34/x86_64-linux-thread-multi/sml.pm line 502
	main::__ANON__("sml", AI::MXNet::NDArray=HASH(0x563c137db588), AI::MXNet::NDArray=HASH(0x563c134556b8), "metric", undef, "n_folds", 10) called at reply input line 32
	main::evaluate_algorithm_cross_validation_split(undef, undef, undef, "n_folds", 5, "metric", "accuracy") called at reply input line 6
	Eval::Closure::Sandbox_668::__ANON__() called at /usr/local/share/perl5/5.34/Reply/Plugin/Defaults.pm line 71
	Reply::Plugin::Defaults::execute(Reply::Plugin::Defaults=HASH(0x563c108a7eb8), CODE(0x563c144355d0), CODE(0x563c1516f2e0)) called at /usr/local/share/perl5/5.34/Reply.pm line 217
	Reply::_wrapped_plugin(Reply=HASH(0x563c1092db78), ARRAY(0x563c109cc340), "execute", CODE(0x563c1516f2e0)) called at /usr/local/share/perl5/5.34/Reply.pm line 215
	Reply::__ANON__(CODE(0x563c1516f2e0)) called at /usr/local/share/perl5/5.34/Reply/Plugin/IPerl.pm line 28
	Reply::Plugin::IPerl::__ANON__() called at /usr/share/perl5/Capture/Tiny.pm line 382
	eval {...} called at /usr/share/perl5/Capture/Tiny.pm line 382
	Capture::Tiny::_capture_tee(1, 1, 0, 0, CODE(0x563c142cc738)) called at /usr/local/share/perl5/5.34/Reply/Plugin/IPerl.pm line 29
	Reply::Plugin::IPerl::execute(Reply::Plugin::IPerl=HASH(0x563c1096f890), CODE(0x563c1545f938), CODE(0x563c1516f2e0)) called at /usr/local/share/perl5/5.34/Reply.pm line 217
	Reply::_wrapped_plugin(Reply=HASH(0x563c1092db78), "execute", CODE(0x563c1516f2e0)) called at /usr/local/share/perl5/5.34/Reply.pm line 174
	Reply::_eval(Reply=HASH(0x563c1092db78), "\x{a}#line 1 \"reply input\"\x{a}mx->random->seed(1);\x{a}\$filename = '../d"...) called at /usr/local/share/perl5/5.34/Reply.pm line 66
	Reply::try {...} () called at /usr/share/perl5/vendor_perl/Try/Tiny.pm line 102
	eval {...} called at /usr/share/perl5/vendor_perl/Try/Tiny.pm line 93
	Try::Tiny::try(CODE(0x563c1542de20), Try::Tiny::Catch=REF(0x563c146ee8a8)) called at /usr/local/share/perl5/5.34/Reply.pm line 71
	Reply::step(Reply=HASH(0x563c1092db78), "mx->random->seed(1);\x{a}\$filename = '../data/pima-indians-diabet"..., 0) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel/Backend/Reply.pm line 48
	Devel::IPerl::Kernel::Backend::Reply::__ANON__() called at /usr/share/perl5/Capture/Tiny.pm line 382
	eval {...} called at /usr/share/perl5/Capture/Tiny.pm line 382
	Capture::Tiny::_capture_tee(1, 1, 0, 0, CODE(0x563c12431448)) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel/Backend/Reply.pm line 49
	Devel::IPerl::Kernel::Backend::Reply::run_line(Devel::IPerl::Kernel::Backend::Reply=HASH(0x563c0de29280), "mx->random->seed(1);\x{a}\$filename = '../data/pima-indians-diabet"...) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel/Callback/REPL.pm line 42
	Devel::IPerl::Kernel::Callback::REPL::execute(Devel::IPerl::Kernel::Callback::REPL=HASH(0x563c0e1af600), Devel::IPerl::Kernel=HASH(0x563c0d1a13a8), Devel::IPerl::Message::ZMQ=HASH(0x563c14425460)) called at (eval 69) line 6
	Devel::IPerl::Kernel::Callback::REPL::execute(Devel::IPerl::Kernel::Callback::REPL=HASH(0x563c0e1af600), Devel::IPerl::Kernel=HASH(0x563c0d1a13a8), Devel::IPerl::Message::ZMQ=HASH(0x563c14425460)) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel/Callback/REPL.pm line 156
	Devel::IPerl::Kernel::Callback::REPL::msg_execute_request(Devel::IPerl::Kernel::Callback::REPL=HASH(0x563c0e1af600), Devel::IPerl::Kernel=HASH(0x563c0d1a13a8), Devel::IPerl::Message::ZMQ=HASH(0x563c14425460), ZMQ::FFI::ZMQ4_1::Socket=HASH(0x563c11094760)) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel.pm line 236
	Devel::IPerl::Kernel::route_message(Devel::IPerl::Kernel=HASH(0x563c0d1a13a8), ARRAY(0x563c110b9a28), ZMQ::FFI::ZMQ4_1::Socket=HASH(0x563c11094760)) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel.pm line 209
	Devel::IPerl::Kernel::__ANON__(Net::Async::ZMQ::Socket=HASH(0x563c110c6578)) called at /usr/local/share/perl5/5.34/IO/Async/Loop/Poll.pm line 171
	IO::Async::Loop::Poll::post_poll(IO::Async::Loop::Poll=HASH(0x563c10f26f78)) called at /usr/local/share/perl5/5.34/IO/Async/Loop/Poll.pm line 291
	IO::Async::Loop::Poll::loop_once(IO::Async::Loop::Poll=HASH(0x563c10f26f78), undef) called at /usr/local/share/perl5/5.34/IO/Async/Loop.pm line 537
	IO::Async::Loop::run(IO::Async::Loop::Poll=HASH(0x563c10f26f78)) called at /usr/local/share/perl5/5.34/IO/Async/Loop.pm line 574
	IO::Async::Loop::loop_forever(IO::Async::Loop::Poll=HASH(0x563c10f26f78)) called at /usr/local/share/perl5/5.34/Devel/IPerl/Kernel.pm line 219
	Devel::IPerl::Kernel::run(Devel::IPerl::Kernel=HASH(0x563c0d1a13a8)) called at /usr/local/share/perl5/5.34/Devel/IPerl.pm line 14
	Devel::IPerl::main() called at -e line 1
