In [8]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [9]:
# fixmatch_train.py
# ==========================================================
# FixMatch implementation for CIFAR-10 with multiple ResNet backbones (excluding 18, 50)
# ==========================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as T
import torchvision.models as models
import math
import copy
from tqdm import tqdm
import numpy as np
import os
import random
import matplotlib.pyplot as plt
import pandas as pd
from torchvision.transforms import InterpolationMode
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support


# ==========================================================
# Configurações globais
# ==========================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

NUM_CLASSES = 10
BATCH_L = 64
BATCH_U = 128
LR = 0.03
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
TAU = 0.95
LAMBDA_U = 1.0
PRINT_EVERY = 50

# ==========================================================
# Augmentações FixMatch
# ==========================================================
weak_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.247, 0.243, 0.261))
])

strong_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.RandAugment(num_ops=2, magnitude=10),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.247, 0.243, 0.261))
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.247, 0.243, 0.261))
])


# ==========================================================
# Augmentações ViT
# ==========================================================

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

weak_transform_vit = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.BICUBIC),
    T.RandomHorizontalFlip(),
    T.RandomCrop(224, padding=28),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

strong_transform_vit = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.BICUBIC),
    T.RandomHorizontalFlip(),
    T.RandomCrop(224, padding=28),
    T.RandAugment(num_ops=2, magnitude=10),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_transform_vit = T.Compose([
    T.Resize(224, interpolation=InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])



# ==========================================================
# Dataset split CIFAR10
# ==========================================================
def make_cifar10_datasets(n_labeled_per_class, root='./data',
                          transform_labeled=weak_transform,
                          transform_test=test_transform):
    base_train_raw = torchvision.datasets.CIFAR10(root=root, train=True,
                                                  download=True, transform=None)
    base_test = torchvision.datasets.CIFAR10(root=root, train=False,
                                             download=True, transform=transform_test)
    targets = np.array(base_train_raw.targets)
    labeled_idx, unlabeled_idx = [], []
    for c in range(NUM_CLASSES):
        idx = np.where(targets == c)[0]
        np.random.shuffle(idx)
        labeled_idx.extend(idx[:n_labeled_per_class])
        unlabeled_idx.extend(idx[n_labeled_per_class:])

    # Rotulado recebe transform já aqui
    labeled_ds = Subset(torchvision.datasets.CIFAR10(root=root, train=True, transform=transform_labeled),
                        labeled_idx)
    # Não rotulado permanece "raw"; as transforms são aplicadas no wrapper
    unlabeled_ds = Subset(base_train_raw, unlabeled_idx)
    return labeled_ds, unlabeled_ds, base_test


# ==========================================================
# Wrappers
# ==========================================================
class LabeledWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x, y
    def __len__(self):
        return len(self.dataset)

class UnlabeledWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, weak_t, strong_t):
        self.dataset = dataset
        self.weak_t = weak_t
        self.strong_t = strong_t
    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        return self.weak_t(img), self.strong_t(img)
    def __len__(self):
        return len(self.dataset)

