In [None]:

!pip install --quiet scikit-learn matplotlib torchvision pillow tqdm scikit-image

import os
import random
import time
from pathlib import Path
from glob import glob
from tqdm import tqdm
import cv2

import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import models

from sklearn.metrics import (roc_auc_score, roc_curve, precision_recall_curve,
                             confusion_matrix, accuracy_score, classification_report,
                             auc)
import pandas as pd

DRIVE_MOUNT_POINT = '/content/drive'
COVER_DIR = '/content/drive/MyDrive/WIFD_dataset/STEGANO/cover'
OUTPUT_ROOT = '/content/drive/MyDrive/stegano_dataset'  # Katalog zawierający już podkatalogi: cover, stego, pseudo
IMAGE_SIZE = 256
PATCH_SIZE = 128
BATCH_SIZE = 12
EMBED_DIM = 256
NUM_EPOCHS = 15 # ZMIANA: Zwiększono z 8 do 15
LR = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
MARGIN = 1.0
MSE_THRESHOLD = 0.1
# Zmienna usunięta lub zignorowana, ponieważ wprowadzamy stały limit par: NEG_PER_POS_RATIO = 1.0
MAX_PAIRS_LIMIT = 1500 # ZMIANA: Zwiększono z 300 do 1500
RESULTS_DIR = os.path.join(OUTPUT_ROOT, 'results_stego_vs_pseudo')
SAVED_MODEL = os.path.join(RESULTS_DIR, 'siamese_stego_vs_pseudo_best.pth')
METRICS_CSV = os.path.join(RESULTS_DIR, 'metrics.csv')
SAMPLE_OUTPUT = os.path.join(RESULTS_DIR, 'samples') # Ten katalog teraz jest używany dynamicznie w podkatalogach epok

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ---- MOUNT DRIVE (Colab) ----
from google.colab import drive
try:
    drive.mount(DRIVE_MOUNT_POINT)
except Exception as e:
    print('Mount drive (jeśli uruchamiasz lokalnie, upewnij się, że ścieżka istnieje):', e)

os.makedirs(OUTPUT_ROOT, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)


all_cover_paths = sorted(glob(os.path.join(OUTPUT_ROOT, 'cover', '*')))
all_stego_paths = sorted(glob(os.path.join(OUTPUT_ROOT, 'stego', '*')))
all_pseudo_paths = sorted(glob(os.path.join(OUTPUT_ROOT, 'pseudo', '*')))

if len(all_cover_paths) == 0 or len(all_stego_paths) == 0 or len(all_pseudo_paths) == 0:
    raise RuntimeError(
        f'Nie znaleziono plików w gotowych katalogach! '
        f'Oczekiwano danych w {OUTPUT_ROOT}/{{cover, stego, pseudo}}'
    )

print('Załadowano gotowe dane:')
print(f'- Cover files: {len(all_cover_paths)}')
print(f'- Stego files: {len(all_stego_paths)}')
print(f'- Pseudo files: {len(all_pseudo_paths)}')


def calculate_mse(img1_path, img2_path, size=IMAGE_SIZE):
    """Oblicza błąd średniokwadratowy (MSE) między dwoma obrazami w skali szarości."""
    try:
        img1 = np.array(Image.open(img1_path).convert('L').resize((size, size)), dtype=np.float32) / 255.0
        img2 = np.array(Image.open(img2_path).convert('L').resize((size, size)), dtype=np.float32) / 255.0
    except Exception as e:
        print(f"Ostrzeżenie: Błąd ładowania lub przetwarzania pliku dla MSE: {e}. Zwracam dużą wartość MSE.")
        return 1.0

    if img1.shape != img2.shape:
        return 1.0

    diff = img1 - img2
    mse = np.mean(diff ** 2)
    return mse

