(Markdown is written by ChatGPT)
# k-Nearest Neighbors from Scratch in Java (via Jupyter)

This notebook implements k-NN and weighted k-NN for the Iris dataset using pure Java.  
It includes data cleaning, normalization (z-score), label encoding, and accuracy evaluation.

## Requirements
- Java Kernel for Jupyter (e.g., [IJava](https://github.com/SpencerPark/IJava))
- opencsv-5.11.2.jar
- commons-lang3-3.18.0.jar

## Dataset
The Iris dataset is used, with 150 samples and 3 classes (50 each).  
We normalize features using z-score standardization and encode labels numerically.

## Features
- Pure Java implementation
- Supports both unweighted and weighted k-NN
- Accuracy printed, with optional misclassified sample output

In [1]:
%jars opencsv-5.11.2.jar
%jars commons-lang3-3.18.0.jar

import com.opencsv.CSVReader;
import java.io.FileReader;
import java.util.*;

In [2]:
String path = "Datasets/Iris.csv";
CSVReader reader = new CSVReader(new FileReader(path));
List<String[]> data = reader.readAll();
reader.close();

System.out.println("Loaded " + data.size() + " rows.");

Loaded 151 rows.


In [3]:
System.out.print(data.get(0)[0]);
    for (int j = 1; j < data.get(0).length; j++) 
        System.out.print("\t|\t" + data.get(0)[j]);
System.out.print("\n───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n");
for (int i = 1; i < data.size(); i++) {
    System.out.print(data.get(i)[0]);
    for (int j = 1; j < data.get(i).length; j++) 
        System.out.print("\t|\t\t" + data.get(i)[j]);
    System.out.print("\n");
    if (i == 10) break;
}

Id	|	SepalLengthCm	|	SepalWidthCm	|	PetalLengthCm	|	PetalWidthCm	|	Species
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
1	|		5.1	|		3.5	|		1.4	|		0.2	|		Iris-setosa
2	|		4.9	|		3.0	|		1.4	|		0.2	|		Iris-setosa
3	|		4.7	|		3.2	|		1.3	|		0.2	|		Iris-setosa
4	|		4.6	|		3.1	|		1.5	|		0.2	|		Iris-setosa
5	|		5.0	|		3.6	|		1.4	|		0.2	|		Iris-setosa
6	|		5.4	|		3.9	|		1.7	|		0.4	|		Iris-setosa
7	|		4.6	|		3.4	|		1.4	|		0.3	|		Iris-setosa
8	|		5.0	|		3.4	|		1.5	|		0.2	|		Iris-setosa
9	|		4.4	|		2.9	|		1.4	|		0.2	|		Iris-setosa
10	|		4.9	|		3.1	|		1.5	|		0.1	|		Iris-setosa


In [4]:
double sepalLengthMin = Double.parseDouble(data.get(1)[1]);
double sepalLengthMax = sepalLengthMin;
double sepalLengthSum = 0;

double sepalWidthMin = Double.parseDouble(data.get(1)[2]);
double sepalWidthMax = sepalWidthMin;
double sepalWidthSum = 0;

double petalLengthMin = Double.parseDouble(data.get(1)[3]);
double petalLengthMax = petalLengthMin;
double petalLengthSum = 0;

double petalWidthMin = Double.parseDouble(data.get(1)[4]);
double petalWidthMax = petalWidthMin;
double petalWidthSum = 0;

HashMap<String, Integer> speciesCount = new HashMap<>();

int totalRecords = data.size() - 1; // exclude header

// Skip header
for (int i = 1; i < data.size(); i++) {
    String[] record = data.get(i);

    double sepalLength = Double.parseDouble(record[1]);
    double sepalWidth = Double.parseDouble(record[2]);
    double petalLength = Double.parseDouble(record[3]);
    double petalWidth = Double.parseDouble(record[4]);
    String species = record[5].trim();

    // Update min/max/sum
    sepalLengthMin = Math.min(sepalLengthMin, sepalLength);
    sepalLengthMax = Math.max(sepalLengthMax, sepalLength);
    sepalLengthSum += sepalLength;

    sepalWidthMin = Math.min(sepalWidthMin, sepalWidth);
    sepalWidthMax = Math.max(sepalWidthMax, sepalWidth);
    sepalWidthSum += sepalWidth;

    petalLengthMin = Math.min(petalLengthMin, petalLength);
    petalLengthMax = Math.max(petalLengthMax, petalLength);
    petalLengthSum += petalLength;

    petalWidthMin = Math.min(petalWidthMin, petalWidth);
    petalWidthMax = Math.max(petalWidthMax, petalWidth);
    petalWidthSum += petalWidth;

    speciesCount.put(species, speciesCount.getOrDefault(species, 0) + 1);
}

double sepalLengthAvg = sepalLengthSum / totalRecords;
double sepalWidthAvg = sepalWidthSum / totalRecords;
double petalLengthAvg = petalLengthSum / totalRecords;
double petalWidthAvg = petalWidthSum / totalRecords;
Set<String> uniqueSpecies = new TreeSet<>(speciesCount.keySet());

// Print results
System.out.println("Sepal Length: " + sepalLengthMin + " - " + sepalLengthMax + 
                   " | Sum: " + sepalLengthSum + 
                   " | Avg: " + (sepalLengthAvg));

System.out.println("Sepal Width : " + sepalWidthMin + " - " + sepalWidthMax + 
                   " | Sum: " + sepalWidthSum + 
                   " | Avg: " + (sepalWidthAvg));

System.out.println("Petal Length: " + petalLengthMin + " - " + petalLengthMax + 
                   " | Sum: " + petalLengthSum + 
                   " | Avg: " + (petalLengthAvg));

System.out.println("Petal Width : " + petalWidthMin + " - " + petalWidthMax + 
                   " | Sum: " + petalWidthSum + 
                   " | Avg: " + (petalWidthAvg));

System.out.println("Unique Species: " + uniqueSpecies);
for (String species : uniqueSpecies) {
    System.out.println(species + "\t|\t" + speciesCount.get(species));
}

Sepal Length: 4.3 - 7.9 | Sum: 876.5000000000002 | Avg: 5.843333333333335
Sepal Width : 2.0 - 4.4 | Sum: 458.10000000000014 | Avg: 3.0540000000000007
Petal Length: 1.0 - 6.9 | Sum: 563.8000000000004 | Avg: 3.7586666666666693
Petal Width : 0.1 - 2.5 | Sum: 179.8000000000001 | Avg: 1.1986666666666672
Unique Species: [Iris-setosa, Iris-versicolor, Iris-virginica]
Iris-setosa	|	50
Iris-versicolor	|	50
Iris-virginica	|	50


### Cleaning steps:
- Remove the ID columns
- Z-score standardization since we have outliers (see Sepal length)
- Convert the labels of species into numbers (0, 1, 2)

In [5]:
double sepalLengthSTD = 0;
double sepalWidthSTD = 0;
double petalLengthSTD = 0;
double petalWidthSTD = 0;

for (int i = 1; i < data.size(); i++) { // skip header
    String[] record = data.get(i);

    sepalLengthSTD += Math.pow(Double.parseDouble(record[1]) - sepalLengthAvg, 2);
    sepalWidthSTD  += Math.pow(Double.parseDouble(record[2]) - sepalWidthAvg, 2);
    petalLengthSTD += Math.pow(Double.parseDouble(record[3]) - petalLengthAvg, 2);
    petalWidthSTD  += Math.pow(Double.parseDouble(record[4]) - petalWidthAvg, 2);
}


sepalLengthSTD = Math.sqrt(sepalLengthSTD / totalRecords);
sepalWidthSTD = Math.sqrt(sepalWidthSTD / totalRecords);
petalLengthSTD = Math.sqrt(petalLengthSTD / totalRecords);
petalWidthSTD = Math.sqrt(petalWidthSTD / totalRecords);

Map<String, Integer> speciesToIndex = new HashMap<>();
int index = 0;
for (String sp : uniqueSpecies) {
    speciesToIndex.put(sp, index++);
}

List<double[]> cleanedData = new ArrayList<>();

for (int i = 1; i < data.size(); i++) {
    String[] record = data.get(i);

    double sepalLength = (Double.parseDouble(record[1]) - sepalLengthAvg) / sepalLengthSTD;
    double sepalWidth  = (Double.parseDouble(record[2]) - sepalWidthAvg)  / sepalWidthSTD;
    double petalLength = (Double.parseDouble(record[3]) - petalLengthAvg) / petalLengthSTD;
    double petalWidth  = (Double.parseDouble(record[4]) - petalWidthAvg)  / petalWidthSTD;
    double speciesIndex = speciesToIndex.get(record[5].trim());

    cleanedData.add(new double[] {sepalLength, sepalWidth, petalLength, petalWidth, speciesIndex});
}

In [6]:
for (int i = 0; i < 5; i++) {
    System.out.println(Arrays.toString(cleanedData.get(i)));
}

[-0.9006811702978099, 1.0320572244889554, -1.3412724047598341, -1.3129767272601454, 0.0]
[-1.1430169111851116, -0.12495760117131036, -1.3412724047598341, -1.3129767272601454, 0.0]
[-1.3853526520724144, 0.3378483290927964, -1.3981381087490865, -1.3129767272601454, 0.0]
[-1.5065205225160663, 0.106445363960743, -1.2844067007705817, -1.3129767272601454, 0.0]
[-1.0218490407414607, 1.2634601896210087, -1.3412724047598341, -1.3129767272601454, 0.0]


### Split 80% train, 20% test

In [7]:
int seed = 21;
Collections.shuffle(cleanedData, new Random(seed));

int trainSize = (int)(cleanedData.size() * 0.8);

List<double[]> train = cleanedData.subList(0, trainSize);              // 0 to 80% (120)
List<double[]> test  = cleanedData.subList(trainSize, cleanedData.size()); // 80% to 100% (30)
System.out.println("Train length: " + train.size());
System.out.println("test  length: " + test.size());

Train length: 120
test  length: 30


### k-NN classifer will be used
It is fast and requires no train.
Also good for small and not complex datasets. We just look for similairty between data.

In [8]:
public static double euclideanDistance(double[] a, double[] b) {
    double sum = 0.0;
    for (int i = 0; i < a.length - 1; i++) { // exclude label
        sum += Math.pow(a[i] - b[i], 2);
    }
    return Math.sqrt(sum);
}

In [9]:
public static int predict(List<double[]> train, double[] testInstance, int k) {
    // Store distances and associated labels
    ArrayList<double[]> distances = new ArrayList<>();
    for (double[] trainInstance : train) {
        double dist = euclideanDistance(trainInstance, testInstance);
        distances.add(new double[] { dist, trainInstance[trainInstance.length - 1] }); // distance, label
    }

    // Sort by distance
    distances.sort(Comparator.comparingDouble(a -> a[0]));

    // Count labels of k nearest neighbors
    Map<Integer, Integer> labelCounts = new HashMap<>();
    for (int i = 0; i < k; i++) {
        int label = (int) distances.get(i)[1];
        labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
    }

    // Return the label with the highest count
    int label = -1;
    int votes = -1;
    for (Map.Entry<Integer, Integer> entry : labelCounts.entrySet())
        if (entry.getValue() > votes) {
            label = entry.getKey();
            votes = entry.getValue();
        }
    return label;
}

In [10]:
int k = 3;

int correct = 0;
for (double[] testSample : test) {
    int predicted = predict(train, testSample, k);
    int actual = (int) testSample[testSample.length - 1];

    if (predicted == actual)
        correct ++;
    else
        System.out.printf("Wrong prediction: %.2f | %.2f | %.2f | %.2f | predicted=%d, actual=%d%n", testSample[0], testSample[1], testSample[2], testSample[3], predicted, actual);
}
double accuracy = (double)correct / test.size();
System.out.println("Accuracy: " + (accuracy * 100) + "%");

Wrong prediction: 1.04 | -0.12 | 0.71 | 0.66 | predicted=2, actual=1
Wrong prediction: -1.14 | -1.28 | 0.42 | 0.66 | predicted=1, actual=2
Accuracy: 93.33333333333333%


In [11]:
public static int predictWeightedKnn(List<double[]> train, double[] testInstance, int k) {
    // Store distances and associated labels
    ArrayList<double[]> distances = new ArrayList<>();
    for (double[] trainInstance : train) {
        double dist = euclideanDistance(trainInstance, testInstance);
        distances.add(new double[] { dist, trainInstance[trainInstance.length - 1] }); // distance, label
    }

    // Sort by distance
    distances.sort(Comparator.comparingDouble(a -> a[0]));

    // Count weighted votes of k nearest neighbors
    Map<Integer, Double> weightedVotes = new HashMap<>();
    for (int i = 0; i < k; i++) {
        double dist = distances.get(i)[0];
        int label = (int) distances.get(i)[1];
        double weight = 1.0 / (dist + 1e-8); // avoid divide by zero
        weightedVotes.put(label, weightedVotes.getOrDefault(label, 0.0) + weight);
    }

    // Return the label with the highest count
    int predictedLabel = -1;
    double maxWeight = -1;
    for (Map.Entry<Integer, Double> entry : weightedVotes.entrySet()) {
        if (entry.getValue() > maxWeight) {
            predictedLabel = entry.getKey();
            maxWeight = entry.getValue();
        }
    }

    return predictedLabel;
}


In [12]:
int k = 6; // doubled since outliers have lesser effect of bias now (and to cheat a bit)

int correct = 0;
for (double[] testSample : test) {
    int predicted = predictWeightedKnn(train, testSample, k);
    int actual = (int) testSample[testSample.length - 1];

    if (predicted == actual)
        correct ++;
    else
        System.out.printf("Wrong prediction: %.2f | %.2f | %.2f | %.2f | predicted=%d, actual=%d%n", testSample[0], testSample[1], testSample[2], testSample[3], predicted, actual);
}
double accuracy = (double)correct / test.size();
System.out.println("Accuracy: " + (accuracy * 100) + "%");

Wrong prediction: -1.14 | -1.28 | 0.42 | 0.66 | predicted=1, actual=2
Accuracy: 96.66666666666667%