# ==========================================================
# Modelos CNN (ResNet family + DenseNet121 + MobileNetV3)
# ==========================================================
class ResNet34(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.backbone = models.resnet34(
            weights=models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
        )
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

class ResNet101(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.backbone = models.resnet101(
            weights=models.ResNet101_Weights.IMAGENET1K_V1 if pretrained else None
        )
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

class ResNet152(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.backbone = models.resnet152(
            weights=models.ResNet152_Weights.IMAGENET1K_V1 if pretrained else None
        )
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

class DenseNet121(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.backbone = models.densenet121(
            weights=models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
        )
        in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

class MobileNetV3(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.backbone = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.DEFAULT if pretrained else None
        )
        in_features = self.backbone.classifier[-1].in_features
        self.backbone.classifier[-1] = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

# NEW CLASS
class ViTB16(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.backbone = models.vit_b_16(
            weights=models.ViT_B_16_Weights.IMAGENET1K_V1 if pretrained else None
        )
        in_features = self.backbone.heads.head.in_features
        self.backbone.heads.head = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)


# ==========================================================
# Fábrica
# ==========================================================
class ConvolutionalModelFactory:
    def __init__(self):
        self.model_map = {
            "resnet34": ResNet34,
            "resnet101": ResNet101,
            "resnet152": ResNet152,
            "vit_b16": ViTB16, 
        }
    def get_model(self, model_name='resnet34', num_classes=10, pretrained=True):
        model_name = model_name.lower()
        if model_name not in self.model_map:
            raise ValueError(f"Modelo {model_name} não disponível. Opções: {list(self.model_map.keys())}")
        return self.model_map[model_name](num_classes=num_classes, pretrained=pretrained)


# ==========================================================
# EMA
# ==========================================================
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.ema = copy.deepcopy(model).eval().to(DEVICE)
        self.decay = decay
        for p in self.ema.parameters():
            p.requires_grad_(False)
    def update(self, model):
        with torch.no_grad():
            msd = model.state_dict()
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= self.decay
                    v += (1. - self.decay) * msd[k].detach()
                else:
                    v.copy_(msd[k].detach())

# ==========================================================
# Avaliação
# ==========================================================
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# ==========================================================
# Treinamento FixMatch
# ==========================================================
criterion = nn.CrossEntropyLoss(reduction='mean')

def train_fixmatch(model_name='resnet34',
                   n_labeled_per_class=250,
                   epochs=100,
                   root='./data',
                   drop_last=True,
                   patience=50,
                   weak_t=None, strong_t=None, test_t=None,
                   batch_l=None, batch_u=None):
    # Escolha de transforms e batch sizes (padrões preservam seu comportamento atual)
    weak_t = weak_t if weak_t is not None else weak_transform
    strong_t = strong_t if strong_t is not None else strong_transform
    test_t = test_t if test_t is not None else test_transform
    local_batch_l = batch_l if batch_l is not None else BATCH_L
    local_batch_u = batch_u if batch_u is not None else BATCH_U

    labeled_ds, unlabeled_ds, test_ds = make_cifar10_datasets(
        n_labeled_per_class, root=root,
        transform_labeled=weak_t, transform_test=test_t
    )

    labeled_loader = DataLoader(LabeledWrapper(labeled_ds), batch_size=local_batch_l,
                                shuffle=True, num_workers=2, drop_last=True)
    unlabeled_loader = DataLoader(UnlabeledWrapper(unlabeled_ds, weak_t, strong_t),
                                  batch_size=local_batch_u, shuffle=True,
                                  num_workers=2, drop_last=drop_last)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)

    factory = ConvolutionalModelFactory()
    model = factory.get_model(model_name=model_name, num_classes=NUM_CLASSES, pretrained=True).to(DEVICE)
    ema = ModelEMA(model, decay=0.999)

    optimizer = SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=True)
    total_steps = epochs * len(labeled_loader)
    lr_lambda = lambda step: math.cos((7.0 * math.pi * step) / (16.0 * total_steps)) if total_steps > 0 else 1.0
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    best_acc = 0.0
    epochs_no_improve = 0
    history = {"epoch": [], "loss_total": [], "loss_supervised": [], "loss_unsupervised": [], "mask_rate": [], "accuracy": []}

    unlabeled_iter = iter(unlabeled_loader)
    for epoch in range(epochs):
        model.train()
        loss_epoch, loss_sup_epoch, loss_unsup_epoch, mask_rate_epoch = [], [], [], []
        pbar = tqdm(enumerate(labeled_loader), total=len(labeled_loader), desc=f"Epoch {epoch}")
        for i, (x_l, y_l) in pbar:
            try:
                x_ul_w, x_ul_s = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = iter(unlabeled_loader)
                x_ul_w, x_ul_s = next(unlabeled_iter)

            x_l, y_l = x_l.to(DEVICE), y_l.to(DEVICE)
            x_ul_w, x_ul_s = x_ul_w.to(DEVICE), x_ul_s.to(DEVICE)

            logits_l = model(x_l)
            loss_l = criterion(logits_l, y_l)

            with torch.no_grad():
                probs_ul_w = torch.softmax(model(x_ul_w), dim=1)
                max_probs, p_hat = torch.max(probs_ul_w, dim=1)
                mask = (max_probs >= TAU).float()

            logits_ul_s = model(x_ul_s)
            loss_u_all = F.cross_entropy(logits_ul_s, p_hat, reduction='none')
            # Normalização pelos selecionados (opção mais estável)
            denom = torch.clamp(mask.sum(), min=1.0)
            loss_u = (mask * loss_u_all).sum() / denom

            loss = loss_l + LAMBDA_U * loss_u
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            ema.update(model)

            loss_epoch.append(loss.item())
            loss_sup_epoch.append(loss_l.item())
            loss_unsup_epoch.append(loss_u.item())
            mask_rate_epoch.append(mask.mean().item())

            if (i + 1) % PRINT_EVERY == 0:
                pbar.set_postfix({
                    "loss_s": f"{loss_l.item():.4f}",
                    "loss_u": f"{loss_u.item():.4f}",
                    "mask": f"{mask.mean().item():.2f}"
                })

        acc = evaluate(ema.ema, test_loader, DEVICE)
        if acc > best_acc:
            best_acc = acc
            epochs_no_improve = 0
            torch.save({'model_state': ema.ema.state_dict(), 'acc': best_acc},
                       f'best_fixmatch_{n_labeled_per_class}_{model_name}.pth')
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}. Best acc={best_acc:.4f}")
            break

        def safe_mean(values): return sum(values)/len(values) if values else 0.0
        history["epoch"].append(epoch)
        history["loss_total"].append(safe_mean(loss_epoch))
        history["loss_supervised"].append(safe_mean(loss_sup_epoch))
        history["loss_unsupervised"].append(safe_mean(loss_unsup_epoch))
        history["mask_rate"].append(safe_mean(mask_rate_epoch))
        history["accuracy"].append(acc)

        print(f"Epoch {epoch}: acc={acc:.4f} (best={best_acc:.4f})")

    history_path = f"history_fixmatch_{model_name}_{n_labeled_per_class}.pt"
    torch.save(history, history_path)
    df = pd.DataFrame(history)
    csv_path = f"history_fixmatch_{model_name}_{n_labeled_per_class}.csv"
    df.to_csv(csv_path, index=False)
    print(f"Saved history CSV: {csv_path}")

    return best_acc, history

# ==========================================================
# Métricas de Avaliação 
# ==========================================================

@torch.no_grad()
def evaluate_with_metrics(model, loader, device, class_names=None, return_preds=False):
    model.eval()
    all_preds, all_probs, all_targets = [], [], []
    for x, y in loader:
        x = x.to(device)
        logits = model(x)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.append(preds)
        all_probs.append(probs.max(axis=1))
        all_targets.append(y.numpy())
    import numpy as np
    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)
    y_conf = np.concatenate(all_probs)

    acc = (y_true == y_pred).mean().item()
    pr, rc, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None, labels=range(NUM_CLASSES), zero_division=0)
    report = classification_report(y_true, y_pred, labels=range(NUM_CLASSES), target_names=class_names, zero_division=0, digits=4)
    cm = confusion_matrix(y_true, y_pred, labels=range(NUM_CLASSES))
    out = {"accuracy": acc, "precision_per_class": pr, "recall_per_class": rc, "f1_per_class": f1, "support_per_class": support,
           "report": report, "confusion_matrix": cm, "y_true": y_true, "y_pred": y_pred, "y_conf": y_conf}
    if return_preds:
        return out
    else:
        del out["y_true"]; del out["y_pred"]; del out["y_conf"]; return out