def create_filtered_pairs(candidate_paths, label, max_limit):
    """
    Tworzy pary (Cover, Candidate) filtrowane przez próg MSE i etykietowane.
    Proces jest przerywany po osiągnięciu 'max_limit' par.
    """
    pairs = []
    desc = "Tworzenie par Stego (Label 1)" if label == 1 else "Tworzenie par Pseudo (Label 0)"

    # KROK 1: Mieszanie, aby zapewnić, że wybrane 300 par jest losowych (nie tylko z pierwszych plików Cover)
    cover_paths_shuffled = all_cover_paths.copy()
    random.shuffle(cover_paths_shuffled)

    # KROK 2: Iteracja z możliwością wczesnego przerwania
    for c_path in tqdm(cover_paths_shuffled, desc=desc, total=len(cover_paths_shuffled)):

        # Sprawdzanie i przerywanie głównej pętli
        if len(pairs) >= max_limit:
            break

        c_name = Path(c_path).name
        original_base_name = c_name.split('.')[0]
        related_candidates = [
            s for s in candidate_paths if Path(s).name.startswith(original_base_name)
        ]

        for s_path in related_candidates:
            # Dodatkowe sprawdzenie, aby upewnić się, że nie przekroczymy limitu w jednej iteracji
            if len(pairs) >= max_limit:
                 break

            mse = calculate_mse(c_path, s_path)

            if mse < MSE_THRESHOLD:
                pairs.append((c_path, s_path, label))

    return pairs

pairs_positive = create_filtered_pairs(all_stego_paths, label=1, max_limit=MAX_PAIRS_LIMIT)
print(f'Po MSE znaleziono par Stego (Label 1): {len(pairs_positive)}')

pairs_negative = create_filtered_pairs(all_pseudo_paths, label=0, max_limit=MAX_PAIRS_LIMIT)
print(f'Po MSE znaleziono par Pseudo (Label 0): {len(pairs_negative)}')


print(f'Zbalansowane (ograniczone) pary: Stego (L1): {len(pairs_positive)} | Pseudo (L0): {len(pairs_negative)}')

pairs = pairs_positive + pairs_negative
random.shuffle(pairs)
print('Łącznie par dla treningu/walidacji:', len(pairs))

def make_srm_bank():
    kernels = []
    kernels += [np.array([[0,0,0],[0,1,-1],[0,0,0]]), np.array([[0,0,0],[0,1,0],[-1,0,0]]), np.array([[0,0,0],[0,1,0],[0,-1,0]])]
    kernels += [np.array([[0,-1,0],[-1,4,-1],[0,-1,0]])]
    base = kernels.copy()
    for k in base:
        kernels.append(np.rot90(k)); kernels.append(np.rot90(k,2))
    kernels = kernels[:30]
    arrs = []
    for k in kernels:
        kf = np.array(k, dtype=np.float32)
        if kf.sum() != 0:
            kf = kf - kf.mean()
        arrs.append(kf)
    bank = np.stack([k for k in arrs])[:, None, :, :]
    return torch.from_numpy(bank)

SRM_BANK = make_srm_bank().float()
print('SRM bank (ilość filtrów):', SRM_BANK.shape[0])

