# Imports

In [None]:
import sys; sys.path.append("..")
from pathlib import Path
from PIL import Image

import sklearn.metrics
import numpy as np

import src.model

# Helpers

In [None]:
def read_labels_file(path):
  with open(path, 'r') as f:
    whole_csv = f.readlines()
  rows = [row.rstrip("\n").split(',') for row in whole_csv]
  paths, labels = zip(*rows)
  labels = list(map(int, labels))
  return paths, labels

def read_img(path):
  im_frame = Image.open(path)
  return np.array(im_frame)

# Data Loading

In [None]:
training_images_meta = Path('..', 'data', 'train', 'labels.txt')
testing_images_meta = Path('..', 'data', 'test', 'labels.txt')
training_images_meta.exists(), testing_images_meta.exists()

In [None]:
train_paths, train_labels = read_labels_file(training_images_meta)
test_paths, test_labels = read_labels_file(testing_images_meta)
len(train_paths), len(train_labels), len(test_paths), len(test_labels)

# Train model

In [None]:
model = src.model.ImgEmbeddingKnn(base_model="MobileNetV3Small", n_neighbors=3)
model.train(train_paths[:10], train_labels[:10])

# Evaluate Model

In [None]:
y_pred = [model.predict(p) for p in test_paths]
acc = sklearn.metrics.accuracy_score(test_labels[:10], y_pred[:10])
print(f"Model accuracy = {acc}")

## Save Model

In [None]:
import importlib
importlib.reload(src.model)

In [None]:
model.save_model(str(Path("..", "models", "demo_model")))