# Classification workflow

Anton Antonov  
RakuForPrediction at WordPress   
October 2025

----

## Setup

In [1]:
use Data::Reshapers;
use Data::Importers;
use Data::Summarizers;
use ML::ROCFunctions;

use ML::SparseMatrixRecommender;

In [2]:
#% js
js-d3-list-line-plot(10.rand xx 40, background => 'none', stroke-width => 2)

In [3]:
my $title-color = 'Silver';
my $stroke-color = 'SlateGray';
my $tooltip-color = 'LightBlue';
my $tooltip-background-color = 'none';
my $tick-labels-font-size = 10;
my $tick-labels-color = 'Silver';
my $tick-labels-font-family = 'Helvetica';
my $background = '#1F1F1F';
my $color-scheme = 'schemeTableau10';
my $color-palette = 'Inferno';
my $edge-thickness = 3;
my $vertex-size = 6;
my $mmd-theme = q:to/END/;
%%{
  init: {
    'theme': 'forest',
    'themeVariables': {
      'lineColor': 'Ivory'
    }
  }
}%%
END
my %force = collision => {iterations => 0, radius => 10},link => {distance => 180};
my %force2 = charge => {strength => -30, iterations => 4}, collision => {radius => 50, iterations => 4}, link => {distance => 30};

my %opts = :$background, :$title-color, :$edge-thickness, :$vertex-size;