class PairedStegoDataset(Dataset):
    def __init__(self, pairs, image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, use_srm=True):
        self.pairs = pairs
        self.image_size = image_size
        self.patch_size = patch_size
        self.use_srm = use_srm
        self.tf_full = T.Compose([T.Resize((self.image_size, self.image_size)), T.Grayscale(1), T.ToTensor()])
        self.tf_patch = T.Compose([T.Lambda(lambda img: self.random_patch(img, self.patch_size)), T.Resize((self.patch_size,self.patch_size)), T.Grayscale(1), T.ToTensor()])

    def random_patch(self, pil_img, size):
        w,h = pil_img.size
        if w < size or h < size:
            return pil_img.resize((size,size))
        x = random.randint(0, w-size)
        y = random.randint(0, h-size)
        return pil_img.crop((x,y,x+size,y+size))

    def apply_srm(self, img_tensor):
        device = img_tensor.device
        bank = SRM_BANK.to(device)
        pad = (bank.shape[2]//2, bank.shape[3]//2)
        img = img_tensor.unsqueeze(0)
        out = nn.functional.conv2d(img, bank, padding=pad)
        return out.squeeze(0)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        a_path, b_path, label = self.pairs[idx]

        try:
            a = Image.open(a_path).convert('RGB')
        except Exception:
            a = Image.new('RGB', (self.image_size, self.image_size), color = 'black')

        try:
            b = Image.open(b_path).convert('RGB')
        except Exception:
            b = Image.new('RGB', (self.image_size, self.image_size), color = 'black')

        a_full = self.tf_full(a); b_full = self.tf_full(b)
        a_patch = self.tf_patch(a); b_patch = self.tf_patch(b)

        if self.use_srm:
            a_full_srm = self.apply_srm(a_full); b_full_srm = self.apply_srm(b_full)
            a_patch_srm = self.apply_srm(a_patch); b_patch_srm = self.apply_srm(b_patch)
        else:
            a_full_srm = a_full.repeat(SRM_BANK.shape[0],1,1)
            b_full_srm = b_full.repeat(SRM_BANK.shape[0],1,1)
            a_patch_srm = a_patch.repeat(SRM_BANK.shape[0],1,1)
            b_patch_srm = b_patch.repeat(SRM_BANK.shape[0],1,1)

        return (a_full_srm, a_patch_srm), (b_full_srm, b_patch_srm), torch.tensor(label, dtype=torch.float32), a_path, b_path

# split
train_pairs = pairs[:int(0.8*len(pairs))]
val_pairs = pairs[int(0.8*len(pairs)):]
train_ds = PairedStegoDataset(train_pairs)
val_ds = PairedStegoDataset(val_pairs)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
print('Train/Val pairs:', len(train_ds), len(val_ds))

class MultiScaleBranch(nn.Module):
    def __init__(self, in_channels, embed_dim=EMBED_DIM):
        super().__init__()
        self.local = nn.Sequential(
            nn.Conv2d(in_channels,64,3,padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64,128,3,padding=1,stride=2), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256, embed_dim//2)
        )
        self.global_stream = nn.Sequential(
            nn.Conv2d(in_channels,32,5,padding=2), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32,64,3,padding=1,stride=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64,128,3,padding=1,stride=2), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(128, embed_dim//2)
        )
    def forward(self, full, patch):
        g = self.global_stream(full)
        l = self.local(patch)
        return torch.cat([g,l], dim=1)

class SiameseMulti(nn.Module):
    def __init__(self, in_channels, embed_dim=EMBED_DIM):
        super().__init__()
        self.branch = MultiScaleBranch(in_channels, embed_dim=embed_dim)
        self.head = nn.Sequential(nn.Linear(embed_dim,128), nn.ReLU(inplace=True), nn.Linear(128,1), nn.Sigmoid())

    def forward(self, a_full,a_patch,b_full,b_patch):
        ea = self.branch(a_full,a_patch)
        eb = self.branch(b_full,b_patch)
        d = torch.abs(ea - eb)
        out = self.head(d).squeeze(1)
        return out, ea, eb

model = SiameseMulti(in_channels=SRM_BANK.shape[0], embed_dim=EMBED_DIM).to(DEVICE)

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=MARGIN):
        super().__init__(); self.margin = margin
    def forward(self, emb1, emb2, label):
        d = torch.norm(emb1 - emb2, p=2, dim=1)
        loss_pos = label * 0.5 * torch.clamp(self.margin - d, min=0.0)**2
        loss_neg = (1.0 - label) * 0.5 * d**2
        return (loss_pos + loss_neg).mean()

criterion_contrast = ContrastiveLoss()
criterion_bce = nn.BCELoss()
opt = torch.optim.Adam(model.parameters(), lr=LR)

def train_epoch(model, loader, opt):
    model.train()
    running_loss = 0.0
    for (a_full,a_patch),(b_full,b_patch),labels,_,_ in tqdm(loader, desc="Trening"):
        a_full = a_full.to(DEVICE); a_patch = a_patch.to(DEVICE)
        b_full = b_full.to(DEVICE); b_patch = b_patch.to(DEVICE)
        labels = labels.to(DEVICE)
        opt.zero_grad()
        out, ea, eb = model(a_full,a_patch,b_full,b_patch)
        loss_c = criterion_contrast(ea, eb, labels)
        loss_b = criterion_bce(out, labels)
        loss = loss_c + 0.5 * loss_b
        loss.backward(); opt.step()
        running_loss += loss.item() * labels.size(0)
    return running_loss / len(loader.dataset)


