# Lab #1

The purpose of this laboratory is to get you acquainted with Python. 
More specifically, you will learn how to:
- read different types of datasets (CSV and JSON). 
- extract some useful information (mean and standard deviation) from the datasets while only using basic python.
- create a simple rule-based classifier that is already capable to perform some classification.


## Preliminaries
### Python availability
Make sure that Python 3 is installed on your device with the commands `python3 --version`. The version should be in the form `3.x.x.`

In [1]:
!python3 --version

Python 3.11.4


### Dataset Download
For this lab, three different datasets will be used. Here, you will learnmore about them and how to retrieve
them.

#### Iris
Iris is a particularly famous *toy dataset* (i.e. a dataset with a small number of rows and columns, mostly
used for initial small-scale tests and proofs of concept). 
This specific dataset contains information about the **Iris**, a genus that includes 260-300 species of plants. 
The Iris dataset contains measurements for 150 Iris flowers, each belonging to one of three species (50 flowers each): 

Iris Virginica             |  Iris Versicolor          |   Iris Setosa  |
:-------------------------:|:-------------------------:|:---------------|
:<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/f/f8/Iris_virginica_2.jpg/1200px-Iris_virginica_2.jpg" alt="Iris Virginica" width="200" /> | <img src="https://www.waternursery.it/document/img_prodotti/616/1646318149.jpeg" alt="Iris Versicolor" width="200" /> |<img src="https://d2j6dbq0eux0bg.cloudfront.net/images/28296135/2323483832.jpg" alt="Iris Setosa" width="200" />|

Each of the 150 flowers contained in the Iris dataset is represented by 5 values:
- sepal length, in cm
- sepal width, in cm
- petal length, in cm
- petal width, in cm
- Iris species, one of: Iris-setosa, Iris-versicolor, Iris-virginica (the label)