{background => #1F1F1F, edge-thickness => 3, title-color => Silver, vertex-size => 6}

----

## Ingestion

In [4]:
my $url = 'https://raw.githubusercontent.com/antononcube/MathematicaVsR/refs/heads/master/Data/MathematicaVsR-Data-Mushroom.csv';
my @dsData = data-import($url, headers => 'auto');

@dsData.&dimensions

(8124 24)

In [5]:
deduce-type(@dsData);

Vector(Assoc(Atom((Str)), Atom((Str)), 24), 8124)

---

## Preliminary data analysis tabulation

Here is a summary of the mushroom data:

In [6]:
my @field-names = <cap-Shape cap-Surface cap-Color bruises? odor gill-Attachment gill-Spacing gill-Size gill-Color edibility>;
sink records-summary(@dsData, :@field-names)

+-----------------+-----------------+-----------------+---------------+-----------------+------------------+-----------------+----------------+-------------------+-------------------+
| cap-Shape       | cap-Surface     | cap-Color       | bruises?      | odor            | gill-Attachment  | gill-Spacing    | gill-Size      | gill-Color        | edibility         |
+-----------------+-----------------+-----------------+---------------+-----------------+------------------+-----------------+----------------+-------------------+-------------------+
| convex  => 3656 | scaly   => 3244 | brown   => 2284 | False => 4748 | none    => 3528 | free     => 7914 | close   => 6812 | broad  => 5612 | buff      => 1728 | edible    => 4208 |
| flat    => 3152 | smooth  => 2556 | gray    => 1840 | True  => 3376 | foul    => 2160 | attached => 210  | crowded => 1312 | narrow => 2512 | pink      => 1492 | poisonous => 3916 |
| knobbed => 828  | fibrous => 2320 | red     => 1500 |               | spicy   

Before classifying for edibility consider this relationship between edibility and odor:

In [7]:
cross-tabulate(@dsData, 'odor', 'edibility')
==> to-pretty-table()

+----------+-----------+--------+
|          | poisonous | edible |
+----------+-----------+--------+
| almond   |           |  400   |
| anise    |           |  400   |
| creosote |    192    |        |
| fishy    |    576    |        |
| foul     |    2160   |        |
| musty    |     36    |        |
| none     |    120    |  3408  |
| pungent  |    256    |        |
| spicy    |    576    |        |
+----------+-----------+--------+

We can see that mushrooms with any odor are much more likely to be poisonous. Alternatively, mushrooms without bruise are much more likely to be poisonous:

In [8]:
cross-tabulate(@dsData, 'bruises?', 'edibility')
==> to-pretty-table()

+-------+--------+-----------+
|       | edible | poisonous |
+-------+--------+-----------+
| False |  1456  |    3292   |
| True  |  2752  |    624    |
+-------+--------+-----------+

---

## Procedure outline

Let us make a full blown classification workflow with the following steps:

- Split the data into training and testing sets
- Make an SMR object over the training set
- Classify with the SMR object all records of the testing set
- Derive (and display) classifier metrics:
    - Confusion matrix
    - ROC plots

---

## SMR object creation

Split the data:

In [9]:
my (@dsTraining, @dsTesting);
with take-drop(@dsData, floor(0.75 * @dsData.elems)) {
    @dsTraining = select-columns($_.head, ['id', |@field-names]); 
    @dsTesting =  select-columns($_.tail, ['id', |@field-names]); 
}

say deduce-type(@dsTraining);
say deduce-type(@dsTesting);

Vector(Assoc(Atom((Str)), Atom((Str)), 11), 6093)
Vector(Assoc(Atom((Str)), Atom((Str)), 11), 2031)


In [10]:
@dsTraining.&dimensions

(6093 11)

Create a Sparse Matrix Recommender (SMR) object with the training data:

In [12]:
my $smrObj = 
    ML::SparseMatrixRecommender.new(:native)
    .create-from-wide-form(
        @dsTraining,
        item-column-name => "id",
        tag-types => Whatever,
        :add-tag-types-to-column-names,
        tag-value-separator => ":")
    .apply-term-weight-functions("IDF", "None", "Cosine")

ML::SparseMatrixRecommender(:matrix-dimensions((6093, 50)), :density(0.2), :tag-types(("gill-Spacing", "gill-Attachment", "gill-Color", "cap-Color", "bruises?", "edibility", "cap-Shape", "gill-Size", "odor", "cap-Surface")))

Here is an example classification:

In [13]:
my $prof = @dsTesting.pick.grep(*.key ∉ <id edibility>).List;
my @prof = |(($prof».key X~ ':') Z~ $prof».value);
$smrObj.classify-by-profile('edibility', @prof, n-top-nearest-neighbors => 4).take-value

{edibility:poisonous => 1}

---

## Batch classification

In [14]:
my $n-top-nearest-neighbors = 200;

my @dsResults = @dsTesting.pick(2000).race(:4degree, :500batch).map( -> %record {
    my @prof = %record.grep(*.key ∉ <id edibility>).List;
    @prof = |((@prof».key X~ ':') Z~ @prof».value);
    my %class = $smrObj.classify-by-profile('edibility', @prof, :$n-top-nearest-neighbors).take-value;
    %class .= map({ $_.key.subst('edibility:') => $_.value.Num});

    %( id => %record<id>, actual => %record<edibility>, predicted => %class.sort(-*.value).head.key, |%class)
});

deduce-type(@dsResults)

Vector((Any), 2000)

Make sure complete set of columns is presented:

In [15]:
my %empty = :0poisonous, :0edible;
@dsResults = @dsResults.map({ merge-hash(%empty, $_) });
deduce-type(@dsResults)

Vector((Any), 2000)

In [16]:
sink records-summary(@dsResults)

+-----------------+-------------------------------+-------------------+---------------------------------+-------------------+
| id              | poisonous                     | predicted         | edible                          | actual            |
+-----------------+-------------------------------+-------------------+---------------------------------+-------------------+
| 6920    => 1    | Min    => 0                   | poisonous => 1450 | Min    => 0                     | poisonous => 1483 |
| 7830    => 1    | 1st-Qu => 0.14604747162022702 | edible    => 550  | 1st-Qu => 0                     | edible    => 517  |
| 6732    => 1    | Mean   => 0.7385533164894529  |                   | Mean   => 0.2856636064661645    |                   |
| 7920    => 1    | Median => 1                   |                   | Median => 0.0041753653444676405 |                   |
| 7975    => 1    | 3rd-Qu => 1                   |                   | 3rd-Qu => 1                     |             

In [17]:
#% html
@dsResults.pick(10)
==> to-html(field-names => <id actual predicted edible poisonous>)

id,actual,predicted,edible,poisonous
7213,poisonous,poisonous,0.0,1
6140,poisonous,poisonous,0.0371517027863777,1
7990,poisonous,poisonous,0.0322085889570552,1
7260,poisonous,poisonous,0.0322085889570552,1
6760,poisonous,poisonous,0.0,1
7302,edible,edible,1.0,0
6689,poisonous,poisonous,0.0,1
7076,poisonous,poisonous,0.0041493775933609,1
6631,poisonous,poisonous,0.0,1
6384,poisonous,poisonous,0.0227445034116755,1


---

## Classifier metrics

Confusion matrix:

In [18]:
my $ct = cross-tabulate(@dsResults, "actual", "predicted");
to-pretty-table($ct)

+-----------+-----------+--------+
|           | poisonous | edible |
+-----------+-----------+--------+
| edible    |     52    |  465   |
| poisonous |    1398   |   85   |
+-----------+-----------+--------+

Prettier version using HTML rendering:

In [19]:
#%html
$ct.map({ (actual => $_.key, |$_.value) })».Hash.sort(*<actual>) 
==> to-html(field-names => <actual edible poisonous>)

actual,edible,poisonous
edible,465,52
poisonous,85,1398


Receiver Operating Characteristic (ROC) metrics computation:

In [52]:
my @thRange = [|(0, 0.01 ... 0.4), |(0.4, 0.45 ... 1)].unique.sort;

my @rocs = @thRange.map(-> $th { to-roc-hash('poisonous', 'edible', 
                                                select-columns(@dsResults, 'actual')>>.values.flat, 
                                                select-columns(@dsResults, 'poisonous')>>.values.flat.map({ $_ >= $th ?? 'poisonous' !! 'edible' })) });

deduce-type(@rocs)                                        

Vector(Assoc(Atom((Str)), Atom((Int)), 4), 53)

Tabulate ROC records:

In [53]:
#%html
@rocs
==> to-html(field-names => <FalsePositive FalseNegative TrueNegative TruePositive>)

FalsePositive,FalseNegative,TrueNegative,TruePositive
517,0,0,1483
294,29,223,1454
197,43,320,1440
185,44,332,1439
154,50,363,1433
154,50,363,1433
136,51,381,1432
113,56,404,1427
111,56,406,1427
110,56,407,1427


Plot ROC functions (False Positive Rate vs True Positive Rate):

In [54]:
text-list-plot(roc-functions('FPR')(@rocs), roc-functions('TPR')(@rocs),
                width => 70, height => 25, 
                x-label => 'FPR', y-label => 'TPR', 
                x-limit => (0, 1))

++------------+-------------+------------+-------------+------------++        
|                                                                    |        
+                                                                   *+  1.00  
|                                                                    |        
|                                                                    |        
|                                                                    |        
+                                                                    +  0.99  
|                                                                    |        
|                                                                    |        
+                                      *                             +  0.98  
|                                                                    |        
|                                                                    |       T
+                        * *                        