# Notebook de Inferência CoDE para Colab

Este notebook testa o modelo CoDE com dataset ELSA_D3 em streaming, inferência visual com os 3 classificadores (Linear, KNN, SVM) e bateria de testes.

## 1. Montar Drive e verificar GPU

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import torch
print('CUDA disponível:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('Device:', torch.cuda.get_device_name(0))

## 2. Clonar repositório e instalar dependências

In [None]:
# Clone do CoDE. Use seu fork se tiver elsa_streaming_loader e benchmark_classifiers.
%cd /content
!git clone https://github.com/aimagelab/CoDE.git
%cd CoDE
!pip install transformers torch torchvision Pillow scikit-learn joblib huggingface-hub datasets requests

## 3. Configurar path e carregar os 3 modelos

In [None]:
import sys
import os

REPO_ROOT = '/content/CoDE'
sys.path.insert(0, os.path.join(REPO_ROOT, 'src', 'inference'))

from code_model import VITContrastiveHF
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = {
    'linear': VITContrastiveHF(classificator_type='linear'),
    'knn': VITContrastiveHF(classificator_type='knn'),
    'svm': VITContrastiveHF(classificator_type='svm'),
}

for name, model in models.items():
    model.eval()
    model.model.to(device)

print('3 modelos carregados: linear, knn, svm')

## 4. Carregar dados via streaming (ELSA_D3)

In [None]:
import torchvision.transforms as transforms
from elsa_streaming_loader import elsa_streaming_samples

transform = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Coletar N imagens para o grid visual (apenas fakes, mais rápido)
NUM_VISUAL = 12
visual_samples = []
for img, label, arch in elsa_streaming_samples(max_samples=NUM_VISUAL, use_real_images=False):
    img_t = transform(img)
    visual_samples.append((img_t, label, arch))

print(f'Coletadas {len(visual_samples)} imagens para visualização')

## 5. Inferência visual – grid com predições dos 3 classificadores

In [None]:
import matplotlib.pyplot as plt
import numpy as np

LABELS = ['Real', 'Fake']
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

def tensor_to_img(t):
    x = t.cpu().numpy().transpose(1, 2, 0)
    x = x * std + mean
    return np.clip(x, 0, 1)

def pred_to_label(pred, ctype):
    if ctype == 'svm':
        return 'Real' if pred == -1 else 'Fake'
    return LABELS[int(pred)]

cols, rows = 4, 3
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
axes = axes.flatten()

for i, (img_tensor, true_label, arch) in enumerate(visual_samples):
    if i >= len(axes):
        break
    x = img_tensor.unsqueeze(0).to(device)
    preds = {}
    with torch.no_grad():
        for ctype, model in models.items():
            p = model(x).item()
            preds[ctype] = pred_to_label(p, ctype)

    true_str = LABELS[true_label]
    title = f'Esperado: {true_str} ({arch}) | Linear: {preds["linear"]} | KNN: {preds["knn"]} | SVM: {preds["svm"]}'

    all_correct = all(preds[c] == true_str for c in ['linear', 'knn', 'svm'])
    color = 'green' if all_correct else 'red'

    axes[i].imshow(tensor_to_img(img_tensor))
    axes[i].axis('off')
    axes[i].set_title(title, fontsize=9, color=color)

for j in range(len(visual_samples), len(axes)):
    axes[j].axis('off')

plt.suptitle('Inferência CoDE – 3 classificadores', fontsize=12)
plt.tight_layout()
plt.show()

## 6. Bateria de testes (benchmark dos 3 classificadores)

In [None]:
%cd /content/CoDE/src/inference
!python benchmark_classifiers.py --streaming --max_samples 500 --result_folder /content/CoDE/results

### Alternativa: bateria no próprio notebook

In [None]:
import sys
sys.path.insert(0, '/content/CoDE/src/inference')

from benchmark_classifiers import run_benchmark_streaming, print_results

class Opt:
    max_samples = 500
    use_real_images = False
    batch_size = 128

opt = Opt()
results, n_total = run_benchmark_streaming(opt)
print_results(results, n_total)