Each row of the dataset represents a distinct flower (as such, the dataset will have 150 rows). Each
row then contains 5 values (4 measurements and a species label).
The dataset is described in more detail on the [UCI Machine Learning Repository website](https://archive.ics.uci.edu/dataset/53/iris). The dataset
can either be downloaded directly from there (iris.data file), or from a terminal, using the `wget` tool. The
following command downloads the dataset from the original URL and stores it in a file named iris.csv.

`wget "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" -O iris.csv`

The dataset is available as a Comma-Separated Values (CSV) file. These files are typically used to
represent tabular data. 
- Each row is represented on one of the lines. 
- Each of the rows contains a fixed number of columns. 
- Each of the columns (in each row) is separated by a comma (,).

To read CSV files, Python offers a module called `csv` (here the offical [doc](https://docs.python.org/3/library/csv.html)). This module allows using `csv.reader()`, which
reads a file row by row. For each row, it returns a list of columns that can be processed as needed. 


Let's download the dataset and print the first three rows.




In [2]:
! wget "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" -O iris.csv

print("Reading first lines of the IRIS dataset")
import csv 
with open("iris.csv") as f:
    for i, cols in enumerate(csv.reader(f)):
        print(cols)
        if i >= 4:
            break

--2023-10-30 15:36:10--  https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: 'iris.csv'

iris.csv                [ <=>                ]   4.44K  --.-KB/s    in 0.003s  

2023-10-30 15:36:11 (1.68 MB/s) - 'iris.csv' saved [4551]

Reading first lines of the IRIS dataset
['5.1', '3.5', '1.4', '0.2', 'Iris-setosa']
['4.9', '3.0', '1.4', '0.2', 'Iris-setosa']
['4.7', '3.2', '1.3', '0.2', 'Iris-setosa']
['4.6', '3.1', '1.5', '0.2', 'Iris-setosa']
['5.0', '3.6', '1.4', '0.2', 'Iris-setosa']


Note by default, csv.reader converts all fields read into strings (str). 
If you want to treat them as number, remember to cast them correctly!

#### MNIST
The MNIST dataset is another particularly famous dataset. It contains several thousands of hand-written
digits (0 to 9). 
- Each hand-written digit is contained in a $28 x 28$ 8-bit grayscale image. 
- This means that each digit has $784$ ($28^2$) pixels
- Each pixel has a value that ranges from 0 (black) to 255 (white).

<img src="https://machinelearningmastery.com/wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-MNIST-Dataset.png" alt="MNIST images" width="500" />

The dataset can be downloaded from the following link:

[https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv](https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv)



In this case, MNIST is represented as a CSV file. Similarly to the Iris dataset, each row of the MNIST
datasets represents a digit. For the sake of simplicity, this dataset contains only a small fraction (10; 000
digits out of 70; 000) of the real MNIST dataset. 

For each digit, 785 values are available: 
- the first one is the numerical value depicted in the image (e.g. for Figure 2 it would be 5). 
- the following 784 columns represent the grayscale image in row-major order (for more information about row- and column-major order of matrices, see [Wikipedia](https://en.wikipedia.org/wiki/Row-_and_column-major_order)).

The MNIST dataset in CSV format can be read with the same approach used for Iris, keeping in mind
that, in this case, the digit label (i.e. the first column) is an integer from 0 to 9, while the following 784
values are integers between 0 and 255.

In [3]:
! wget "https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv" -O mnist.csv

print("Reading the first line of the MNIST dataset")
import csv 
with open("mnist.csv") as f:
    for i, cols in enumerate(csv.reader(f)):
        for j in range(28):
             print(cols[j*28:j*28+28])
        break


--2023-10-30 15:36:11--  https://raw.githubusercontent.com/dbdmg/data-science-lab/master/datasets/mnist_test.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18289443 (17M) [text/plain]
Saving to: 'mnist.csv'


2023-10-30 15:36:12 (19.9 MB/s) - 'mnist.csv' saved [18289443/18289443]

Reading the first line of the MNIST dataset
['7', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']
['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']
['0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0']
['0'

## Exercises
Note that exercises marked with a (*) are optional, you should focus on completing the other ones first.
### Iris analysis
1. Load the previously downloaded Iris dataset as a list of lists (each of the 150 lists should have 5 elements). You can make use of the csv module presented

In [4]:
rows = []
with open("iris.csv") as f:
    for i, cols in enumerate(csv.reader(f)):
        if cols != []:
            rows.append(cols)
        if i % 10 == 0:
            print(f"{i}) rows read")
print(f"Number of rows read: {len(rows)}")
# print(all_cols)

0) rows read
10) rows read
20) rows read
30) rows read
40) rows read
50) rows read
60) rows read
70) rows read
80) rows read
90) rows read
100) rows read
110) rows read
120) rows read
130) rows read
140) rows read
150) rows read
Number of rows read: 150


2. Compute and print the mean and the standard deviation for each of the 4 measurement columns (i.e. sepal length and width, petal length and width). Remember that, for a given list of n values $x = (x_1, x_2, ..., x_n)$, the mean $\mu$ and the standard deviation $\sigma$ are defined respectively as:
$$\mu = {1 \over n} \sum_i^n x_i $$

$$ \sigma = \sqrt{ {1 \over n} \sum_i^n (x_i - \mu)^2} $$

In [5]:
from math import sqrt

means = [0., 0., 0., 0.]
stds = [0., 0., 0., 0.]

def means_and_stds(rows:list[list]) -> (list, list):
    # computing means
    for row in rows:
        for i in range(4):
            means[i] += float(row[i])
    for i in range(4):        
        means[i] /= len(rows)

    # computing stds
    for row in rows:
        for i in range(4):
            stds[i] += (means[i] - float(row[i])) ** 2

    for i in range(4):        
        stds[i] /= len(rows)
        stds[i] = sqrt(stds[i])
        
    return means, stds

means, stds = means_and_stds(rows)
print("Means", means)
print("Standard deviations", stds)


Means [5.843333333333335, 3.0540000000000007, 3.7586666666666693, 1.1986666666666672]
Standard deviations [0.8253012917851409, 0.4321465800705435, 1.7585291834055201, 0.760612618588172]



