# Testes

### Imports

In [None]:
from os.path import join, exists
from os import makedirs
from json import load, dump
from datetime import datetime

from unsloth import FastVisionModel
from tqdm.notebook import tqdm
from PIL import Image

import torch

from scripts.authentication import authenticate_huggingface
from scripts.messages import add_inference_message, format_prompt

import experiments.notebooks.scripts.definitions as defs

### Configurações

In [None]:
# Edite as duas constantes abaixo
MODEL = defs.BASE_MODEL_NAME
QUANTIZED = False  # Isso é sobreescrito no caso de modelos treinados
TEST_ON_TRAINING = False  # Testa o modelo sobre os dados de treinamento
DETERMINISTIC = True
TEMPERATURE = 0.0

with open(join(defs.TRAINING_PATH, 'models.json'), 'r', encoding='utf-8') as file:
    models = load(file)

model_stats = models[MODEL]
model_path = ''

if model_stats['local']:
    model_path = join(defs.RESULTS_PATH, 'adapter_weights', MODEL)
else:
    model_path = MODEL

quantized = model_stats['quantized'] if model_stats['quantized'] is not None else QUANTIZED
prompt_template = ''
prompt_type = None

match model_stats['type']:
    case 'base' | 'simple_classification':
        prompt_type = defs.PromptType.SIMPLE_CLASSIFICATION
        prompt_template = defs.SIMPLE_CLASSIFICATION_PROMPT_TEMPLATE
    case 'report':
        prompt_type = defs.PromptType.REPORT
        prompt_template = defs.REPORT_PROMPT_TEMPLATE
    case _:
        raise ValueError('Invalid model type')

model_version = model_stats['version']
model_size = model_stats['size']

### Autenticação

In [None]:
authenticate_huggingface()

### Carregamento do dataset

In [None]:
if TEST_ON_TRAINING:
    with open(join(defs.DATA_PATH, 'stt_data', 'training_dataset.json'), 'r', encoding='utf-8') as file:
        training_dataset = load(file)

    with open(join(defs.DATA_PATH, 'training_dataset_analysis.json'), 'r', encoding='utf-8') as file:
        training_dataset_analysis = load(file)

with open(join(defs.DATA_PATH, 'stt_data', 'test_dataset.json'), 'r', encoding='utf-8') as file:
    test_dataset = load(file)

with open(join(defs.DATA_PATH, 'test_dataset_analysis.json'), 'r', encoding='utf-8') as file:
    test_dataset_analysis = load(file)

### Carregamento do modelo

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    model_path,
    load_in_4bit=quantized,
    use_gradient_checkpointing='unsloth',
    random_state=defs.STATIC_RANDOM_STATE
)

FastVisionModel.for_inference(model)

### Preparação do teste

In [None]:
formatted_prompt = format_prompt(prompt_template, prompt_type, training_dataset_analysis)
messages = add_inference_message(formatted_prompt)

test_name = f'{MODEL}_test_{datetime.now().isoformat()}'.strip('unsloth/')
test_output = {'model_name': MODEL, 'model': models[MODEL], 'results_on_test': [], 'results_on_training': []}
tests_path = join(defs.RESULTS_PATH, 'tests')

if not exists(tests_path):
    makedirs(tests_path)

### Testes sobre os dados de teste

In [None]:
# TODO: Verificar se isso deve ficar aqui
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

for idx, exam in enumerate(tqdm(test_dataset, desc='Testing on test data: ')):
    image_path = join(defs.DATA_PATH, 'stt_raw_data', 'dataset', 'images', exam['image'])
    image = Image.open(image_path).convert('RGB')

    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors='pt',
    ).to('cuda')

    # TODO: Tentar travar o estado
    outputs = model.generate(
        **inputs,
        max_new_tokens=2048,
        use_cache=False,  # Isola melhor os casos de teste
        do_sample=not DETERMINISTIC,  # TODO: Analisar
        temperature=TEMPERATURE  # TODO: Analisar
    )

    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    assistant_message = result.split('assistant')[-1].strip()
    structured_result = {'exam_id': exam['id'], 'image': exam['image'], 'answer': assistant_message}

    test_output['results_on_test'].append(structured_result)

### Testes sobre os dados de treinamento

In [None]:
if TEST_ON_TRAINING:
    for idx, exam in enumerate(tqdm(training_dataset, desc='Testing on training data: ')):
        image_path = join(defs.DATA_PATH, 'stt_raw_data', 'dataset', 'images', exam['image'])
        image = Image.open(image_path).convert('RGB')

        inputs = tokenizer(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors='pt',
        ).to('cuda')

        outputs = model.generate(
            **inputs,
            max_new_tokens=2048,
            use_cache=False,
            do_sample=not DETERMINISTIC,  # TODO: Analisar
            temperature=TEMPERATURE  # TODO: Analisar
        )

        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        assistant_message = result.split('assistant')[-1].strip()
        structured_result = {'exam_id': exam['id'], 'image': exam['image'], 'answer': assistant_message}

        test_output['results_on_training'].append(structured_result)

### Salvamento do teste

In [None]:
test_path = join(tests_path, f'{test_name}.json')

with open(test_path, 'w+', encoding='utf-8') as file:
    dump(test_output, file, indent=4, ensure_ascii=False)