In [None]:
#!pip install ultralytics opencv-python pillow

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ---------- Bibliotecas ----------

import os
import shutil
import random
from sklearn.model_selection import train_test_split
from PIL import Image
import yaml
import numpy as np
import cv2
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
import timm
from sklearn.metrics import accuracy_score, f1_score
from itertools import product
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
import time
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

In [None]:
path = os.getcwd()
print(path)

#os.chdir(path)
#file_log = open(path + "/mensagem_final_classificar_V2.txt", "a")

/content


In [None]:
# ================= Configura√ß√µes =================
NUM_CLASSES = 15
INPUT_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 30
LR = 1e-3
PATIENCE = 5

FUSION_MODE = "concat"  # "concat" OU "sum"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATASET_FOLHA = "/content/drive/MyDrive/TCC/Datasets/Imagens Folhas/Especies"
DATASET_CASCA = "/content/drive/MyDrive/TCC/Datasets/Imagens tronco/EspeciesCascas"

PESOS_FOLHA = "/content/drive/MyDrive/TCC/Datasets/main_weights/folha/mobilenetv4_best_leaf.pt"
PESOS_CASCA = "/content/drive/MyDrive/TCC/Datasets/main_weights/casca/mobilenetv4_best_bark.pt"

CKPT_DIR = "/content/drive/MyDrive/TCC/Datasets/checkpointsHybridFeature"
FINAL_PATH = "/content/drive/MyDrive/TCC/Datasets/main_weights/hybrid_mobilenet_best.pt"

os.makedirs(CKPT_DIR, exist_ok=True)

In [None]:
def stratified_split(dataset, test_split=0.1, valid_split=0.2, seed=42):
    labels = [label for *_, label in dataset.samples]

    sss1 = StratifiedShuffleSplit(
        n_splits=1,
        test_size=test_split,
        random_state=seed
    )
    train_valid_idx, test_idx = next(
        sss1.split(np.zeros(len(labels)), labels)
    )

    labels_train_valid = np.array(labels)[train_valid_idx]

    sss2 = StratifiedShuffleSplit(
        n_splits=1,
        test_size=valid_split,
        random_state=seed
    )
    train_idx, valid_idx = next(
        sss2.split(np.zeros(len(labels_train_valid)), labels_train_valid)
    )

    train_idx = np.array(train_valid_idx)[train_idx]
    valid_idx = np.array(train_valid_idx)[valid_idx]

    train_ds = torch.utils.data.Subset(dataset, train_idx)
    valid_ds = torch.utils.data.Subset(dataset, valid_idx)
    test_ds  = torch.utils.data.Subset(dataset, test_idx)

    return train_ds, valid_ds, test_ds

In [None]:
# ================= Dataset =================
class ImageFolderDataset(Dataset):
    def __init__(self, root_dir):
        self.samples = []
        classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {c: i for i, c in enumerate(classes)}

        for cls in classes:
            cls_path = os.path.join(root_dir, cls)
            if not os.path.isdir(cls_path):
                continue
            for f in os.listdir(cls_path):
                if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append((os.path.join(cls_path, f),
                                         self.class_to_idx[cls]))

    def preprocess(self, img_path):
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (INPUT_SIZE, INPUT_SIZE))
        img = img.astype(np.float32) / 255.0

        mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        img = (img - mean) / std

        img = np.transpose(img, (2, 0, 1))
        return torch.from_numpy(img)