3. Compute and print the mean and the standard deviation for each of the 4 measurement columns, separately for each of the three Iris species (`Iris-versicolor`, `Iris-virginica` and `Iris-setosa`). *Remember* that the label is stored in the $5^{th}$ (last) cell of the row.

In [6]:
class_means_stds = {
    "Iris-versicolor": (),
    "Iris-virginica": (),
    "Iris-setosa": (),
}

for iris in class_means_stds.keys():
    iris_rows = list(filter(lambda row: row[4] == iris, rows))
    print(f"Number of rows in {iris}", len(iris_rows))
    means_iris, stds_iris = means_and_stds(iris_rows)
    class_means_stds[iris] = (means_iris, stds_iris)
    print(f"{iris} means {means_iris}")
    print(f"{iris} stds {stds_iris}")
    print("")
    

Number of rows in Iris-versicolor 50
Iris-versicolor means [6.052866666666667, 2.8310800000000005, 4.335173333333333, 1.3499733333333332]
Iris-versicolor stds [0.5396923601585264, 0.32996014608041824, 0.5071702019170239, 0.23261765428031153]

Number of rows in Iris-virginica 50
Iris-virginica means [6.709057333333331, 3.0306215999999995, 5.638703466666666, 2.0529994666666664]
Iris-virginica stds [0.6493879619741559, 0.3342591936030605, 0.5622783075758463, 0.28161201019468385]

Number of rows in Iris-setosa 50
Iris-setosa means [5.1401811466666665, 3.478612432, 1.5767740693333334, 0.28505998933333315]
Iris-setosa stds [0.3908405549073235, 0.3906853603413056, 0.23123052753803175, 0.13631640740551776]



4. Based on the results of exercises 2 and 3, which of the 4 measurements would you considering as being the most characterizing one for the three species? (In other words, which measurement would you consider “best”, if you were to guess the Iris species based only on those four values?)

**The 4th measure is the best for characterizing to which species a flower belong because it has the most distinct means and low standard deviation**


5. Based on the considerations of Exercise 3, assign the flowers with the following measurements to what you consider would be the most likely species.


`5.2, 3.1, 4.0, 1.2`: Iris-versicolor

`4.9, 2.5, 5.6, 2.0`: Iris-virginica

`5.4, 3.2, 1.9, 0.4`: Iris-setosa


6. (*) Create a Rule-based classifier similar to the one seen in class. This classifier, again, will receive some rule and will classify each sample into one of the three species. According to your analysis in the previous point where you identified the most discriminative feature, provide the classifier with 3 rules, one for classifying each iris, based on this feature.

In [7]:
class RuleModel:
    __rules = []    # Private member of the class

    def __init__(self, default_class):
        """
        Create the rule-based model.
        :param default_class: default class when no rule applies
        """
        self.__default_class = default_class

    def add_rule(self, rule, output_class):
        """
        Add rule to the model.
        :param rule: lambda function with the conditions on the input sample
        :param output_class: output label to be assigned when the rule is satisfied
        """
        self.__rules.append((rule, output_class))

    def predict(self, x):
        """
        Apply rules to a sample. The first rule that applies represents the output label.
        :param x: dictionary representing the input sample 
        """
        for rule, out_class in self.__rules:
            if rule(x):
                return out_class
        return self.__default_class
    
    
rule_clf = RuleModel('Iris-virginica')

# Add rules
rule_clf.add_rule(lambda x: float(x[3]) < 0.8, 'Iris-setosa')
rule_clf.add_rule(lambda x: float(x[3]) > 0.8 and float(x[3]) < 1.6, 'Iris-versicolor')
rule_clf.add_rule(lambda x: float(x[3]) > 1.6, 'Iris-virginica') # it could be avoided because it is the default class


7. (*) Compute prediction for all the elements in the dataset and store them in a list. Then, compute the accuracy of the classifier that you create. You will see it later, but the accuracy metric can be computed as:

$$ Acc = {\text{number of correct predictions} \over \text{total number of predictions}} $$

One can compute the number of correct predictions by checking how many times the predicted class is equal to the label of the sample ($5^{th}$ column)

In [8]:
correct_predictions = 0
for row in rows:
    prediction = rule_clf.predict(row)
    label = row[4]
    if prediction == label:
        correct_predictions += 1
    else:
        print(f"Error, predicted {prediction}, actual label {label}")

accuracy = correct_predictions / len(rows)
print(f"Accuracy of the model: {accuracy}")

Error, predicted Iris-virginica, actual label Iris-versicolor
Error, predicted Iris-virginica, actual label Iris-versicolor
Error, predicted Iris-virginica, actual label Iris-versicolor
Error, predicted Iris-virginica, actual label Iris-versicolor
Error, predicted Iris-virginica, actual label Iris-versicolor
Error, predicted Iris-versicolor, actual label Iris-virginica
Error, predicted Iris-versicolor, actual label Iris-virginica
Error, predicted Iris-versicolor, actual label Iris-virginica
Accuracy of the model: 0.9466666666666667


### MNIST Analysis

1. Load the previously downloaded MNIST dataset. You can make use of the csv module already presented.

In [9]:
import csv

rows = []
with open("mnist.csv") as f:
    for i, cols in enumerate(csv.reader(f)):
        if cols != []:
            rows.append(cols)
        if i % 1000 == 0:
            print(f"{i}) rows read")
print(f"Number of rows read: {len(rows)}")

0) rows read
1000) rows read
2000) rows read
3000) rows read
4000) rows read
5000) rows read
6000) rows read
7000) rows read
8000) rows read
9000) rows read
Number of rows read: 10000


2. Create a function that, given a position $1 < k < 10,000$, prints the $k^{th}$ sample of the dataset (i.e. the $k^{th}$ row of the csv file) as a grid of $28x28$ characters. More specifically, you should map each range of pixel values to the following characters:
    - [0; 64) &rarr; " "
    - [64; 128) &rarr; "."
    - [128; 192) &rarr; "*"
    - [192; 256) &rarr; "#"
So, for example, you should map the sequence `0, 72, 192, 138, 250` to the string `.#*#`.
*Note*: Remember to start a new line every time you read 28 characters

Example of output of the $130^{th}$ sample: 
```
         .#      **
        .##..*#####
       #########*.
      #####***.
     ##*
    *##
    ##
   .##
    ###*
    .#####.
        *###*
           *###*
              ###
              .##
              ###
            .###
      .    *###.
     .#  .*###*
     .######.
      *##*.
```


In [10]:
def print_digit(mnist_samples, index):
    mnist_sample = mnist_samples[index]
    for i in range(28):
        row_string = ""
        for j in range(28):
            value = float(mnist_sample[i*28+j])
            if value < 64:
                row_string += " "
            elif value <128:
                row_string += "."
            elif value < 192: 
                row_string += "*"
            else:
                row_string += "#"
        print(row_string)

print_digit(rows, 129)

                            
                            
                            
                            
                            
               .#      **   
              .##..*#####   
             #########*.    
            #####***.       
           ##*              
          *##               
          ##                
         .##                
          ###*              
          .#####.           
             *###*          
               *###*        
                 ###        
                 .##        
                 ###        
               .###         
         .    *###.         
        .# .*###*           
        .######.            
         *##*.              
                            
                            
                            


3. Compute the Euclidean distance between each pair of the 784-dimensional vectors of the digits at
the following positions: $26^{th}$, $30^{th}$, $32^{nd}$, $35^{th}$.

*Note*: Remember that Python arrays are indexed from 0, so the $k^{th}$ value will be at position $k-1$

In [11]:
import math
def euclidean_dist(v1, v2):
    dist = 0
    for v1_el, v2_el in zip(v1[1:], v2[1:]): # skipping first element it is the label
        v1_el, v2_el = float(v1_el), float(v2_el)
        dist += (v1_el-v2_el)**2
    
    dist = math.sqrt(dist)
    return dist