def plot_history_curves(history, title_prefix, save_dir="."):
    epochs = history["epoch"]
    plt.figure(); plt.plot(epochs, history["loss_supervised"]); plt.title(f"{title_prefix} - Supervised Loss"); plt.xlabel("epoch"); plt.ylabel("loss"); plt.tight_layout(); plt.savefig(os.path.join(save_dir, f"{title_prefix}_loss_sup.png")); plt.close()
    plt.figure(); plt.plot(epochs, history["loss_unsupervised"]); plt.title(f"{title_prefix} - Unsupervised Loss"); plt.xlabel("epoch"); plt.ylabel("loss"); plt.tight_layout(); plt.savefig(os.path.join(save_dir, f"{title_prefix}_loss_unsup.png")); plt.close()
    plt.figure(); plt.plot(epochs, history["mask_rate"]); plt.title(f"{title_prefix} - Mask Rate"); plt.xlabel("epoch"); plt.ylabel("rate"); plt.tight_layout(); plt.savefig(os.path.join(save_dir, f"{title_prefix}_mask_rate.png")); plt.close()
    plt.figure(); plt.plot(epochs, history["accuracy"]); plt.title(f"{title_prefix} - Test Accuracy (EMA)"); plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.tight_layout(); plt.savefig(os.path.join(save_dir, f"{title_prefix}_accuracy.png")); plt.close()

def denormalize(img_tensor, mean, std):
    # img_tensor: (C,H,W) em torch
    x = img_tensor.clone()
    for c in range(3):
        x[c] = x[c]*std[c] + mean[c]
    return torch.clamp(x, 0, 1)

@torch.no_grad()
def show_predictions_grid(model, dataset, transform_stats, n=16, save_path="pred_grid.png", class_names=None):
    model.eval()
    idxs = np.random.choice(len(dataset), size=n, replace=False)
    ncols = int(math.sqrt(n)); nrows = int(math.ceil(n / ncols))
    plt.figure(figsize=(1.6*ncols, 1.6*nrows))
    for i, idx in enumerate(idxs, 1):
        x, y = dataset[idx]
        logits = model(x.unsqueeze(0).to(DEVICE))
        prob = torch.softmax(logits, dim=1)
        conf, pred = prob.max(dim=1)
        img = denormalize(x.cpu(), transform_stats["mean"], transform_stats["std"]).permute(1,2,0).numpy()
        plt.subplot(nrows, ncols, i)
        plt.imshow(img)
        title = f"pred={class_names[pred.item()] if class_names else pred.item()} ({conf.item():.2f})\ntrue={class_names[y] if class_names else y}"
        plt.title(title, fontsize=8)
        plt.axis("off")
    plt.tight_layout(); plt.savefig(save_path); plt.close()


# ==========================================================
# Execução para múltiplos modelos
# ==========================================================
def run_all_models(root='./data', epochs=200, n_labeled_per_class_list=[1, 4, 25, 400]):
    factory = ConvolutionalModelFactory()
    models_list = list(factory.model_map.keys())
    results = {}
    for model_name in models_list:
        results[model_name] = {}
        for n_lab in n_labeled_per_class_list:
            print(f"\n=== Training {model_name} with {n_lab}/class ===")
            best_acc, _ = train_fixmatch(model_name=model_name, n_labeled_per_class=n_lab, epochs=epochs, root=root)
            results[model_name][n_lab] = best_acc
    torch.save(results, "fixmatch_model_comparison.pth")
    print("Saved comparison to fixmatch_model_comparison.pth")
    return results

# ==========================================================
# if __name__ == "__main__":
#     run_all_models(root='./data', epochs=200, n_labeled_per_class_list=[1, 4, 25, 400])


Using device: cuda


In [10]:
print("xsd Training")