class CartesianFusionDataset(Dataset):
    def __init__(self, ds_folha, ds_casca):
        self.samples = []

        # Extrair samples considerando Subset
        def get_samples(ds):
            if isinstance(ds, torch.utils.data.Subset):
                # Pegar apenas os √≠ndices do subset
                base_samples = ds.dataset.samples
                return [base_samples[i] for i in ds.indices]
            else:
                return ds.samples

        folha_samples = get_samples(ds_folha)
        casca_samples = get_samples(ds_casca)

        # Agrupar por classe
        folhas_por_classe = {}
        cascas_por_classe = {}

        for img, label in folha_samples:
            if label not in folhas_por_classe:
                folhas_por_classe[label] = []
            folhas_por_classe[label].append(img)

        for img, label in casca_samples:
            if label not in cascas_por_classe:
                cascas_por_classe[label] = []
            cascas_por_classe[label].append(img)

        # Produto cartesiano por classe
        for label in folhas_por_classe.keys():
            if label not in cascas_por_classe:
                continue
            for f_img in folhas_por_classe[label]:
                for c_img in cascas_por_classe[label]:
                    self.samples.append((f_img, c_img, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        folha_path, casca_path, label = self.samples[idx]

        # Preprocessar imagens
        img_f = self.preprocess(folha_path)
        img_c = self.preprocess(casca_path)

        return img_f, img_c, torch.tensor(label, dtype=torch.long)

    def preprocess(self, img_path):
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (INPUT_SIZE, INPUT_SIZE))
        img = img.astype(np.float32) / 255.0

        mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        img = (img - mean) / std

        img = np.transpose(img, (2, 0, 1))
        return torch.from_numpy(img)

In [None]:
# ================= Feature Extractors =================
def create_feature_extractor(weight_path):
    # 1. Carregar modelo COMPLETO primeiro
    model_full = timm.create_model(
        'mobilenetv4_conv_small.e1200_r224_in1k',
        pretrained=False,
        num_classes=15  # ‚Üê MESMO n√∫mero do treinamento
    )
    model_full.load_state_dict(torch.load(weight_path, map_location=DEVICE))

    # 2. Criar feature extractor
    model = timm.create_model(
        'mobilenetv4_conv_small.e1200_r224_in1k',
        pretrained=False,
        num_classes=0  # ‚Üê Sem classifier
    )

    # 3. Copiar pesos do backbone (tudo exceto classifier)
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in model_full.state_dict().items() if k in model_dict}
    model.load_state_dict(pretrained_dict)

    # 4. Congelar
    for p in model.parameters():
        p.requires_grad = False
    model.eval()

    return model.to(DEVICE)

In [None]:
# ================= MLP =================
class FusionMLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, NUM_CLASSES)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# ================= Avalia√ß√£o =================
@torch.no_grad()
def evaluate(model_f, model_c, mlp, loader):
    mlp.eval()
    preds, labels_all = [], []
    inference_times = []  # ‚Üê NOVO

    for img_f, img_c, labels in loader:
        img_f, img_c = img_f.to(DEVICE), img_c.to(DEVICE)

        # ‚Üê NOVO: Medir tempo
        start_time = time.time()

        feat_f = model_f(img_f)
        feat_c = model_c(img_c)

        fused = torch.cat([feat_f, feat_c], 1) if FUSION_MODE == "concat" else feat_f + feat_c
        out = mlp(fused)

        end_time = time.time()
        batch_time = end_time - start_time
        inference_times.append(batch_time / img_f.size(0))  # Tempo por imagem

        preds.extend(torch.argmax(out, 1).cpu().numpy())
        labels_all.extend(labels.numpy())

    acc = accuracy_score(labels_all, preds)
    f1 = f1_score(labels_all, preds, average="weighted")
    avg_time = np.mean(inference_times) * 1000  # ‚Üê NOVO: em milissegundos

    return acc, f1, avg_time, preds, labels_all  # ‚Üê NOVO: retorna tempo, preds e labels

In [None]:
# ================= Fun√ß√µes de An√°lise =================
def plot_confusion_matrix(y_true, y_pred, class_names, title="Matriz de Confus√£o", normalize=False):
    """
    Plota matriz de confus√£o com visualiza√ß√£o aprimorada.

    Args:
        y_true: Labels verdadeiros
        y_pred: Predi√ß√µes do modelo
        class_names: Lista com nomes das classes
        title: T√≠tulo do gr√°fico
        normalize: Se True, normaliza os valores por linha (%)
    """
    cm = confusion_matrix(y_true, y_pred)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2%'
    else:
        fmt = 'd'

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Porcentagem' if normalize else 'Contagem'})
    plt.title(title, fontsize=14, fontweight='bold')
    plt.ylabel('Classe Verdadeira', fontsize=12)
    plt.xlabel('Classe Predita', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def print_per_class_accuracy(y_true, y_pred, class_names):
    """
    Imprime acur√°cia por classe e outras m√©tricas detalhadas.

    Args:
        y_true: Labels verdadeiros
        y_pred: Predi√ß√µes do modelo
        class_names: Lista com nomes das classes
    """
    cm = confusion_matrix(y_true, y_pred)

    print("\n" + "="*80)
    print("ACUR√ÅCIA POR CLASSE")
    print("="*80)
    print(f"{'Classe':<25} {'Corretas':>10} {'Total':>10} {'Acur√°cia':>12}")
    print("-"*80)

    per_class_acc = []
    for i, class_name in enumerate(class_names):
        correct = cm[i, i]
        total = cm[i, :].sum()
        acc = correct / total if total > 0 else 0
        per_class_acc.append(acc)
        print(f"{class_name:<25} {correct:>10} {total:>10} {acc:>12.2%}")

    print("-"*80)
    print(f"{'M√âDIA':<25} {'':<10} {'':<10} {np.mean(per_class_acc):>12.2%}")
    print("="*80)

    # Relat√≥rio de classifica√ß√£o detalhado
    print("\n" + "="*80)
    print("RELAT√ìRIO DE CLASSIFICA√á√ÉO DETALHADO")
    print("="*80)
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))

    return per_class_acc