for i in [25,29,31,34]:
    for j in [25,29,31,34]:
        if i!=j:
            sample_i = rows[i]
            sample_j = rows[j]
            dist_i_j = euclidean_dist(sample_i, sample_j)
            print(f"The euclidean dist between {i} and {j} samples is: {dist_i_j}")




The euclidean dist between 25 and 29 samples is: 3539.223219860539
The euclidean dist between 25 and 31 samples is: 3556.4199695761467
The euclidean dist between 25 and 34 samples is: 3223.2069434027967
The euclidean dist between 29 and 25 samples is: 3539.223219860539
The euclidean dist between 29 and 31 samples is: 1171.8293391104355
The euclidean dist between 29 and 34 samples is: 2531.0033583541526
The euclidean dist between 31 and 25 samples is: 3556.4199695761467
The euclidean dist between 31 and 29 samples is: 1171.8293391104355
The euclidean dist between 31 and 34 samples is: 2515.5599774205343
The euclidean dist between 34 and 25 samples is: 3223.2069434027967
The euclidean dist between 34 and 29 samples is: 2531.0033583541526
The euclidean dist between 34 and 31 samples is: 2515.5599774205343


4. Based on the distances computed in the previous step and knowing that the digits listed in Exercise 3 are (not necessarily in this order) $0, 1, 1, 7$ can you assign the correct label to each of the digits of Exercise 3?

The $0$ is sample $25$ beacuse it has the highest distance to all digits

The $1$s are samples $29$ and $31$ because they have the shortest distance among them

The $7$ is sample $34$ because it more similar to the sample representing $1$s than to the sample representing the $0$

5. There are 1,135 images representing 1’s and 980 images representing 0’s in the dataset. For all 0’s and 1’s separately, count the number of times each of the 784 pixels is black (use 128 as the threshold value). You can do this by building a list `Z` and a list `O`, each containing 784 elements, containing respectively the counts for the 0’s and the 1’s. `Z[i]` and `O[i]` contain the number of times the $i^{th}$ pixel was black for either class. For each value i, compute `abs(Z[i] - O[i])`. The $i$ with the highest value represents the pixel that best separates the digits “0” and “1” (i.e. the pixel that is most often black for one class and white for the other). Where is this pixel located within the grid? Why is it?

In [12]:
one_rows = list(filter(lambda row: row[0] == "1", rows))
print("Number of ones", len(one_rows))

zero_rows = list(filter(lambda row: row[0] == "0", rows))
print("Number of zeros", len(zero_rows))

def count_black_pixels(samples):    
    black_pixels = [0] * len(samples[0]) 
    for digit in samples:
        for i, pixel in enumerate(digit):
            black_pixels[i] += float(pixel) > 127
    return black_pixels

O = count_black_pixels(one_rows)
Z = count_black_pixels(zero_rows)

print("First rows count sum for 1s")
print(O[:28])
print(O[28:28*2])
print(O[28*2:28*3])
print(O[28*3:28*4])

print("First rows count sum for 0s")
print(Z[:28])
print(Z[28:28*2])
print(Z[28*2:28*3])
print(Z[28*3:28*4])



Number of ones 1135
Number of zeros 980
First rows count sum for 1s
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 4, 5, 3, 2, 6, 4, 4, 1, 0, 0, 0, 0, 0, 0]
First rows count sum for 0s
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 9, 15, 21, 27, 20, 13, 10, 7, 1, 4, 2, 0, 0, 0]


In [13]:
highest_pixel = -1
highest_dist = -1
distances = [0]*len(O)
for i in range(len(distances)):
    distances[i] = Z[i] - O[i]
    if distances[i] < 0:
        distances[i] = - distances[i]

    print(f"{i}) d = {distances[i]}, O: {O[i]}, Z: {Z[i]}")    
    
    if distances[i] > highest_dist:
        highest_dist = distances[i]
        highest_pixel = i

print(f"Highest dist {highest_dist} at pixel {highest_pixel} located in position {highest_pixel //28}x{highest_pixel%28}")

