### Separate results

This notebook is used to explore disease (Alternaria) present on the field.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from configs import configs
from database.db import SQLiteDatabase
from utils.metrics import calculate_classification_metrics
from utils.plot_utils import save_plot_figure
from utils.utils import ensure_dir

In [None]:
model_name = "alternaria_clf"
model_name = "alternaria_b_clf"

db = SQLiteDatabase()
records = db.get_records(is_latest=True, model_name=model_name)
container = {}
for record in records:
	for data, predictions in zip(record.data, record.predictions):
		data_name = data.name
		data_content = data.content
		pred_content = predictions.content
		container[data_name] = (data_content, pred_content)

mapping = {'class 1': 'Healthy', 'class 2': 'Infected'}

In [None]:
def countplot(data_name: str):
	with save_plot_figure(save_path=ensure_dir(configs.SAVE_RESULTS_DIR / model_name) / f"{data_name}_distribution.pdf", 
                          figsize=(8,3)) as (fig, ax):
		if data_name == "all":
			data_names = ["train", "test"]

		if isinstance(data_name, str) and not data_name == "all":
			data_names = [data_name]
   
		dfs = []
		for data_name in data_names:
			meta = container[data_name][0].meta.reset_index(drop=True)
			target = container[data_name][0].target
			df = pd.concat((meta, pd.DataFrame(target.label)), axis=1)
			df["label"] = df["label"].map(mapping)
			dfs.append(df)

		df = pd.concat(dfs, axis=0)

		sns.countplot(data=df, x="varieties", hue="label", palette=["darkgoldenrod", "forestgreen"], alpha=0.5)
		ax.legend(framealpha=0)

		plt.xlabel('Variety')
		plt.ylabel('Count')
		plt.title(data_name)
		ax.spines["right"].set_linewidth(0)
		ax.spines["top"].set_linewidth(0)

countplot("train")
countplot("test")
countplot("all")


In [None]:
data_name = "test"
meta = container[data_name][0].meta.reset_index(drop=True)
target = container[data_name][0].target

y_true = target.value.to_numpy()
y_pred = container[data_name][1].predictions

df = pd.DataFrame.from_dict(
	{
		"varieties": meta["varieties"],
		"y_true": y_true,
		"y_pred": y_pred,
	}
)

average = "weighted"
metrics = {}
metrics_by_class = {}
for name, group in df.groupby("varieties"):
	metrics_by_class[name] = calculate_classification_metrics(group.y_true, group.y_pred, average=None)
	metrics[name] = calculate_classification_metrics(group.y_true, group.y_pred, average=average)


print(f"Metrics on {data_name} data:")
print("\n--> Metrics per treatment:")
for variety, metric in metrics_by_class.items():
	print(variety)
	for idx, (f1, precision, recall) in enumerate(zip(metric.f1, metric.precision, metric.recall)):
		treatment = mapping[container["test"][0].target.encoding.to_dict()[idx]]
		print(f"   {treatment:<9} -> F1: {f1:.2f} | Precision: {precision:.2f} | Recall: {recall:.2f}" )

print("\n--> Metrics per variety:")
print("Variety    | Accuracy |  F1  | Precision | Recall")
for variety, metric in metrics.items():
    print(
        f"{variety:<10} |     {metric.accuracy:.2f} | {metric.f1:.2f} "
        f"| {metric.precision:.2f}      | {metric.recall:.2f}"
        )

print(f"\n--> Average metrics on:\n{calculate_classification_metrics(y_true=y_true, y_pred=y_pred, average=average)}")