xsd Training


In [11]:
def compare_fixmatch_resnet34_vs_vit(root='./data',
                                     n_labeled_per_class=25,
                                     epochs=100,
                                     patience=50):
    # Classes CIFAR-10
    class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

    # ====== Treino ResNet34 (mantém suas transforms de CIFAR-10) ======
    print(f"\n[ResNet34] FixMatch com CIFAR transforms | n_label/class={n_labeled_per_class}")
    best_acc_rn, hist_rn = train_fixmatch(
        model_name='resnet34',
        n_labeled_per_class=n_labeled_per_class,
        epochs=epochs,
        root=root,
        patience=patience,
        weak_t=weak_transform, strong_t=strong_transform, test_t=test_transform,
        batch_l=BATCH_L, batch_u=BATCH_U
    )

    # Carregar melhor EMA para avaliação detalhada
    rn_model = ConvolutionalModelFactory().get_model('resnet34', NUM_CLASSES, pretrained=True).to(DEVICE)
    rn_ckpt = torch.load(f'best_fixmatch_{n_labeled_per_class}_resnet34.pth', map_location=DEVICE)
    rn_ema = rn_model
    rn_ema.load_state_dict(rn_ckpt['model_state'])
    # Test loader com as transforms de teste de CIFAR
    _, _, test_ds_rn = make_cifar10_datasets(n_labeled_per_class, root=root,
                                             transform_labeled=weak_transform, transform_test=test_transform)
    test_loader_rn = DataLoader(test_ds_rn, batch_size=256, shuffle=False, num_workers=2)
    metrics_rn = evaluate_with_metrics(rn_ema, test_loader_rn, DEVICE, class_names=class_names)

    # Curvas e grid de previsões
    plot_history_curves(hist_rn, f"resnet34_{n_labeled_per_class}", save_dir=".")
    show_predictions_grid(rn_ema, test_ds_rn, transform_stats={"mean": (0.4914,0.4822,0.4465), "std": (0.247,0.243,0.261)},
                          n=16, save_path=f"pred_grid_resnet34_{n_labeled_per_class}.png", class_names=class_names)

    # ====== Treino ViT-B/16 (224, ImageNet transforms; batches menores) ======
    # Dica prática: entradas 224×224 exigem VRAM; ajuste de batches
    vit_batch_l = max(16, BATCH_L//2)
    vit_batch_u = max(32, BATCH_U//2)

    print(f"\n[ViT-B/16] FixMatch com ImageNet transforms | n_label/class={n_labeled_per_class}")
    best_acc_vit, hist_vit = train_fixmatch(
        model_name='vit_b16',
        n_labeled_per_class=n_labeled_per_class,
        epochs=epochs,
        root=root,
        patience=patience,
        weak_t=weak_transform_vit, strong_t=strong_transform_vit, test_t=test_transform_vit,
        batch_l=vit_batch_l, batch_u=vit_batch_u
    )

    vit_model = ConvolutionalModelFactory().get_model('vit_b16', NUM_CLASSES, pretrained=True).to(DEVICE)
    vit_ckpt = torch.load(f'best_fixmatch_{n_labeled_per_class}_vit_b16.pth', map_location=DEVICE)
    vit_ema = vit_model
    vit_ema.load_state_dict(vit_ckpt['model_state'])
    _, _, test_ds_vit = make_cifar10_datasets(n_labeled_per_class, root=root,
                                              transform_labeled=weak_transform_vit, transform_test=test_transform_vit)
    test_loader_vit = DataLoader(test_ds_vit, batch_size=128, shuffle=False, num_workers=2)
    metrics_vit = evaluate_with_metrics(vit_ema, test_loader_vit, DEVICE, class_names=class_names)

    plot_history_curves(hist_vit, f"vit_b16_{n_labeled_per_class}", save_dir=".")
    show_predictions_grid(vit_ema, test_ds_vit,
                          transform_stats={"mean": IMAGENET_MEAN, "std": IMAGENET_STD},
                          n=16, save_path=f"pred_grid_vit_b16_{n_labeled_per_class}.png", class_names=class_names)

    # ====== Sumário comparativo ======
    def summarize(name, best_acc, metrics):
        macro_f1 = np.mean(metrics["f1_per_class"])
        return {
            "model": name,
            "best_test_acc(EMA)": best_acc,
            "final_test_acc(EMA)": metrics["accuracy"],
            "macro_F1": macro_f1
        }

    import numpy as np, json
    summary = [summarize("ResNet34", best_acc_rn, metrics_rn),
               summarize("ViT-B/16", best_acc_vit, metrics_vit)]
    with open(f"comparison_summary_{n_labeled_per_class}.json", "w") as f:
        json.dump(summary, f, indent=2)

    # Registro adicional: relatórios por classe
    with open(f"classification_report_resnet34_{n_labeled_per_class}.txt", "w") as f:
        f.write(metrics_rn["report"])
    with open(f"classification_report_vit_b16_{n_labeled_per_class}.txt", "w") as f:
        f.write(metrics_vit["report"])

    # Confusion matrices salvas como imagens
    def plot_cm(cm, title, path):
        plt.figure(figsize=(6,5))
        plt.imshow(cm, interpolation='nearest')
        plt.title(title); plt.xlabel("Pred"); plt.ylabel("True")
        plt.colorbar(); plt.tight_layout(); plt.savefig(path); plt.close()
    plot_cm(metrics_rn["confusion_matrix"], f"CM ResNet34 (n={n_labeled_per_class})",
            f"cm_resnet34_{n_labeled_per_class}.png")
    plot_cm(metrics_vit["confusion_matrix"], f"CM ViT-B/16 (n={n_labeled_per_class})",
            f"cm_vit_b16_{n_labeled_per_class}.png")

    print("\n=== COMPARAÇÃO FINAL ===")
    for row in summary:
        print(row)

    return {
        "resnet34": {"best_acc": best_acc_rn, "metrics": metrics_rn, "history": hist_rn},
        "vit_b16":  {"best_acc": best_acc_vit, "metrics": metrics_vit, "history": hist_vit},
        "summary": summary
    }

In [12]:
if __name__ == "__main__":
    # Exemplo: 25 rótulos por classe, 200 épocas
    compare_fixmatch_resnet34_vs_vit(root=='./data', n_labeled_per_class=25, epochs=200, patience=50)


[ResNet34] FixMatch com CIFAR transforms | n_label/class=25


Epoch 0: 100%|██████████| 3/3 [00:00<00:00,  5.50it/s]


Epoch 0: acc=0.1015 (best=0.1015)


Epoch 1: 100%|██████████| 3/3 [00:00<00:00,  8.74it/s]


Epoch 1: acc=0.1020 (best=0.1020)


Epoch 2: 100%|██████████| 3/3 [00:00<00:00,  8.97it/s]


Epoch 2: acc=0.1034 (best=0.1034)


Epoch 3: 100%|██████████| 3/3 [00:00<00:00,  9.65it/s]


Epoch 3: acc=0.1047 (best=0.1047)


Epoch 4: 100%|██████████| 3/3 [00:00<00:00,  9.07it/s]


Epoch 4: acc=0.1054 (best=0.1054)


Epoch 5: 100%|██████████| 3/3 [00:00<00:00,  8.17it/s]


Epoch 5: acc=0.1067 (best=0.1067)


Epoch 6: 100%|██████████| 3/3 [00:00<00:00,  9.31it/s]


Epoch 6: acc=0.1078 (best=0.1078)


Epoch 7: 100%|██████████| 3/3 [00:00<00:00,  8.13it/s]


Epoch 7: acc=0.1152 (best=0.1152)


Epoch 8: 100%|██████████| 3/3 [00:00<00:00,  9.15it/s]


Epoch 8: acc=0.1209 (best=0.1209)


Epoch 9: 100%|██████████| 3/3 [00:00<00:00,  8.06it/s]


Epoch 9: acc=0.1106 (best=0.1209)


Epoch 10: 100%|██████████| 3/3 [00:00<00:00,  8.53it/s]


Epoch 10: acc=0.1002 (best=0.1209)


Epoch 11: 100%|██████████| 3/3 [00:00<00:00,  8.74it/s]


Epoch 11: acc=0.0994 (best=0.1209)


Epoch 12: 100%|██████████| 3/3 [00:00<00:00,  9.34it/s]


Epoch 12: acc=0.0999 (best=0.1209)


Epoch 13: 100%|██████████| 3/3 [00:00<00:00,  8.21it/s]


Epoch 13: acc=0.1000 (best=0.1209)


Epoch 14: 100%|██████████| 3/3 [00:00<00:00,  9.26it/s]


Epoch 14: acc=0.1000 (best=0.1209)


Epoch 15: 100%|██████████| 3/3 [00:00<00:00,  7.95it/s]


Epoch 15: acc=0.1000 (best=0.1209)


Epoch 16: 100%|██████████| 3/3 [00:00<00:00,  8.91it/s]


Epoch 16: acc=0.1000 (best=0.1209)


Epoch 17: 100%|██████████| 3/3 [00:00<00:00,  8.33it/s]


Epoch 17: acc=0.1000 (best=0.1209)


Epoch 18: 100%|██████████| 3/3 [00:00<00:00,  8.32it/s]


Epoch 18: acc=0.1000 (best=0.1209)


Epoch 19: 100%|██████████| 3/3 [00:00<00:00,  7.95it/s]


Epoch 19: acc=0.1000 (best=0.1209)


Epoch 20: 100%|██████████| 3/3 [00:00<00:00,  8.21it/s]


Epoch 20: acc=0.1000 (best=0.1209)


Epoch 21: 100%|██████████| 3/3 [00:00<00:00,  8.43it/s]


Epoch 21: acc=0.1000 (best=0.1209)


Epoch 22: 100%|██████████| 3/3 [00:00<00:00,  9.12it/s]


Epoch 22: acc=0.1000 (best=0.1209)


Epoch 23: 100%|██████████| 3/3 [00:00<00:00,  8.15it/s]


Epoch 23: acc=0.1000 (best=0.1209)


Epoch 24: 100%|██████████| 3/3 [00:00<00:00,  7.99it/s]


Epoch 24: acc=0.1000 (best=0.1209)


Epoch 25: 100%|██████████| 3/3 [00:00<00:00,  8.96it/s]


Epoch 25: acc=0.1000 (best=0.1209)


Epoch 26: 100%|██████████| 3/3 [00:00<00:00,  8.55it/s]


Epoch 26: acc=0.1000 (best=0.1209)


Epoch 27: 100%|██████████| 3/3 [00:00<00:00,  8.28it/s]


Epoch 27: acc=0.1000 (best=0.1209)


Epoch 28: 100%|██████████| 3/3 [00:00<00:00,  8.57it/s]


Epoch 28: acc=0.1000 (best=0.1209)


Epoch 29: 100%|██████████| 3/3 [00:00<00:00,  8.63it/s]


Epoch 29: acc=0.1000 (best=0.1209)


Epoch 30: 100%|██████████| 3/3 [00:00<00:00,  8.54it/s]


Epoch 30: acc=0.1000 (best=0.1209)


Epoch 31: 100%|██████████| 3/3 [00:00<00:00,  8.15it/s]


Epoch 31: acc=0.1000 (best=0.1209)


Epoch 32: 100%|██████████| 3/3 [00:00<00:00,  8.17it/s]


Epoch 32: acc=0.1000 (best=0.1209)


Epoch 33: 100%|██████████| 3/3 [00:00<00:00,  8.46it/s]


Epoch 33: acc=0.1000 (best=0.1209)


Epoch 34: 100%|██████████| 3/3 [00:00<00:00,  8.35it/s]


Epoch 34: acc=0.1000 (best=0.1209)


Epoch 35: 100%|██████████| 3/3 [00:00<00:00,  8.50it/s]


Epoch 35: acc=0.1000 (best=0.1209)


Epoch 36: 100%|██████████| 3/3 [00:00<00:00,  9.44it/s]


Epoch 36: acc=0.1000 (best=0.1209)


Epoch 37: 100%|██████████| 3/3 [00:00<00:00,  8.51it/s]


Epoch 37: acc=0.1000 (best=0.1209)


Epoch 38: 100%|██████████| 3/3 [00:00<00:00,  7.74it/s]


Epoch 38: acc=0.1283 (best=0.1283)


Epoch 39: 100%|██████████| 3/3 [00:00<00:00,  8.77it/s]


Epoch 39: acc=0.1000 (best=0.1283)


Epoch 40: 100%|██████████| 3/3 [00:00<00:00,  8.66it/s]


Epoch 40: acc=0.1000 (best=0.1283)


Epoch 41: 100%|██████████| 3/3 [00:00<00:00,  9.12it/s]


Epoch 41: acc=0.1000 (best=0.1283)


Epoch 42: 100%|██████████| 3/3 [00:00<00:00,  8.03it/s]


Epoch 42: acc=0.1000 (best=0.1283)


Epoch 43: 100%|██████████| 3/3 [00:00<00:00,  7.55it/s]


Epoch 43: acc=0.1000 (best=0.1283)


Epoch 44: 100%|██████████| 3/3 [00:00<00:00,  8.74it/s]


Epoch 44: acc=0.1000 (best=0.1283)


Epoch 45: 100%|██████████| 3/3 [00:00<00:00,  8.20it/s]


Epoch 45: acc=0.1000 (best=0.1283)


Epoch 46: 100%|██████████| 3/3 [00:00<00:00,  9.00it/s]


Epoch 46: acc=0.1000 (best=0.1283)


Epoch 47: 100%|██████████| 3/3 [00:00<00:00,  9.03it/s]


Epoch 47: acc=0.1000 (best=0.1283)


Epoch 48: 100%|██████████| 3/3 [00:00<00:00,  8.15it/s]


Epoch 48: acc=0.1000 (best=0.1283)


Epoch 49: 100%|██████████| 3/3 [00:00<00:00,  8.71it/s]


Epoch 49: acc=0.1000 (best=0.1283)


Epoch 50: 100%|██████████| 3/3 [00:00<00:00,  9.17it/s]


Epoch 50: acc=0.1000 (best=0.1283)


Epoch 51: 100%|██████████| 3/3 [00:00<00:00,  9.00it/s]


Epoch 51: acc=0.1000 (best=0.1283)


Epoch 52: 100%|██████████| 3/3 [00:00<00:00,  9.10it/s]


Epoch 52: acc=0.1000 (best=0.1283)


Epoch 53: 100%|██████████| 3/3 [00:00<00:00,  8.90it/s]


Epoch 53: acc=0.1000 (best=0.1283)


Epoch 54: 100%|██████████| 3/3 [00:00<00:00,  8.94it/s]


Epoch 54: acc=0.1000 (best=0.1283)


Epoch 55: 100%|██████████| 3/3 [00:00<00:00,  9.02it/s]


Epoch 55: acc=0.1000 (best=0.1283)


Epoch 56: 100%|██████████| 3/3 [00:00<00:00,  8.26it/s]


Epoch 56: acc=0.1000 (best=0.1283)


Epoch 57: 100%|██████████| 3/3 [00:00<00:00,  8.62it/s]


Epoch 57: acc=0.1000 (best=0.1283)


Epoch 58: 100%|██████████| 3/3 [00:00<00:00,  7.98it/s]


Epoch 58: acc=0.1000 (best=0.1283)


Epoch 59: 100%|██████████| 3/3 [00:00<00:00,  8.02it/s]


Epoch 59: acc=0.1000 (best=0.1283)


Epoch 60: 100%|██████████| 3/3 [00:00<00:00,  8.91it/s]


Epoch 60: acc=0.1000 (best=0.1283)


Epoch 61: 100%|██████████| 3/3 [00:00<00:00,  8.36it/s]


Epoch 61: acc=0.1000 (best=0.1283)


Epoch 62: 100%|██████████| 3/3 [00:00<00:00,  7.97it/s]


Epoch 62: acc=0.1000 (best=0.1283)


Epoch 63: 100%|██████████| 3/3 [00:00<00:00,  8.47it/s]


Epoch 63: acc=0.1000 (best=0.1283)


Epoch 64: 100%|██████████| 3/3 [00:00<00:00,  8.90it/s]


Epoch 64: acc=0.1000 (best=0.1283)


Epoch 65: 100%|██████████| 3/3 [00:00<00:00,  8.78it/s]


Epoch 65: acc=0.1000 (best=0.1283)


Epoch 66: 100%|██████████| 3/3 [00:00<00:00,  8.26it/s]


Epoch 66: acc=0.1000 (best=0.1283)


Epoch 67: 100%|██████████| 3/3 [00:00<00:00,  8.09it/s]


Epoch 67: acc=0.1000 (best=0.1283)


Epoch 68: 100%|██████████| 3/3 [00:00<00:00,  8.83it/s]


Epoch 68: acc=0.1000 (best=0.1283)


Epoch 69: 100%|██████████| 3/3 [00:00<00:00,  8.24it/s]


Epoch 69: acc=0.1000 (best=0.1283)


Epoch 70: 100%|██████████| 3/3 [00:00<00:00,  8.88it/s]


Epoch 70: acc=0.1000 (best=0.1283)


Epoch 71: 100%|██████████| 3/3 [00:00<00:00,  7.84it/s]


Epoch 71: acc=0.1000 (best=0.1283)


Epoch 72: 100%|██████████| 3/3 [00:00<00:00,  8.25it/s]


Epoch 72: acc=0.1000 (best=0.1283)


Epoch 73: 100%|██████████| 3/3 [00:00<00:00,  8.18it/s]


Epoch 73: acc=0.1000 (best=0.1283)


Epoch 74: 100%|██████████| 3/3 [00:00<00:00,  8.88it/s]


Epoch 74: acc=0.1000 (best=0.1283)


Epoch 75: 100%|██████████| 3/3 [00:00<00:00,  8.07it/s]


Epoch 75: acc=0.1000 (best=0.1283)


Epoch 76: 100%|██████████| 3/3 [00:00<00:00,  7.95it/s]


Epoch 76: acc=0.1000 (best=0.1283)


Epoch 77: 100%|██████████| 3/3 [00:00<00:00,  9.22it/s]


Epoch 77: acc=0.1000 (best=0.1283)


Epoch 78: 100%|██████████| 3/3 [00:00<00:00,  9.04it/s]


Epoch 78: acc=0.1000 (best=0.1283)


Epoch 79: 100%|██████████| 3/3 [00:00<00:00,  8.15it/s]


Epoch 79: acc=0.1000 (best=0.1283)


Epoch 80: 100%|██████████| 3/3 [00:00<00:00,  9.18it/s]


Epoch 80: acc=0.1000 (best=0.1283)


Epoch 81: 100%|██████████| 3/3 [00:00<00:00,  8.27it/s]


Epoch 81: acc=0.1000 (best=0.1283)


Epoch 82: 100%|██████████| 3/3 [00:00<00:00,  7.81it/s]


Epoch 82: acc=0.1000 (best=0.1283)


Epoch 83: 100%|██████████| 3/3 [00:00<00:00,  8.59it/s]


Epoch 83: acc=0.1000 (best=0.1283)


Epoch 84: 100%|██████████| 3/3 [00:00<00:00,  7.84it/s]


Epoch 84: acc=0.1000 (best=0.1283)


Epoch 85: 100%|██████████| 3/3 [00:00<00:00,  7.86it/s]


Epoch 85: acc=0.1000 (best=0.1283)


Epoch 86: 100%|██████████| 3/3 [00:00<00:00,  8.17it/s]


Epoch 86: acc=0.1000 (best=0.1283)


Epoch 87: 100%|██████████| 3/3 [00:00<00:00,  7.92it/s]


Epoch 87: acc=0.1000 (best=0.1283)


Epoch 88: 100%|██████████| 3/3 [00:00<00:00,  8.88it/s]



Early stopping at epoch 89. Best acc=0.1283
Saved history CSV: history_fixmatch_resnet34_25.csv

[ViT-B/16] FixMatch com ImageNet transforms | n_label/class=25


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:03<00:00, 113MB/s]  
Epoch 0: 100%|██████████| 7/7 [00:27<00:00,  3.98s/it]


Epoch 0: acc=0.1681 (best=0.1681)


Epoch 1: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 1: acc=0.1774 (best=0.1774)


Epoch 2: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 2: acc=0.1882 (best=0.1882)


Epoch 3: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 3: acc=0.2003 (best=0.2003)


Epoch 4: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 4: acc=0.2114 (best=0.2114)


Epoch 5: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 5: acc=0.2201 (best=0.2201)


Epoch 6: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 6: acc=0.2278 (best=0.2278)


Epoch 7: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 7: acc=0.2333 (best=0.2333)


Epoch 8: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 8: acc=0.2367 (best=0.2367)


Epoch 9: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 9: acc=0.2363 (best=0.2367)


Epoch 10: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 10: acc=0.2256 (best=0.2367)


Epoch 11: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 11: acc=0.2154 (best=0.2367)


Epoch 12: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 12: acc=0.2055 (best=0.2367)


Epoch 13: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 13: acc=0.1960 (best=0.2367)


Epoch 14: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 14: acc=0.1853 (best=0.2367)


Epoch 15: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 15: acc=0.1747 (best=0.2367)


Epoch 16: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 16: acc=0.1661 (best=0.2367)


Epoch 17: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 17: acc=0.1571 (best=0.2367)


Epoch 18: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 18: acc=0.1491 (best=0.2367)


Epoch 19: 100%|██████████| 7/7 [00:27<00:00,  3.86s/it]


Epoch 19: acc=0.1449 (best=0.2367)


Epoch 20: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 20: acc=0.1414 (best=0.2367)


Epoch 21: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 21: acc=0.1380 (best=0.2367)


Epoch 22: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 23: acc=0.1350 (best=0.2367)


Epoch 24: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 24: acc=0.1370 (best=0.2367)


Epoch 25: 100%|██████████| 7/7 [00:26<00:00,  3.85s/it]


Epoch 25: acc=0.1364 (best=0.2367)


Epoch 26: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 26: acc=0.1349 (best=0.2367)


Epoch 27: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 27: acc=0.1298 (best=0.2367)


Epoch 28: 100%|██████████| 7/7 [00:26<00:00,  3.85s/it]


Epoch 28: acc=0.1296 (best=0.2367)


Epoch 29: 100%|██████████| 7/7 [00:26<00:00,  3.85s/it]


Epoch 29: acc=0.1288 (best=0.2367)


Epoch 30: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 30: acc=0.1307 (best=0.2367)


Epoch 31: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 31: acc=0.1326 (best=0.2367)


Epoch 32: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 32: acc=0.1327 (best=0.2367)


Epoch 33: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 33: acc=0.1335 (best=0.2367)


Epoch 34: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 34: acc=0.1340 (best=0.2367)


Epoch 35: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 35: acc=0.1346 (best=0.2367)


Epoch 36: 100%|██████████| 7/7 [00:27<00:00,  3.86s/it]


Epoch 36: acc=0.1350 (best=0.2367)


Epoch 37: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 37: acc=0.1322 (best=0.2367)


Epoch 38: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 38: acc=0.1299 (best=0.2367)


Epoch 39: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 39: acc=0.1267 (best=0.2367)


Epoch 40: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 40: acc=0.1259 (best=0.2367)


Epoch 41: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 41: acc=0.1291 (best=0.2367)


Epoch 42: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]
Epoch 43: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 43: acc=0.1415 (best=0.2367)