0) d = 0, O: 0, Z: 0
1) d = 0, O: 0, Z: 0
2) d = 0, O: 0, Z: 0
3) d = 0, O: 0, Z: 0
4) d = 0, O: 0, Z: 0
5) d = 0, O: 0, Z: 0
6) d = 0, O: 0, Z: 0
7) d = 0, O: 0, Z: 0
8) d = 0, O: 0, Z: 0
9) d = 0, O: 0, Z: 0
10) d = 0, O: 0, Z: 0
11) d = 0, O: 0, Z: 0
12) d = 0, O: 0, Z: 0
13) d = 0, O: 0, Z: 0
14) d = 0, O: 0, Z: 0
15) d = 0, O: 0, Z: 0
16) d = 0, O: 0, Z: 0
17) d = 0, O: 0, Z: 0
18) d = 0, O: 0, Z: 0
19) d = 0, O: 0, Z: 0
20) d = 0, O: 0, Z: 0
21) d = 0, O: 0, Z: 0
22) d = 0, O: 0, Z: 0
23) d = 0, O: 0, Z: 0
24) d = 0, O: 0, Z: 0
25) d = 0, O: 0, Z: 0
26) d = 0, O: 0, Z: 0
27) d = 0, O: 0, Z: 0
28) d = 0, O: 0, Z: 0
29) d = 0, O: 0, Z: 0
30) d = 0, O: 0, Z: 0
31) d = 0, O: 0, Z: 0
32) d = 0, O: 0, Z: 0
33) d = 0, O: 0, Z: 0
34) d = 0, O: 0, Z: 0
35) d = 0, O: 0, Z: 0
36) d = 0, O: 0, Z: 0
37) d = 0, O: 0, Z: 0
38) d = 0, O: 0, Z: 0
39) d = 0, O: 0, Z: 0
40) d = 0, O: 0, Z: 0
41) d = 0, O: 0, Z: 0
42) d = 0, O: 0, Z: 0
43) d = 0, O: 0, Z: 0
44) d = 0, O: 0, Z: 0
45) d = 0, O: 0, Z: 

6. (*) Extract a subset of the MNIST dataset composed of only 0 and 1 digits. Create a Rule-based classifier that take as input the rule that you discovered in ex. 5. As previously then, compute the prediction of such a classifier on all the samples in the dataset

In [14]:
class RuleModel:
    __rules = []    # Private member of the class

    def __init__(self, default_class):
        """
        Create the rule-based model.
        :param default_class: default class when no rule applies
        """
        self.__default_class = default_class

    def add_rule(self, rule, output_class):
        """
        Add rule to the model.
        :param rule: lambda function with the conditions on the input sample
        :param output_class: output label to be assigned when the rule is satisfied
        """
        self.__rules.append((rule, output_class))

    def predict(self, x):
        """
        Apply rules to a sample. The first rule that applies represents the output label.
        :param x: dictionary representing the input sample 
        """
        for rule, out_class in self.__rules:
            if rule(x):
                return out_class
        return self.__default_class
    
    
rule_clf_mnist = RuleModel('0')

# Add rules
rule_clf_mnist.add_rule(lambda x: float(x[407]) > 128, "1")


# Compute prediction
correct_predictions = 0
for row in one_rows + zero_rows:
    prediction = rule_clf_mnist.predict(row)
    label = row[0]
    if prediction == label:
        correct_predictions += 1
    else:
        print(f"Error, predicted {prediction}, actual label {label}")

# Compute accuracy
accuracy = correct_predictions / len(one_rows + zero_rows)
print(f"Accuracy of the model: {accuracy}")

Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 0, actual label 1
Error, predicted 1, actual label 0
Error, predicted 1, actual label 0
Error, predicted 1, actual label 0
Error, predicted 1, actual label 0
Error, predicted 1, actual label 0
Error, predicted 1, actual label 0
Error, predicted 1, actual label 0
Accuracy of the model: 0.9895981087470449
