[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/FM11pp3/VC_0312/blob/main/Untitled0.ipynb)



# VC_0312 - Notebook arrumado
Notebook dividido em: Configuracao -> Parte A (analise exploratoria + augmentations) -> Parte B (pesos pre-treinados para validar/testar) -> Anexos (treino + push GitHub).

**Como correr**
- Ajusta `DATA_DIR` se nao estiveres em Colab.
- Executa as celulas por ordem: Configuracao -> Parte A -> Parte B.
- As celulas de Anexos sao opcionais para treino de raiz e push.

## Configuracao

In [None]:
from pathlib import Path
import random
import zipfile
import urllib.request
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use("seaborn-v0_8")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

def seed_everything(seed: int = SEED) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything()
print(f"Device: {DEVICE} | Seed: {SEED}")


In [None]:
# Caminhos principais e dataset
REPO_ROOT = Path(".").resolve()
DATA_DIR = Path("/content/InfraredSolarModules") if Path("/content").exists() else REPO_ROOT / "InfraredSolarModules"
DATA_URL = "https://github.com/RaptorMaps/InfraredSolarModules/raw/master/2020-02-14_InfraredSolarModules.zip"
BASE_IMAGE_DIR = DATA_DIR / "images"
MODELS_DIR = REPO_ROOT / "models"
METRICS_DIR = REPO_ROOT / "metrics"
TRAIN_CSV = REPO_ROOT / "full_train_data_list.csv"
TEST_CSV = REPO_ROOT / "final_test_data_list.csv"

def ensure_dataset() -> None:
    """Descarrega o dataset apenas se nao existir localmente."""
    if BASE_IMAGE_DIR.exists():
        print(f"?? Dataset pronto em {BASE_IMAGE_DIR}")
        return
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    zip_path = DATA_DIR.with_suffix(".zip")
    print("?? A descarregar InfraredSolarModules (pode demorar)...")
    urllib.request.urlretrieve(DATA_URL, zip_path)
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(DATA_DIR.parent)
    print(f"?? Dataset extraido para {DATA_DIR}")

def load_dataframes(image_dir: Path = BASE_IMAGE_DIR):
    train_df = pd.read_csv(TRAIN_CSV)
    test_df = pd.read_csv(TEST_CSV)
    for df in (train_df, test_df):
        df["filename"] = df["path"].apply(lambda p: Path(p).name)
        df["path"] = df["filename"].apply(lambda n: image_dir / n)
    class_pairs = train_df[["class_name", "label"]].drop_duplicates().sort_values("label")
    classes_map = {row.class_name: int(row.label) for row in class_pairs.itertuples()}
    idx_to_class = {v: k for k, v in classes_map.items()}
    return train_df, test_df, classes_map, idx_to_class

ensure_dataset()
train_df, test_df, classes_map, idx_to_class = load_dataframes()
print(f"Train imgs: {len(train_df):,} | Test imgs: {len(test_df):,}")
display(train_df.head())


## Parte A - Analise exploratoria

In [None]:
# Distribuicao de classes (dataset original esta desbalanceado)
order = train_df["class_name"].value_counts().index
fig, ax = plt.subplots(figsize=(8, 6))
sns.countplot(data=train_df, y="class_name", order=order, palette="viridis", ax=ax)
ax.set_title("Distribuicao de classes (train)")
ax.bar_label(ax.containers[0], fontsize=8)
plt.tight_layout()
plt.show()


In [None]:
# Visualizar imagem original + transformer base (sem flips/rotacoes)
sample = train_df.sample(1, random_state=SEED).iloc[0]
img_path = Path(sample["path"])
if not img_path.exists():
    raise FileNotFoundError(f"Imagem nao encontrada: {img_path}. Confirma a celula de download do dataset.")

img = Image.open(img_path).convert("L")
preview_size = (64, 64)
preview_transform = T.Compose([
    T.Resize(preview_size),
    T.Grayscale(),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
])
transformed = preview_transform(img)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(img, cmap="gray")
axes[0].set_title("Original")
axes[0].axis("off")

