# Análise dos testes

### Imports

In [None]:
from os import makedirs
from os.path import join
from json import load, dump

from scripts.test import Test
from scripts.data import SimpleLesionData, SimpleDatasetAnalysis

import scripts.definitions as defs
import scripts.analysis as analysis

### Configuração

In [None]:
TEST_NAME = 'LLaDerm-0.1-11B-4bit-SC_test_2025-04-08T07_21_49.243036.json'

### Carregamento dos testes e dados

In [None]:
with open(join(defs.RESULTS_PATH, 'tests', TEST_NAME), 'r', encoding='utf-8') as file:
    test = Test(**load(file))

with open(join(defs.DATA_PATH, 'stt_data', 'test_dataset.json'), 'r', encoding='utf-8') as file:
    test_dataset = [SimpleLesionData(**data) for data in load(file)]

with open(join(defs.DATA_PATH, 'stt_data', 'training_dataset.json'), 'r', encoding='utf-8') as file:
    training_dataset = [SimpleLesionData(**data) for data in load(file)]

with open(join(defs.DATA_PATH, 'test_dataset_analysis.json'), 'r', encoding='utf-8') as file:
    test_dataset_analysis = SimpleDatasetAnalysis(**load(file))

with open(join(defs.DATA_PATH, 'training_dataset_analysis.json'), 'r', encoding='utf-8') as file:
    training_dataset_analysis = SimpleDatasetAnalysis(**load(file))

### Processamento dos testes

In [None]:
sanitized_results_on_test = analysis.structure_answers(test.model.prompt_type, test.results_on_test_data)
sanitized_results_on_training = analysis.structure_answers(test.model.prompt_type, test.results_on_training_data)

### Salvamento dos testes processados

In [None]:
valid_results_on_test = list(filter(lambda x: x.answer.valid, sanitized_results_on_test))
invalid_results_on_test = list(filter(lambda x: not x.answer.valid, sanitized_results_on_test))

print(
    f'Resultados válidos para testes sobre dados de teste: {len(sanitized_results_on_test) - len(invalid_results_on_test)}')
print(f'Resultados inválidos para testes sobre dados de teste: {len(invalid_results_on_test)}')

valid_results_on_training = list(filter(lambda x: x.answer.valid, sanitized_results_on_training))
invalid_results_on_training = list(filter(lambda x: not x.answer.valid, sanitized_results_on_training))

print(
    f'Resultados válidos para testes sobre dados de treinamento: {len(sanitized_results_on_test) - len(invalid_results_on_training)}')
print(f'Resultados inválidos para testes sobre dados de treinamento: {len(invalid_results_on_training)}')

test_analysis = analysis.TestAnalysis(
    test_name=TEST_NAME,
    valid_results_on_test_data=valid_results_on_test,
    valid_results_on_training_data=valid_results_on_training,
    invalid_results_on_test_data=invalid_results_on_test,
    invalid_results_on_training_data=invalid_results_on_training
)

makedirs(join(defs.RESULTS_PATH, 'tests', 'analysis'), exist_ok=True)

with open(join(defs.RESULTS_PATH, 'tests', 'analysis', TEST_NAME.replace('.json', '_analysis.json')), 'w', encoding='utf-8') as file:
    dump(test_analysis.model_dump(), file, indent=4, ensure_ascii=False)

### Associação de pares com as respostas corretas

In [None]:
result_pairs_on_test = analysis.associate_results_with_data(test_dataset,
                                                            test.model.prompt_type,
                                                            sanitized_results_on_test)

result_pairs_on_training = analysis.associate_results_with_data(test_dataset,
                                                                test.model.prompt_type,
                                                                sanitized_results_on_training)

### Análise

In [None]:
model_name = test.tested_model
quantized = '(Quantizado)' if test.model.quantized else ''

skin_lesion_pairs = analysis.get_label_pairs(result_pairs_on_test, 'skin_lesion')

skin_lesion_labels = [analysis.sanitize_domain_class(key)
                      for key in training_dataset_analysis.skin_lesion_distribution.classes]

# TODO: Melhorar a visualização da matriz de confusão
skin_lesion_accuracy = analysis.create_confusion_matrix(skin_lesion_pairs,  # type: ignore
                                                        skin_lesion_labels,
                                                        f'{model_name} {quantized} - Lesões de pele',
                                                        join(defs.RESULTS_PATH, 'plots', f'skin_lesions_{TEST_NAME[:-4]}'))

if test.model.prompt_type == defs.PromptType.REPORT:
    risk_pairs = analysis.get_label_pairs(result_pairs_on_test, 'risk')

    risk_labels = [key[0] for key in training_dataset_analysis.risk_distribution.classes]

    risk_accuracy = analysis.create_confusion_matrix(risk_pairs,  # type: ignore
                                                     risk_labels,
                                                     f'{model_name} {quantized} - Classificação de risco',
                                                     join(defs.RESULTS_PATH, 'plots', f'risk_{TEST_NAME[:-4]}'))