# Imports

In [None]:
import os
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer
)
from torchvision.transforms import (
    Compose, Normalize, RandomHorizontalFlip,
    RandomResizedCrop, Resize, ToTensor
)

# --- 1. CONFIGURA√á√ïES INICIAIS ---

In [None]:
# Nome do modelo base (Vision Transformer - SOTA para classifica√ß√£o)
MODEL_CHECKPOINT = "google/vit-base-patch16-224-in21k"
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
NUM_EPOCHS = 5  # 5 √©pocas costuma ser bom para come√ßar com Transfer Learning
DATASET_DIR = "dataset_soja" # Nome da pasta onde est√£o as subpastas (Caterpillar, etc)

Configurar dispositivo (GPU se dispon√≠vel, sen√£o CPU)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üå≤ Iniciando projeto SojaTech usando: {device}")

# --- 2. CARREGAMENTO E DIVIS√ÉO DOS DADOS ---

In [None]:
print("üìÇ Carregando imagens...")
# A biblioteca 'datasets' l√™ a estrutura de pastas automaticamente
dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)

: 

Criar mapeamento de Labels (Ex: 0 -> Caterpillar, 1 -> Saudavel...)

In [None]:
labels = dataset['train'].features['label'].names
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)
print(f"Classes encontradas: {labels}")

Divis√£o Estratificada (Crucial devido ao desbalanceamento das classes)
80% Treino / 20% Valida√ß√£o

In [None]:
splits = dataset['train'].train_test_split(test_size=0.2, stratify_by_column="label", seed=42)
train_ds = splits['train']
val_ds = splits['test']

print(f"Imagens de Treino: {len(train_ds)}")
print(f"Imagens de Valida√ß√£o: {len(val_ds)}")

# --- 3. PR√â-PROCESSAMENTO (TRANSFORMS) ---
Carregar o processador original do modelo

In [None]:
processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)

Definir normaliza√ß√£o (m√©dia e desvio padr√£o do ImageNet)

In [None]:
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)

Transforma√ß√µes: O modelo espera 224x224, mas as tuas imagens s√£o 500x500.
Vamos redimensionar e aplicar Data Augmentation suave.

In [None]:
_train_transforms = Compose([
    RandomResizedCrop(processor.size['height']), # Recorte aleat√≥rio e resize para 224
    RandomHorizontalFlip(),                      # Espelhamento (ajuda na generaliza√ß√£o)
    ToTensor(),
    normalize,
])

_val_transforms = Compose([
    Resize(processor.size['height']),            # Apenas redimensiona para 224
    ToTensor(),
    normalize,
])

def transform_train(examples):
    examples["pixel_values"] = [_train_transforms(img.convert("RGB")) for img in examples["image"]]
    return examples

def transform_val(examples):
    examples["pixel_values"] = [_val_transforms(img.convert("RGB")) for img in examples["image"]]
    return examples

Aplicar as transforma√ß√µes ao dataset

In [None]:
train_ds.set_transform(transform_train)
val_ds.set_transform(transform_val)

# --- 4. DEFINI√á√ÉO DO MODELO ---

In [None]:
print("ü§ñ Baixando e configurando o modelo...")
model = AutoModelForImageClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True # Necess√°rio pois estamos mudando o n√∫mero de classes final
)

# --- 5. CONFIGURA√á√ÉO DO TREINO ---

In [None]:
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    """Calcula a acur√°cia das previs√µes"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

args = TrainingArguments(
    output_dir="resultado_soja",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    save_total_limit=2, # Salva apenas os 2 melhores checkpoints para economizar espa√ßo
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=processor,
    compute_metrics=compute_metrics,
)

# --- 6. EXECUTAR TREINO E SALVAR ---

In [None]:
print("üöÄ Iniciando treinamento...")
trainer.train()

print("üíæ Salvando modelo final...")
trainer.save_model("modelo_soja_final")
processor.save_pretrained("modelo_soja_final")

# Avalia√ß√£o Final
metrics = trainer.evaluate()
print("-" * 30)
print("RESULTADOS FINAIS:")
print(f"Acur√°cia: {metrics['eval_accuracy']:.2%}")
print("-" * 30)