In [None]:
# ================= Treinamento =================
def train(model_f, model_c, mlp, train_loader, valid_loader):
    optimizer = optim.Adam(mlp.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    best_f1 = 0
    patience_counter = 0
    best_preds = None  # ‚Üê NOVO
    best_labels = None  # ‚Üê NOVO
    early_stop = False

    # === Tentar carregar √∫ltimo checkpoint existente ===
    os.makedirs(CKPT_DIR, exist_ok=True)

    # Procura checkpoints com o padr√£o "hybrid_feature_epochX_eY_lrZ_modeM.pt"
    ckpt_pattern = f"e{EPOCHS:.0e}_lr{LR:.0e}_mode{FUSION_MODE}"
    ckpt_files = [f for f in os.listdir(CKPT_DIR) if ckpt_pattern in f]
    last_epoch = 0

    if ckpt_files:
        # Ordena checkpoints por n√∫mero da √©poca
        ckpt_files.sort(key=lambda x: int(x.split("_epoch")[1].split("_")[0]))
        last_ckpt = os.path.join(CKPT_DIR, ckpt_files[-1])
        print(f"Checkpoint detectado: {last_ckpt} ‚Äî retomando treinamento...")

        checkpoint = torch.load(last_ckpt, map_location=DEVICE)
        mlp.load_state_dict(checkpoint['mlp_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        last_epoch = checkpoint['epoch']
        best_f1 = checkpoint.get('best_f1', 0.0)
        patience_counter = checkpoint.get('patience_counter', 0)

        # Testa se j√° chegou no final do treinamento
        if last_epoch == EPOCHS:
            print("Treinamento j√° foi finalizado!")
            acc_val, f1_val = checkpoint['val_acc'], checkpoint['val_f1']
            print(f"Epoch {last_epoch}/{EPOCHS} - val_acc: {acc_val:.4f}, val_f1: {f1_val:.4f}")

            # Retornar predi√ß√µes salvas (se existirem)
            if 'best_preds' in checkpoint and 'best_labels' in checkpoint:
                return checkpoint['best_preds'], checkpoint['best_labels']
            else:
                # Avaliar para obter predi√ß√µes
                _, _, _, preds, labels = evaluate(model_f, model_c, mlp, valid_loader)
                return preds, labels
        else:
            print(f"Retomando a partir da √©poca {last_epoch+1}/{EPOCHS} (lr={LR})")
    else:
        print("Nenhum checkpoint anterior encontrado, come√ßando do zero.")

    # === Loop de Treinamento ===
    for epoch in range(last_epoch, EPOCHS):
        mlp.train()
        running_loss = 0.0

        for img_f, img_c, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            img_f, img_c, y = img_f.to(DEVICE), img_c.to(DEVICE), y.to(DEVICE)

            with torch.no_grad():
                feat_f = model_f(img_f)
                feat_c = model_c(img_c)

            fused = (
                torch.cat([feat_f, feat_c], 1)
                if FUSION_MODE == "concat"
                else feat_f + feat_c
            )

            out = mlp(fused)
            loss = criterion(out, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * img_f.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        acc, f1, avg_time, preds, labels = evaluate(model_f, model_c, mlp, valid_loader)

        print(f"Epoch {epoch+1}/{EPOCHS} - loss: {epoch_loss:.4f}, val_acc: {acc:.4f}, val_f1: {f1:.4f}, time: {avg_time:.2f}ms")

        # === Early Stopping por val_f1 ===
        if f1 > best_f1:
            best_f1 = f1
            patience_counter = 0
            best_model_state = mlp.state_dict().copy()
            best_preds = preds
            best_labels = labels
            best_epoch = epoch + 1
            print(f"  ‚úÖ Novo melhor F1: {f1:.4f}")
        else:
            patience_counter += 1
            print(f"  ‚è≥ Contador do Early stopping: {patience_counter}/{PATIENCE} - Melhor F1-Score anterior: {best_f1:.4f}")

            if patience_counter >= PATIENCE:
                print(f"  üõë Parando antecipadamente na √©poca {epoch+1}. Melhor val_f1: {best_f1:.4f}")
                early_stop = True

        # === Salvar checkpoint por √©poca ===
        if early_stop:
            # Salva a √©poca atual como se fosse a √∫ltima
            current_epoch = EPOCHS
            ckpt_path = os.path.join(CKPT_DIR,
                                    f"hybrid_feature_epoch{EPOCHS}_e{EPOCHS:.0e}_lr{LR:.0e}_mode{FUSION_MODE}.pt")
        else:
            current_epoch = epoch + 1
            ckpt_path = os.path.join(CKPT_DIR,
                                    f"hybrid_feature_epoch{current_epoch}_e{EPOCHS:.0e}_lr{LR:.0e}_mode{FUSION_MODE}.pt")

        torch.save({
            'epoch': current_epoch,
            'mlp_state_dict': mlp.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': acc,
            'val_f1': f1,
            'loss': epoch_loss,
            'lr': LR,
            'best_f1': best_f1,
            'patience_counter': patience_counter,
            'best_preds': best_preds,
            'best_labels': best_labels,
            'fusion_mode': FUSION_MODE
        }, ckpt_path)
        print(f"  üíæ Checkpoint salvo: {ckpt_path}")

        # Apagar checkpoint anterior
        prev_ckpt = os.path.join(CKPT_DIR,
                                f"hybrid_feature_epoch{epoch}_e{EPOCHS:.0e}_lr{LR:.0e}_mode{FUSION_MODE}.pt")
        if os.path.exists(prev_ckpt):
            os.remove(prev_ckpt)
            print(f"  üóëÔ∏è  Checkpoint deletado: {prev_ckpt}")

        if early_stop:
            break  # Early Stopping

    # === Restaurar melhor modelo ===
    if best_model_state is not None:
        print(f"\nüîÑ Restaurando melhor modelo (√©poca {best_epoch}, F1={best_f1:.4f})")
        mlp.load_state_dict(best_model_state)

        # Salvar melhor modelo no caminho final
        torch.save(best_model_state, FINAL_PATH)
        print(f"üíæ Melhor modelo salvo em: {FINAL_PATH}")

    return best_preds, best_labels  # ‚Üê NOVO: retorna predi√ß√µes da melhor √©poca

In [None]:
# === Fun√ß√£o Main ===
if __name__ == "__main__":
    try:
        print(
            "\n--------------- Treinamento do Modelo H√≠brido 2 - MobileNetV4 ---------------"
            "\nIn√≠cio..."
        )

        # 1. Dividir datasets originais
        ds_f = ImageFolderDataset(DATASET_FOLHA)
        print(f"Dataset de Folhas:\nClasses detectadas ({len(ds_f.class_to_idx.keys())}): {ds_f.class_to_idx.keys()}")
        print(f"Total de Imagens: {len(ds_f)}")
        train_f, valid_f, test_f = stratified_split(ds_f)
        print(f"Total: {len(train_f)+len(valid_f)+len(test_f)} | Treino: {len(train_f)} | Valida√ß√£o: {len(valid_f)} | Teste: {len(test_f)}\n")

        ds_c = ImageFolderDataset(DATASET_CASCA)
        print(f"Dataset de Cascas:\nClasses detectadas ({len(ds_c.class_to_idx.keys())}): {ds_c.class_to_idx.keys()}")
        print(f"Total de Imagens: {len(ds_c)}")
        train_c, valid_c, test_c = stratified_split(ds_c)
        print(f"Total: {len(train_c)+len(valid_c)+len(test_c)} | Treino: {len(train_c)} | Valida√ß√£o: {len(valid_c)} | Teste: {len(test_c)}\n")

        # 2. Criar produto cartesiano DEPOIS
        train_fusion = CartesianFusionDataset(train_f, train_c)
        valid_fusion = CartesianFusionDataset(valid_f, valid_c)
        test_fusion = CartesianFusionDataset(test_f, test_c)

        #fusion = CartesianFusionDataset(ds_f.dataset, ds_c.dataset)
        #train_fusion, valid_fusion, test_fusion = stratified_split(fusion)

        # Extra. Exibir dados do Dataset ---------------------------------------- DEBUG ------------------------------------------------------------------------------------------
        print(f"Total: {len(train_fusion)+len(valid_fusion)+len(test_fusion)} | Treino: {len(train_fusion)} | Valida√ß√£o: {len(valid_fusion)} | Teste: {len(test_fusion)}")

        print("\n=== DEBUG ===")

        # Verificar quantas classes est√£o presentes em cada split
        def check_classes(fusion_ds, name):
            classes_presentes = set()
            for _, _, label in fusion_ds.samples:
                classes_presentes.add(label)
            print(f"{name}: {len(classes_presentes)} classes presentes de 15 totais")
            print(f"Classes: {sorted(classes_presentes)}")
            return classes_presentes

        train_classes = check_classes(train_fusion, "Train")
        valid_classes = check_classes(valid_fusion, "Valid")
        test_classes = check_classes(test_fusion, "Test")

        # Verificar se h√° classes faltando
        all_classes = set(range(15))
        print(f"\nClasses faltando em train: {all_classes - train_classes}")
        print(f"Classes faltando em valid: {all_classes - valid_classes}")
        print(f"Classes faltando em test: {all_classes - test_classes}")

        # Pares por classe em cada split
        def count_pairs_per_class(fusion_ds, name):
            pares_por_classe = {}
            for _, _, label in fusion_ds.samples:
                pares_por_classe[label] = pares_por_classe.get(label, 0) + 1

            print(f"\n{name} - Pares por classe:")
            for label in sorted(pares_por_classe.keys()):
                print(f"  Classe {label}: {pares_por_classe[label]:4d} pares")
            print(f"Total: {sum(pares_por_classe.values())}")

        count_pairs_per_class(train_fusion, "TREINO")
        count_pairs_per_class(valid_fusion, "VALIDA√á√ÉO")
        count_pairs_per_class(test_fusion, "TESTE")

        # -------------------------------------------------------------------------------------------------------------------------------------------------------------------

        # 3. Criar DataLoaders
        train_loader = DataLoader(train_fusion, batch_size=BATCH_SIZE, shuffle=True)
        valid_loader = DataLoader(valid_fusion, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_fusion, batch_size=BATCH_SIZE, shuffle=False)

        # üîπ Modelos base (feature extractors)
        model_f = create_feature_extractor(PESOS_FOLHA).to(DEVICE)
        model_c = create_feature_extractor(PESOS_CASCA).to(DEVICE)

        model_f.eval()
        model_c.eval()

        for p in model_f.parameters():
            p.requires_grad = False
        for p in model_c.parameters():
            p.requires_grad = False

        # üîπ Dimens√£o das features
        feat_dim = model_f.num_features
        if FUSION_MODE == "concat":
            feat_dim *= 2

        # üîπ MLP de fus√£o
        mlp = FusionMLP(feat_dim).to(DEVICE)

        # üîπ Treinamento (somente MLP)
        print(f"\n{'='*80}")
        print(f"TREINAMENTO DA MLP - FUSION_MODE: {FUSION_MODE}")
        print(f"{'='*80}\n")

        best_preds_val, best_labels_val = train(  # ‚Üê ATUALIZADO
            model_f,
            model_c,
            mlp,
            train_loader,
            valid_loader
        )

        # Carregar melhor modelo
        mlp.load_state_dict(torch.load(FINAL_PATH))

        # ‚Üê NOVO: Avaliar no test
        print(f"\n{'='*80}")
        print(f"AVALIA√á√ÉO FINAL NO CONJUNTO DE TESTE")
        print(f"{'='*80}\n")

        acc_test, f1_test, time_test, preds_test, labels_test = evaluate(model_f, model_c, mlp, test_loader)

        print(f"\nüéØ TESTE FINAL:")
        print(f"   Accuracy: {acc_test:.4f}")
        print(f"   F1-Score: {f1_test:.4f}")
        print(f"   Tempo m√©dio de infer√™ncia: {time_test:.2f} ms/imagem")

        # ‚Üê NOVO: Obter nomes das classes
        class_names = list(ds_f.class_to_idx.keys())

        # ‚Üê NOVO: An√°lise detalhada
        print("\n" + "="*80)
        print("AN√ÅLISE DETALHADA - FEATURE-LEVEL FUSION")
        print("="*80)
        per_class_acc = print_per_class_accuracy(labels_test, preds_test, class_names)

        # ‚Üê NOVO: Matriz de confus√£o
        plot_confusion_matrix(labels_test, preds_test, class_names,
                             title=f"Matriz de Confus√£o - Feature-Level Fusion ({FUSION_MODE.upper()})",
                             normalize=True) #OBS: Normalize deixa como porcentagem, n√£o n√∫mero bruto

        print("...Fim\n")

    except KeyboardInterrupt:
        print("Programa encerrado via terminal...")