axes[1].imshow((transformed.squeeze(0) * 0.5 + 0.5).clamp(0, 1), cmap="gray")
axes[1].set_title("Transformer base (sem flips/rotacoes)")
axes[1].axis("off")

plt.suptitle(f"Classe: {sample['class_name']} | ficheiro: {img_path.name}")
plt.tight_layout()
plt.show()


In [None]:
# Transformer base (sem flips/rotacoes) usado nos DataLoaders
IMAGE_SIZE = (64, 64)
base_transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.Grayscale(),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
])

train_transform = base_transform
test_transform = base_transform
print("Transformers definidos sem augmentations aleatorias.")


## Parte B - Pesos pre-treinados: validar/testar

In [None]:
# Garantir que os pesos dos modelos estao disponiveis
WEIGHT_URLS = {
    "model_A_final.pth": "https://raw.githubusercontent.com/FM11pp3/VC_0312/main/models/model_A_final.pth",
    "model_B_final.pth": "https://raw.githubusercontent.com/FM11pp3/VC_0312/main/models/model_B_final.pth",
    "model_C_final.pth": "https://raw.githubusercontent.com/FM11pp3/VC_0312/main/models/model_C_final.pth",
}
MODELS_DIR.mkdir(exist_ok=True)

def ensure_weights():
    for fname, url in WEIGHT_URLS.items():
        dest = MODELS_DIR / fname
        if dest.exists():
            print(f"?? {fname} ja existe")
            continue
        print(f"?? A descarregar {fname}...")
        urllib.request.urlretrieve(url, dest)
    print("Pronto.")

ensure_weights()


In [None]:
# Dataset, modelo e helpers para avaliacao
class SolarDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["path"]).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, int(row["label"])


def make_loader(df, transform, batch_size=256, shuffle=False, sampler=None):
    return DataLoader(
        SolarDataset(df, transform),
        batch_size=batch_size,
        shuffle=shuffle if sampler is None else False,
        sampler=sampler,
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )


class NetworkCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        dummy_input = torch.randn(1, 1, 64, 64)
        with torch.no_grad():
            x = self.pool(F.relu(self.conv1(dummy_input)))
            x = self.pool(F.relu(self.conv2(x)))
            flattened_size = torch.flatten(x, 1).shape[1]

        self.fc1 = nn.Linear(flattened_size, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def evaluate_model_metrics(model: nn.Module, loader: DataLoader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            preds = model(images).argmax(1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
    if not all_labels:
        return {"accuracy": 0.0, "f1_macro": 0.0, "f1_micro": 0.0, "f1_weighted": 0.0}
    return {
        "accuracy": accuracy_score(all_labels, all_preds),
        "f1_macro": f1_score(all_labels, all_preds, average="macro"),
        "f1_micro": f1_score(all_labels, all_preds, average="micro"),
        "f1_weighted": f1_score(all_labels, all_preds, average="weighted"),
    }


def evaluate_model(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            preds = model(images).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total if total else 0.0


def make_weighted_sampler(df: pd.DataFrame) -> WeightedRandomSampler:
    class_counts = df["label"].value_counts()
    class_weights = 1.0 / class_counts
    sample_weights = df["label"].map(class_weights).astype(float)
    return WeightedRandomSampler(
        weights=torch.as_tensor(sample_weights.values, dtype=torch.double),
        num_samples=len(sample_weights),
        replacement=True,
    )


# Preparar dataframes para os 3 modelos
anomaly_classes = sorted([c for c in classes_map if c != "No-Anomaly"])
classes_map_B = {cls: idx for idx, cls in enumerate(anomaly_classes)}

model_frames = {
    "A": {
        "num_classes": 2,
        "df": test_df.assign(label=test_df["class_name"].apply(lambda c: 0 if c == "No-Anomaly" else 1)),
        "weights": MODELS_DIR / "model_A_final.pth",
    },
    "B": {
        "num_classes": len(anomaly_classes),
        "df": test_df[test_df["class_name"] != "No-Anomaly"].assign(label=lambda d: d["class_name"].map(classes_map_B)),
        "weights": MODELS_DIR / "model_B_final.pth",
    },
    "C": {
        "num_classes": len(classes_map),
        "df": test_df.assign(label=lambda d: d["class_name"].map(classes_map)),
        "weights": MODELS_DIR / "model_C_final.pth",
    },
}

results = []
for key, cfg in model_frames.items():
    loader = make_loader(cfg["df"], test_transform, batch_size=256)
    model = NetworkCNN(cfg["num_classes"]).to(DEVICE)
    state = torch.load(cfg["weights"], map_location=DEVICE)
    model.load_state_dict(state)
    metrics = evaluate_model_metrics(model, loader)
    results.append({"Model": f"Model {key}", **metrics})
    print(f"Model {key}: acc={metrics['accuracy']:.3f} | f1_macro={metrics['f1_macro']:.3f} | f1_weighted={metrics['f1_weighted']:.3f}")

results_df = pd.DataFrame(results)
display(results_df)

metrics_path = METRICS_DIR / "final_test_metrics.csv"
if metrics_path.exists():
    print("Metricas exportadas no treino original:")
    display(pd.read_csv(metrics_path))


## Anexos - Treino de raiz e push para GitHub
Inclui treino rapido, validacao k-fold e DataLoader com WeightedRandomSampler para lidar com o desbalanceamento.


In [None]:
# Treino rapido (exemplo) ? usa o modelo C (12 classes) como base
from sklearn.model_selection import train_test_split, StratifiedKFold

# Variantes de treino/validacao (A/B/C padrao, D/E usam dados mix real+CGAN por defeito)
# A: binario (No-Anomaly vs Anomaly) - real
# B: 11 classes anomalias apenas        - real
# C: 12 classes completas               - real
# D: 11 classes anomalias apenas        - real + CGAN
# E: 12 classes completas               - real + CGAN
train_frames = {
    "A": {"num_classes": 2, "base": "A", "use_cgan": False},
    "B": {"num_classes": len(anomaly_classes), "base": "B", "use_cgan": False},
    "C": {"num_classes": len(classes_map), "base": "C", "use_cgan": False},
    "D": {"num_classes": len(anomaly_classes), "base": "B", "use_cgan": True},
    "E": {"num_classes": len(classes_map), "base": "C", "use_cgan": True},
}


def resolve_variant(key: str, use_cgan_balance: bool | None):
    key = key.upper()
    if key not in train_frames:
        raise ValueError(f"Modelo invalido: {key}. Use A, B, C, D ou E.")
    cfg = train_frames[key]
    base_key = cfg["base"]
    num_classes = cfg["num_classes"]
    use_cgan = cfg["use_cgan"] if use_cgan_balance is None else use_cgan_balance
    return base_key, num_classes, use_cgan


def build_train_df(base_df: pd.DataFrame, base_key: str):
    base_key = base_key.upper()
    if base_key == "A":
        df = base_df.assign(label=base_df["class_name"].apply(lambda c: 0 if c == "No-Anomaly" else 1))
    elif base_key == "B":
        df = base_df[base_df["class_name"] != "No-Anomaly"].assign(label=lambda d: d["class_name"].map(classes_map_B))
    elif base_key == "C":
        df = base_df.assign(label=lambda d: d["class_name"].map(classes_map))
    else:
        raise ValueError(f"Base invalida: {base_key}.")
    return df.reset_index(drop=True)


def get_base_df(use_cgan_balance: bool, target_per_class=None):
    if not use_cgan_balance:
        return train_df
    if "balance_with_cgan" not in globals():
        raise NameError("balance_with_cgan nao definido. Corre a celula da Parte C primeiro.")
    return balance_with_cgan(target_per_class=target_per_class)


def train_model(
    model_name: str,
    base_df: pd.DataFrame,
    num_classes: int,
    epochs: int = 3,
    lr: float = 1e-3,
    use_weighted_sampler: bool = True,
    batch_size: int = 128,
):
    train_split, val_split = train_test_split(
        base_df,
        test_size=0.2,
        stratify=base_df["label"],
        random_state=SEED,
    )
    sampler = make_weighted_sampler(train_split) if use_weighted_sampler else None
    train_loader = make_loader(
        train_split,
        train_transform,
        batch_size=batch_size,
        shuffle=not use_weighted_sampler,
        sampler=sampler,
    )
    val_loader = make_loader(val_split, test_transform, batch_size=batch_size)

    model = NetworkCNN(num_classes).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(images), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * labels.size(0)
        val_acc = evaluate_model(model, val_loader)
        print(
            f"Epoch {epoch + 1}/{epochs} | loss={running_loss / len(train_loader.dataset):.4f} | val_acc={val_acc:.3f}"
        )

    out_path = MODELS_DIR / f"{model_name}.pth"
    torch.save(model.state_dict(), out_path)
    print(f"Modelo guardado em {out_path}")
    return model


def train_model_variant(key: str, use_cgan_balance: bool | None = None, target_per_class=None, **kwargs):
    base_key, num_classes, use_cgan = resolve_variant(key, use_cgan_balance)
    base_df = get_base_df(use_cgan, target_per_class=target_per_class)
    df = build_train_df(base_df, base_key)
    name = f"model_{key}_scratch"
    return train_model(name, df, num_classes=num_classes, **kwargs)


def run_kfold_cv(
    model_name: str,
    base_df: pd.DataFrame,
    num_classes: int,
    n_splits: int = 5,
    epochs: int = 3,
    batch_size: int = 128,
    lr: float = 1e-3,
    use_weighted_sampler: bool = True,
):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    fold_rows = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(base_df, base_df["label"]), start=1):
        print(f"\nFold {fold_idx}/{n_splits}")
        train_split = base_df.iloc[train_idx].reset_index(drop=True)
        val_split = base_df.iloc[val_idx].reset_index(drop=True)
        sampler = make_weighted_sampler(train_split) if use_weighted_sampler else None

        train_loader = make_loader(
            train_split,
            train_transform,
            batch_size=batch_size,
            shuffle=not use_weighted_sampler,
            sampler=sampler,
        )
        val_loader = make_loader(val_split, test_transform, batch_size=batch_size)

        model = NetworkCNN(num_classes).to(DEVICE)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            for images, labels in train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad()
                loss = criterion(model(images), labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * labels.size(0)
            val_acc = evaluate_model(model, val_loader)
            print(
                f"  Epoch {epoch + 1}/{epochs} | loss={running_loss / len(train_loader.dataset):.4f} | val_acc={val_acc:.3f}"
            )

        metrics = evaluate_model_metrics(model, val_loader)
        fold_rows.append({"fold": fold_idx, **metrics})
        print(
            f"Fold {fold_idx} metrics: acc={metrics['accuracy']:.3f} | f1_macro={metrics['f1_macro']:.3f} | f1_weighted={metrics['f1_weighted']:.3f}"
        )

    fold_df = pd.DataFrame(fold_rows)
    metrics_path = METRICS_DIR / f"kfold_results_{model_name}.csv"
    METRICS_DIR.mkdir(exist_ok=True)
    fold_df.to_csv(metrics_path, index=False)
    print(f"Metricas K-fold guardadas em {metrics_path}")
    display(fold_df)
    return fold_df


def run_kfold_variant(key: str, use_cgan_balance: bool | None = None, target_per_class=None, **kwargs):
    base_key, num_classes, use_cgan = resolve_variant(key, use_cgan_balance)
    base_df = get_base_df(use_cgan, target_per_class=target_per_class)
    df = build_train_df(base_df, base_key)
    return run_kfold_cv(f"{key}", df, num_classes=num_classes, **kwargs)


def run_kfold_all(keys=("A", "B", "C", "D", "E"), target_per_class=None, **kwargs):
    summaries = {}
    for key in keys:
        print(f"\n===== Modelo {key} =====")
        _, _, use_cgan_default = resolve_variant(key, use_cgan_balance=None)
        summaries[key] = run_kfold_variant(key, use_cgan_balance=use_cgan_default, target_per_class=target_per_class, **kwargs)
    return summaries


# Exemplos (comentados):
# balanced_train_df = get_base_df(use_cgan_balance=True)  # gera e mistura reais+GAN
# run_kfold_variant("A", n_splits=5, epochs=3)            # Binario: real
# run_kfold_variant("D", n_splits=5, epochs=3)            # 11 classes anomalas: real + CGAN
# run_kfold_variant("B", n_splits=5, epochs=3)            # 11 classes anomalas: real
# run_kfold_variant("C", n_splits=5, epochs=3)            # 12 classes: real
# run_kfold_variant("E", n_splits=5, epochs=3)            # 12 classes: real + CGAN
# train_model_variant("E", epochs=5)                      # Treino rapido modelo E com dados mix


In [None]:
# Push rapido dos artefactos (usa HTTPS). Configura antes: git config user.email/name e token de acesso se precisa.
# Descomenta as linhas abaixo quando estiveres autenticado.
# !git status
# !git add models/*.pth metrics/*.csv
# !git commit -m "Add modelos e metricas atualizadas"
# !git push origin main


### Anexo - GAN (opcional)
Exemplo de como treinar um GAN condicional simples com o mesmo dataset (sem augmentations de flip/rotacao). Ajusta `num_epochs` e `batch_size` conforme recursos.


In [None]:
# GAN condicional simples (G e D totalmente conectados) para gerar imagens 64x64 em tons de cinza
LATENT_DIM = 128
NUM_CLASSES = len(classes_map)
GAN_OUT_DIR = REPO_ROOT / "cgan_generated_outputs"
GAN_OUT_DIR.mkdir(exist_ok=True)

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim: int, num_classes: int):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1 * 64 * 64),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        x = torch.cat((self.label_emb(labels), noise), dim=1)
        out = self.model(x)
        return out.view(noise.size(0), 1, 64, 64)

class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(1 * 64 * 64 + num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img, labels):
        x = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), dim=1)
        validity = self.model(x)
        return validity