Epoch 44: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 44: acc=0.1508 (best=0.2367)


Epoch 45: 100%|██████████| 7/7 [00:26<00:00,  3.84s/it]


Epoch 45: acc=0.1590 (best=0.2367)


Epoch 46: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 46: acc=0.1632 (best=0.2367)


Epoch 47: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 47: acc=0.1675 (best=0.2367)


Epoch 48: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 48: acc=0.1740 (best=0.2367)


Epoch 49: 100%|██████████| 7/7 [00:26<00:00,  3.81s/it]


Epoch 49: acc=0.1804 (best=0.2367)


Epoch 50: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 50: acc=0.1871 (best=0.2367)


Epoch 51: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 51: acc=0.1907 (best=0.2367)


Epoch 52: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 52: acc=0.1947 (best=0.2367)


Epoch 53: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 53: acc=0.1982 (best=0.2367)


Epoch 54: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 54: acc=0.2000 (best=0.2367)


Epoch 55: 100%|██████████| 7/7 [00:26<00:00,  3.83s/it]


Epoch 55: acc=0.2002 (best=0.2367)


Epoch 56: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 56: acc=0.2025 (best=0.2367)


Epoch 57: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 57: acc=0.2035 (best=0.2367)


Epoch 58: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]



Early stopping at epoch 59. Best acc=0.2367
Saved history CSV: history_fixmatch_vit_b16_25.csv

=== COMPARAÇÃO FINAL ===
{'model': 'ResNet34', 'best_test_acc(EMA)': 0.1283, 'final_test_acc(EMA)': 0.1283, 'macro_F1': 0.042982158733577115}
{'model': 'ViT-B/16', 'best_test_acc(EMA)': 0.2367, 'final_test_acc(EMA)': 0.2367, 'macro_F1': 0.22230080453986437}


In [14]:
print("oi")

oi
