## Conta os tokens

### Configura o ambiente

In [None]:
from os import environ

environ['CUDA_VISIBLE_DEVICES'] = input('Enter GPU ID: ')

### Imports

In [None]:
from os.path import join
from json import load

from unsloth import FastVisionModel
from tqdm.notebook import tqdm
from PIL import Image

import torch

from scripts.authentication import authenticate_huggingface
from scripts.messages import create_training_message
from scripts.data import SimpleLesionData, SimpleDatasetAnalysis

import scripts.definitions as defs

### Autenticação

In [None]:
authenticate_huggingface()

### Configurações

In [None]:
QUANTIZED = True

### Carregamento do dataset

In [None]:
with open(join(defs.DATA_PATH, 'stt_data', 'simple_dataset.json'), 'r', encoding='utf-8') as file:
    dataset = [SimpleLesionData(**data) for data in load(file)]

with open(join(defs.DATA_PATH, 'simple_dataset_analysis.json'), 'r', encoding='utf-8') as file:
    dataset_analysis = SimpleDatasetAnalysis(**load(file))

### Preparação das mensagens

In [None]:
messages = []

for data in tqdm(dataset, desc='Criando mensagens: '):
    messages.append(create_training_message(defs.PromptType.REPORT, data, dataset_analysis))

### Carregamento do tokenizador

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    defs.BASE_MODEL_NAME,
    load_in_4bit=QUANTIZED,
    use_gradient_checkpointing='unsloth'
)

FastVisionModel.for_inference(model)

### Contagem dos tokens

In [None]:
token_counts = []

for message in tqdm(messages, desc='Contando os tokens das mensagens: '):
    image = Image.open(join('..', message['messages'][0]['content'][1]['image'])).convert('RGB')
    input_text = tokenizer.apply_chat_template([message['messages'][0]], add_generation_prompt=True)
    output_text = tokenizer.apply_chat_template([message['messages'][1]], add_generation_prompt=True)

    input_token_count = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_length=True
    )['input_ids'][0]

    output_token_count = tokenizer(
        None,
        output_text,
        add_special_tokens=False,
        return_length=True
    )['input_ids'][0]

    token_counts.append(len(input_token_count) + len(output_token_count))

print(f'Mínimo: {min(token_counts)}')
print(f'Máximo: {max(token_counts)}')
print(f'Média: {sum(token_counts) / len(token_counts)}')