def eval_model(model, loader):
    model.eval()
    ys=[]; preds=[]; emb_d=[]; paths=[]
    with torch.no_grad():
        for (a_full,a_patch),(b_full,b_patch),labels,a_paths,b_paths in tqdm(loader, desc="Walidacja"):
            a_full = a_full.to(DEVICE); a_patch = a_patch.to(DEVICE)
            b_full = b_full.to(DEVICE); b_patch = b_patch.to(DEVICE)

            out, ea, eb = model(a_full,a_patch,b_full,b_patch)

            ys.extend(labels.numpy().tolist())
            preds.extend(out.cpu().numpy().tolist())
            emb_d.extend(((ea-eb).pow(2).sum(dim=1).sqrt()).cpu().numpy().tolist())
            paths.extend(list(zip(a_paths, b_paths)))

    return np.array(ys), np.array(preds), np.array(emb_d), paths

def save_eval_artifacts(model, loader, ys_val, preds_val, val_paths, epoch, chosen_thr):
    """Generuje i zapisuje wszystkie metryki i wizualizacje do katalogu wyników."""
    # Obliczanie metryk
    try:
        auc_score = roc_auc_score(ys_val, preds_val)
    except Exception:
        auc_score = float('nan')

    fpr, tpr, thr = roc_curve(ys_val, preds_val)
    precision, recall, pr_thr = precision_recall_curve(ys_val, preds_val)
    pr_auc = auc(recall, precision)

    pred_labels = (preds_val >= chosen_thr).astype(int)
    cm = confusion_matrix(ys_val, pred_labels)
    acc = accuracy_score(ys_val, pred_labels)
    report = classification_report(ys_val, pred_labels, digits=4)

    metrics = {
        'epoch':[epoch],
        'val_auc_roc':[auc_score],
        'val_auc_pr':[pr_auc],
        'threshold':[chosen_thr],
        'accuracy':[acc]
    }
    df_metrics = pd.DataFrame(metrics)
    # Zapisz lub dołącz
    if epoch == 1:
        df_metrics.to_csv(METRICS_CSV, index=False)
    else:
        # Sprawdź, czy plik istnieje, aby uniknąć błędów, jeśli jest to pierwsza epoka
        if os.path.exists(METRICS_CSV):
            df_metrics.to_csv(METRICS_CSV, mode='a', header=False, index=False)
        else:
            df_metrics.to_csv(METRICS_CSV, index=False)


    # Upewnij się, że katalog dla epoki istnieje
    epoch_dir = os.path.join(RESULTS_DIR, f'epoch_{epoch:02d}')
    os.makedirs(epoch_dir, exist_ok=True)

    plt.figure(figsize=(6,6)); plt.plot(fpr, tpr, label=f'AUC ROC: {auc_score:.4f}'); plt.plot([0,1],[0,1],'--');
    plt.xlabel('FPR (Fałszywie pozytywne)'); plt.ylabel('TPR (Prawdziwie pozytywne)');
    plt.title(f'Krzywa ROC (Stego vs Pseudo) - Ep. {epoch}'); plt.grid(True); plt.legend();
    plt.savefig(os.path.join(epoch_dir, 'roc_curve.png'))
    plt.close()

    plt.figure(figsize=(6,4)); plt.plot(recall, precision, label=f'AUC PR: {pr_auc:.4f}');
    plt.xlabel('Recall (Czułość)'); plt.ylabel('Precision (Precyzja)');
    plt.title(f'Krzywa Precision-Recall - Ep. {epoch}'); plt.grid(True); plt.legend();
    plt.savefig(os.path.join(epoch_dir, 'pr_curve.png'))
    plt.close()

    plt.figure(figsize=(5,4)); plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues);
    plt.title(f'Macierz pomyłek - Ep. {epoch}'); plt.colorbar();
    plt.xticks([0,1], ['Pseudo (0)','Stego (1)'], rotation=45, ha='right'); plt.yticks([0,1], ['Pseudo (0)','Stego (1)'])
    plt.ylabel('Prawdziwa etykieta'); plt.xlabel('Etykieta predykcji');
    for i in range(2):
        for j in range(2):
            plt.text(j, i, cm[i,j], ha='center', va='center', color='white' if cm[i,j]>cm.max()/2 else 'black')
    plt.tight_layout()
    plt.savefig(os.path.join(epoch_dir, 'confusion_matrix.png'))
    plt.close()

    plt.figure(figsize=(6,4));
    plt.hist(preds_val[ys_val==0], bins=50, alpha=0.6, label='negatyw (Pseudo)');
    plt.hist(preds_val[ys_val==1], bins=50, alpha=0.6, label='pozytyw (Stego)');
    plt.legend();
    plt.title(f'Rozkład wyników predykcji - Ep. {epoch}');
    plt.xlabel('Wynik predykcji');
    plt.ylabel('Liczba próbek');
    plt.savefig(os.path.join(epoch_dir, 'scores_histogram.png'))
    plt.close()

    with open(os.path.join(epoch_dir, 'classification_report.txt'),'w') as f:
        f.write(f'--- Raport Klasyfikacji (Epoka {epoch}) ---\n')
        f.write(f'AUC ROC: {auc_score:.4f}\n')
        f.write(f'Accuracy: {acc:.4f}\n')
        f.write(f'Wybrany próg (Youden): {chosen_thr:.4f}\n')
        f.write('\nClassification Report:\n')
        f.write(str(report))

    print(f'Zapisano artefakty ewaluacyjne dla Epoki {epoch} do: {epoch_dir}')



    # Czcionka (dla kompatybilności Colab)
    font = None
    try:
        font_path = '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf'
        if os.path.exists(font_path):
            font = ImageFont.truetype(font_path, 16)
    except Exception:
        pass

    n_save = 10 # Zapisz 10 przykładów (5 najbardziej pewnych negatywów, 5 najbardziej pewnych pozytywów)
    # Wybierz najbardziej pewne pozytywy (najwyższy wynik) i najbardziej pewne negatywy (najniższy wynik)
    idxs = np.argsort(preds_val)[:n_save//2].tolist() + np.argsort(preds_val)[-n_save//2:].tolist()

    for k, idx in enumerate(idxs):
        a_path, b_path = val_paths[idx]
        score = preds_val[idx]
        lab = ys_val[idx]

        try:
            a = Image.open(a_path).convert('RGB').resize((IMAGE_SIZE,IMAGE_SIZE))
            b = Image.open(b_path).convert('RGB').resize((IMAGE_SIZE,IMAGE_SIZE))
        except Exception:
            continue

        # Połącz obraz A (Cover) i B (Stego/Pseudo) poziomo
        canvas = Image.new('RGB', (IMAGE_SIZE*2, IMAGE_SIZE+40), (240,240,240))
        canvas.paste(a, (0,0)); canvas.paste(b, (IMAGE_SIZE,0))
        d = ImageDraw.Draw(canvas)

        # Treść informacyjna
        true_label_text = f'PRAWDZIWY: {"Stego (1)" if int(lab) == 1 else "Pseudo (0)"}'

        # Predykcja
        predicted_label_int = 1 if score >= chosen_thr else 0
        predicted_label_text = f'PRED. {"Stego (1)" if predicted_label_int == 1 else "Pseudo (0)"}'

        # Czy predykcja jest poprawna?
        is_correct = (predicted_label_int == int(lab))
        text_color = (0, 128, 0) if is_correct else (255, 0, 0)

        # Etykiety nad obrazami
        d.text((IMAGE_SIZE//2 - 20, 5), 'Obraz A (Cover)', fill=(0,0,0), font=font)
        d.text((IMAGE_SIZE + IMAGE_SIZE//2 - 40, 5), 'Obraz B (Stego/Pseudo)', fill=(0,0,0), font=font)


        # Tekst pod obrazami (Metryki)
        txt = (
            f'{true_label_text} | '
            f'{predicted_label_text} | '
            f'Prawd. Stego: {score:.3f} | '
            f'Werdykt: {"ZGODNY" if is_correct else "BŁĘDNY"}'
        )
        d.text((10, IMAGE_SIZE+5), txt, fill=text_color, font=font)

        # Zapis do pliku
        filename_prefix = 'Poprawna' if is_correct else 'Bledna'
        canvas.save(os.path.join(epoch_dir, f'przyklad_{filename_prefix}_{k}_lab_{int(lab)}.png'))

    print(f'Zapisano {n_save} przykładowych wizualizacji dla Epoki {epoch} do: {epoch_dir}')


# ---- TRAIN LOOP ----
best_auc = 0.0
history = {'train_loss':[], 'val_auc':[]}

print('\n=== Rozpoczęcie Treningu (Stego vs Pseudo) ===')
for epoch in range(1, NUM_EPOCHS+1):
    print(f'EP. {epoch}/{NUM_EPOCHS}')
    tr_loss = train_epoch(model, train_loader, opt)

    # Walidacja
    ys_val, preds_val, emb_val, val_paths = eval_model(model, val_loader)

    try:
        current_auc = roc_auc_score(ys_val, preds_val)
    except Exception:
        current_auc = 0.0

    history['train_loss'].append(tr_loss); history['val_auc'].append(current_auc)
    print(f'Train loss: {tr_loss:.4f} | Val AUC: {current_auc:.4f}')

    # ---------------------------------------------
    # Zapis najlepszego modelu i artefaktów
    # ---------------------------------------------

    # Wybór progu dla Macierzy Pomyłek (Youden)
    try:
        fpr, tpr, thr = roc_curve(ys_val, preds_val)
        j = tpr - fpr; j_idx = np.argmax(j); chosen_thr = thr[j_idx]
    except Exception:
        chosen_thr = 0.5

    # Zawsze wylicz bieżące metryki dla CSV (nawet jeśli nie rekord)
    try:
        # Ponowne obliczenie dla aktualizacji CSV
        precision_temp, recall_temp, _ = precision_recall_curve(ys_val, preds_val)
        current_pr_auc = auc(recall_temp, precision_temp)
        current_acc = accuracy_score(ys_val, (preds_val >= chosen_thr).astype(int))
    except Exception:
        current_pr_auc = float('nan')
        current_acc = float('nan')


    if current_auc >= best_auc: # Zmienione na >= dla zapisu pierwszej epoki
        if current_auc > best_auc:
            best_auc = current_auc
            print(f'-> Osiągnięto nowy rekord AUC: {best_auc:.4f}')

        # 1. Zapis modelu
        torch.save(model.state_dict(), SAVED_MODEL)
        print(f'-> Zapisano najlepszy model do: {SAVED_MODEL}')

        # 2. Zapis wszystkich metryk i wizualizacji dla tej (nowej rekordowej lub pierwszej) epoki
        save_eval_artifacts(model, val_loader, ys_val, preds_val, val_paths, epoch, chosen_thr)
        print(f'-> Zapisano pełen zestaw artefaktów ewaluacyjnych dla Epoki {epoch}')

    # Aktualizacja CSV z bieżącymi wynikami epoki (niezależnie od tego, czy była rekordowa)
    metrics = {
        'epoch':[epoch],
        'val_auc_roc':[current_auc],
        'val_auc_pr':[current_pr_auc],
        'threshold':[chosen_thr],
        'accuracy':[current_acc]
    }
    df_metrics = pd.DataFrame(metrics)

    # Zapisz do CSV
    csv_exists = os.path.exists(METRICS_CSV)
    if not csv_exists and epoch == 1:
        df_metrics.to_csv(METRICS_CSV, index=False)
    elif csv_exists:
        df_metrics.to_csv(METRICS_CSV, mode='a', header=False, index=False)

print('\n=== Trening zakończony ===')
print(f'Najlepszy wynik AUC ROC: {best_auc:.4f}')
print(f'Wyniki metryk i wizualizacje są w katalogu: {RESULTS_DIR}')