In [1]:
import sys
import os
import random
import csv
from typing import List

module_path = os.path.abspath(os.path.join("..", ".."))
sys.path.append(module_path)

import lvq
import data_util


# Load data

In [2]:
with open("iris.csv", "rt") as f:
    dataset = [
        [float(value) for value in features] + [label]
        for *features, label in [line for line in csv.reader(f) if line]
    ]
mapping, encoded = data_util.encode_labels(dataset)

normalized = data_util.normalize(encoded)


# Initialize model

In [3]:
model_config = dict(
    codebook_size=6,
    features_count=4,  # Unused if codebook_init_method == "sample"
    labels_count=3,
    codebook_init_method="sample",
    codebook_init_dataset=encoded,  # Needed only in case codebook_init_method == "sample"
)

model = lvq.LVQ(**model_config)

random.seed(0)

sample = random.choice(encoded)
*features, label = sample

print("Random sample:")
print(sample)
print("Prediction:", model.predict(features))

print("\nInitialized codebook:")
for row in model.codebook:
    print(row)


Random sample:
[5.1, 2.5, 3.0, 1.1, 0]
Prediction: 0

Initialized codebook:
[5.7, 2.8, 4.1, 1.3, 0]
[5.0, 3.3, 1.4, 0.2, 1]
[5.9, 3.0, 5.1, 1.8, 2]
[5.1, 2.5, 3.0, 1.1, 0]
[5.3, 3.7, 1.5, 0.2, 1]
[6.2, 3.4, 5.4, 2.3, 2]


# Train model

In [4]:
train_config = dict(
    base_learning_rate=0.1,
    learning_rate_decay="linear",
    epochs=10,
)

random.seed(0)

model.train_codebook(train_vectors=encoded, **train_config)

print("Random sample:")
print(sample)
print("Prediction:", model.predict(features))


Training: 100% |████████████████████████████████████████████████████████| 10/10, acc=0.953, sse=6.29

Random sample:
[5.1, 2.5, 3.0, 1.1, 0]
Prediction: 0





# Cross validation

In [5]:
random.seed(0)

scores = lvq.cross_validate(
    encoded,
    fold_count=3,
    **model_config,
    **train_config,
)

print("Validation accuracies:")
for score in scores:
    print(round(score, 3))


Training: 100% |█████████████████████████████████████████████████████████| 10/10, acc=0.96, sse=4.43
Training: 100% |█████████████████████████████████████████████████████████| 10/10, acc=0.96, sse=3.49
Training: 100% |█████████████████████████████████████████████████████████| 10/10, acc=0.98, sse=3.97

Validation accuracies:
0.9
0.96
0.96



