In [None]:
:dep csv = { version = "1.1" }
:dep ndarray = { version = "0.13.1" }
:dep linfa = { version = "0.7.0" }
:dep linfa-trees = { version = "0.7" }
:dep ndarray-rand = { version = "0.15.0" }
:dep plotters = { version = "0.3.7" }

extern crate csv;
extern crate ndarray;
extern crate linfa;
extern crate linfa-trees;
extern crate ndarray-rand;
extern crate plotters;


In [None]:
use csv::ReaderBuilder;
use ndarray::{Array1, Array2, s};
use linfa::prelude::*;
use linfa_trees::DecisionTree;
use std::collections::HashMap;
use std::error::Error;
use plotters::prelude::*;

// Function to map labels to numeric values for classification
fn encode_labels(labels: Vec<String>) -> (Array1<usize>, HashMap<String, usize>) {
    let mut label_map = HashMap::new();
    let mut current_label = 0;
    let encoded_labels: Vec<usize> = labels
        .into_iter()
        .map(|label| {
            *label_map.entry(label).or_insert_with(|| {
                let val = current_label;
                current_label += 1;
                val
            })
        })
        .collect();
    let array = Array1::from_vec(encoded_labels);
    (array, label_map)
}

// Function to read CSV into an ndarray with chunked reading
fn read_csv_to_ndarray(file_path: &str) -> Result<(Array2<f64>, Vec<String>), Box<dyn Error>> {
    let mut reader = ReaderBuilder::new()
        .has_headers(true)
        .from_path(file_path)?;
    
    let mut features = Vec::new();
    let mut labels = Vec::new();
    let mut row_count = 0;
    
    for result in reader.records() {
        let record = result?;
        row_count += 1;
        if row_count == 1 {
            let col_count = record.len() - 1;
            features = Vec::with_capacity(row_count * col_count);
            labels = Vec::with_capacity(row_count);
        }
        
        let row: Vec<f64> = record
            .iter()
            .take(record.len() - 1)
            .filter_map(|field| field.parse::<f64>().ok())
            .collect();
            
        if row.len() == record.len() - 1 {
            features.extend(row);
            labels.push(record[record.len() - 1].to_string());
        }
    }

    let num_cols = if !features.is_empty() { features.len() / labels.len() } else { 0 };
    let feature_array = Array2::from_shape_vec((labels.len(), num_cols), features)?;
    
    Ok((feature_array, labels))
}

// Function to split data into training and testing sets
fn split_data(data: Array2<f64>, targets: Array1<usize>) -> (Array2<f64>, Array1<usize>, Array2<f64>, Array1<usize>) {
    let num_samples = data.nrows();
    let split_at = (num_samples as f64 * 0.8) as usize;
    
    let x_train = data.slice(s![..split_at, ..]).to_owned();
    let y_train = targets.slice(s![..split_at]).to_owned();
    let x_test = data.slice(s![split_at.., ..]).to_owned();
    let y_test = targets.slice(s![split_at..]).to_owned();
    
    (x_train, y_train, x_test, y_test)
}

// Function to visualize the relation between features and target labels
fn plot_data(features: &Array2<f64>, targets: &Array1<usize>, file_name: &str) -> Result<(), Box<dyn Error>> {
    let root = BitMapBackend::new(file_name, (800, 600)).into_drawing_area();
    root.fill(&WHITE)?;
    let mut chart = ChartBuilder::on(&root)
        .caption("Feature Relationships", ("sans-serif", 40).into_font())
        .margin(10)
        .x_label_area_size(30)
        .y_label_area_size(30)
        .build_cartesian_2d(0..features.nrows(), 0..features.ncols())?;

    chart.configure_mesh().draw()?;
    for (i, row) in features.outer_iter().enumerate() {
        let label = targets[i] as i32;
        chart.draw_series(PointSeries::of_element(
            row.iter().enumerate().map(|(x, &y)| (x as i32, y as i32)),
            5,
            &Palette99::pick(label),
            &|c, s, st| {
                return EmptyElement::at(c)
                    + Circle::new((0, 0), s, st.filled());
            },
        ))?;
    }
    Ok(())
}

fn main() -> Result<(), Box<dyn Error>> {
    let file_path = "data/train.csv";
    
    println!("Reading data from CSV...");
    let (data, labels) = read_csv_to_ndarray(file_path)?;
    
    println!("Encoding labels...");
    let (encoded_labels, label_map) = encode_labels(labels);
    
    println!("Splitting data...");
    let (x_train, y_train, x_test, y_test) = split_data(data, encoded_labels);
    
    println!("Creating datasets...");
    let train_dataset = Dataset::from((x_train.clone(), y_train.clone()));
    let test_dataset = Dataset::from((x_test.clone(), y_test.clone()));
    
    println!("Training model...");
    let model = DecisionTree::params()
        .max_depth(Some(10))
        .min_weight_leaf(1.0)
        .fit(&train_dataset)
        .expect("Failed to train model");
    
    println!("Making predictions...");
    let predictions = model.predict(&test_dataset);
    
    let accuracy = predictions
        .iter()
        .zip(test_dataset.targets().iter())
        .filter(|(&pred, &actual)| pred == actual)
        .count() as f64
        / test_dataset.targets().len() as f64;
    
    println!("Model accuracy: {:.2}%", accuracy * 100.0);
    println!("Label encoding map: {:?}", label_map);
    
    // Visualize the training data
    plot_data(&x_train, &y_train, "train_plot.png")?;
    
    Ok(())
}