In [43]:
import Torch hiding (div)

In [44]:
let binaryInput = asTensor ([1.0,0.0,1.0] :: [Float])
let binaryTarget = asTensor ([0.0,1.0,0.0] :: [Float])

In [45]:
binaryConfusionMatrix :: Tensor -> Tensor -> (Int, Int, Int, Int)
binaryConfusionMatrix yTrue yPred =
    let tp = asValue (sumAll (logicalAnd (eq yTrue 1) (eq yPred 1))) :: Int
        tn = asValue (sumAll (logicalAnd (eq yTrue 0) (eq yPred 0))) :: Int
        fp = asValue (sumAll (logicalAnd (eq yTrue 0) (eq yPred 1))) :: Int
        fn = asValue (sumAll (logicalAnd (eq yTrue 1) (eq yPred 0))) :: Int
    in (tp, fp, tn, fn)

In [46]:
binaryConfusionMatrix binaryTarget binaryInput

(0,2,0,1)

In [47]:
let multiclassInput = asTensor ([3.0,4.0,0.0,3.0]:: [Float])
let multiclassTarget = asTensor ([3.0,4.0,0.0,3.0] :: [Float])

In [48]:
multiclassConfusionMatrix :: Tensor -> Tensor -> [[Int]]
multiclassConfusionMatrix yTrue yPred =
    let yTrueList = map round (asValue yTrue :: [Float])
        yPredList = map round (asValue yPred :: [Float])
        pairs     = zip yTrueList yPredList
        n         = if null pairs then 0 else 1 + maximum (map fst pairs ++ map snd pairs)
    in [ [ length [ () | (t,p) <- pairs, t == i, p == j ]
         | j <- [0..n-1] ]
       | i <- [0..n-1] ]



In [49]:
multiclassConfusionMatrix multiclassTarget multiclassInput

[[1,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,2,0],[0,0,0,0,1]]

In [50]:
import Control.Monad (forM_)
import Text.Printf     (printf)

printConfusionMatrix :: [[Int]] -> IO ()
printConfusionMatrix m = do
    let n = length m
    -- header
    putStr "    "
    forM_ [0..n-1] $ \j -> putStr $ printf "%4d" j
    putStrLn ""
    -- lignes
    forM_ (zip [0..] m) $ \(i,row) -> do
        putStr $ printf "%4d" i
        forM_ row $ \x -> putStr $ printf "%4d" x
        putStrLn ""

In [51]:
printConfusionMatrix (multiclassConfusionMatrix multiclassTarget multiclassInput)

       0   1   2   3   4
   0   1   0   0   0   0
   1   0   0   0   0   0
   2   0   0   0   0   0
   3   0   0   0   2   0
   4   0   0   0   0   1

In [66]:
printConfusionMatrix
  :: [String]   -- labels
  -> [[Int]]    -- confusion matrix
  -> IO ()
printConfusionMatrix labels m = do
  let n        = length m
      lbls     = if length labels == n then labels else map show [0..n-1]
      colW     = 10
      widthExp = n * colW                     -- largeur totale des colonnes "Expected"
      padExp   = (widthExp - length "Expected") `div` 2
      -- on ajoute colW pour décaler le titre au-dessus de la zone "Expected"
      padL     = colW + padExp

  -- titre "Expected" centré au-dessus des colonnes de prédiction
  putStrLn $ replicate padL ' ' ++ "Expected"

  -- en-tête : labels Expected uniquement (sans "Actual")
  putStr   $ replicate colW ' '  -- espace vide à la place de "Actual"
  forM_ lbls $ \l -> putStr $ printf "%*s" colW l
  putStrLn ""

  -- lignes : valeurs uniquement (sans étiquettes)
  forM_ (zip lbls m) $ \(lab,row) -> do
    putStr $ printf "%*s" colW lab
    forM_ row $ \x -> putStr $ printf "%*d" colW x
    putStrLn ""

In [68]:
printConfusionMatrix ["chat","chien","bonjour","moi","kk"] (multiclassConfusionMatrix multiclassTarget multiclassInput)

                               Expected
                chat     chien   bonjour       moi        kk
      chat         1         0         0         0         0
     chien         0         0         0         0         0
   bonjour         0         0         0         0         0
       moi         0         0         0         2         0
        kk         0         0         0         0         1

In [6]:
import qualified Torch.Functional.Internal as FI
import Torch

target <- randIO' [3,4]
input <- randIO' [3,4]

crossEntropyLoss :: Tensor -> Tensor -> Float
crossEntropyLoss target output = 
  let
      weight = ones' [last (shape output)]
      loss = FI.cross_entropy_loss output target weight 1 (-100) 0.0

    in
      asValue loss

In [7]:
crossEntropyLoss target input

2.866855