def make_gan_loader(df, batch_size=64):
    return DataLoader(
        SolarDataset(df, train_transform),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        pin_memory=torch.cuda.is_available(),
    )

def train_cgan(train_df: pd.DataFrame, num_epochs: int = 5, batch_size: int = 64, lr: float = 2e-4, sample_every: int = 1):
    loader = make_gan_loader(train_df, batch_size=batch_size)
    generator = ConditionalGenerator(LATENT_DIM, NUM_CLASSES).to(DEVICE)
    discriminator = ConditionalDiscriminator(NUM_CLASSES).to(DEVICE)
    adversarial_loss = nn.BCELoss()

    opt_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    fixed_noise = torch.randn(NUM_CLASSES, LATENT_DIM, device=DEVICE)
    fixed_labels = torch.arange(NUM_CLASSES, device=DEVICE)

    for epoch in range(num_epochs):
        for imgs, labels in loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            valid = torch.ones(imgs.size(0), 1, device=DEVICE)
            fake = torch.zeros(imgs.size(0), 1, device=DEVICE)

            # Train generator
            opt_G.zero_grad()
            z = torch.randn(imgs.size(0), LATENT_DIM, device=DEVICE)
            gen_labels = labels
            gen_imgs = generator(z, gen_labels)
            g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)
            g_loss.backward()
            opt_G.step()

            # Train discriminator
            opt_D.zero_grad()
            real_loss = adversarial_loss(discriminator(imgs, labels), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            opt_D.step()

        print(f"Epoch {epoch+1}/{num_epochs} | D loss: {d_loss.item():.4f} | G loss: {g_loss.item():.4f}")

        if (epoch + 1) % sample_every == 0:
            with torch.no_grad():
                samples = generator(fixed_noise, fixed_labels).cpu()
                grid = vutils.make_grid(samples, nrow=4, normalize=True)
                out_path = GAN_OUT_DIR / f"cgan_samples_epoch_{epoch+1:03}.png"
                vutils.save_image(grid, out_path)
                print(f"Samples guardados em {out_path}")

    torch.save(generator.state_dict(), MODELS_DIR / "cgan_generator.pth")
    print("Gerador guardado em models/cgan_generator.pth")
    return generator

# Exemplo de uso (comentado): treinar GAN condicional usando o split de treino completo
# generator = train_cgan(train_df, num_epochs=5, batch_size=128, sample_every=1)


### Parte C - CGAN para classes minoritarias
Usa o gerador condicionado treinado para sintetizar imagens das classes com menos exemplos, guardar no disco e devolver um DataFrame balanceado (train + sinteticas). Depois podes treinar com `WeightedRandomSampler` para manter o balanceamento.


In [None]:
# Gerar imagens sinteticas para classes minoritarias com CGAN treinado
CGAN_WEIGHTS = REPO_ROOT / "cgan_generated_outputs" / "cgan_generator_minority_classes.pth"
GAN_AUG_DIR = DATA_DIR / "cgan_augmented"
GAN_AUG_DIR.mkdir(exist_ok=True)


def load_trained_generator(weights_path=CGAN_WEIGHTS):
    if not weights_path.exists():
        raise FileNotFoundError(f"Nao encontrei pesos do CGAN em {weights_path}. Treina o CGAN na celula anterior ou ajusta o caminho.")
    generator = ConditionalGenerator(LATENT_DIM, NUM_CLASSES).to(DEVICE)
    state = torch.load(weights_path, map_location=DEVICE)
    generator.load_state_dict(state)
    generator.eval()
    return generator


def generate_class_samples(generator, class_id: int, num_images: int, batch_size: int = 64):
    saved_paths = []
    remaining = num_images
    while remaining > 0:
        cur = min(batch_size, remaining)
        noise = torch.randn(cur, LATENT_DIM, device=DEVICE)
        labels = torch.full((cur,), class_id, device=DEVICE, dtype=torch.long)
        with torch.no_grad():
            imgs = generator(noise, labels).cpu()
        for i in range(cur):
            out_name = f"cgan_{class_id:02d}_{len(saved_paths)+1:05d}.png"
            out_path = GAN_AUG_DIR / out_name
            vutils.save_image(imgs[i], out_path, normalize=True)
            saved_paths.append(out_path)
        remaining -= cur
    return saved_paths


def balance_with_cgan(target_per_class: int | None = None, batch_size: int = 64):
    """Gera imagens sinteticas para cada classe ate target_per_class (default = max count atual)."""
    generator = load_trained_generator()
    class_counts = train_df["label"].value_counts().to_dict()
    if target_per_class is None:
        target_per_class = max(class_counts.values())
    new_rows = []
    for cls_name, cls_id in classes_map.items():
        cur = class_counts.get(cls_id, 0)
        if cur >= target_per_class:
            continue
        need = target_per_class - cur
        print(f"Classe {cls_name} (id={cls_id}): gerar {need} para equilibrar {cur}->{target_per_class}")
        paths = generate_class_samples(generator, cls_id, need, batch_size=batch_size)
        for p in paths:
            new_rows.append({"path": p, "class_name": cls_name, "label": cls_id})
    if not new_rows:
        print("Dataset ja esta balanceado; nada gerado.")
        return train_df
    synth_df = pd.DataFrame(new_rows)
    balanced_df = pd.concat([train_df, synth_df], ignore_index=True)
    print(f"Geradas {len(synth_df)} imagens sinteticas. Total final: {len(balanced_df)}")
    return balanced_df


# Exemplo (comentado): balancear para o maximo atual por classe e usar WeightedRandomSampler depois
# balanced_train_df = balance_with_cgan()
# sampler = make_weighted_sampler(balanced_train_df)
# loader = make_loader(balanced_train_df, train_transform, batch_size=128, sampler=sampler)
