In [1]:
# Instalacje
!pip install pandas torch torchaudio lightning kagglehub scikit-learn ipython soundfile wandb gdown torchcodec

import os
import shutil
import random
import glob
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import kagglehub
import gdown
from pathlib import Path
import wandb

# Ustawienie ziarna losowo≈õci dla powtarzalno≈õci
pl.seed_everything(42)
wandb.login()



Seed set to 42
[34m[1mwandb[0m: Currently logged in as: [33mrymer[0m ([33mrymer-agh-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

***Integracja dzia≈Çania g.collab vs lokalne***


In [2]:
IS_COLAB = os.path.exists('/content')

local = not IS_COLAB ## powinno samo wykryƒá
if local:
    dataset_path = "dataset"
else:
    dataset_path = "/content/dataset"

In [3]:
# 1. Pobieranie Datasetu MAD
target_dir = dataset_path
if os.path.exists(target_dir) and len(os.listdir(target_dir)) > 0:
    print(f"Dataset ju≈º istnieje w '{target_dir}'.")
else:
    print("Pobieranie datasetu MAD...")
    path = kagglehub.dataset_download("junewookim/mad-dataset-military-audio-dataset")
    os.makedirs(target_dir, exist_ok=True)
    shutil.copytree(path, target_dir, dirs_exist_ok=True)
    print("Pobrano dataset MAD.")

# 2. Pobieranie Szum√≥w (z Twojego pliku)
noise_folder = dataset_path + "/noises"
os.makedirs(noise_folder, exist_ok=True)
url = "https://drive.google.com/drive/folders/14Q_0KNDXACkFQ2oTF1T-gnjIaNbNuaKL?usp=sharing"

if not list(Path(noise_folder).glob("*.wav")):
    print("Pobieranie szum√≥w z Google Drive...")
    try:
        gdown.download_folder(url, output=noise_folder, quiet=False, use_cookies=False)
        print("Pobrano szumy.")
    except Exception as e:
        print(f"B≈ÇƒÖd pobierania szum√≥w: {e}")
else:
    print(f"Szumy ju≈º istniejƒÖ w '{noise_folder}'.")

noise_files = list(glob.glob(os.path.join(noise_folder, "*.wav")))
print(f"Liczba dostƒôpnych plik√≥w szumu: {len(noise_files)}")


Dataset ju≈º istnieje w 'dataset'.
Szumy ju≈º istniejƒÖ w 'dataset/noises'.
Liczba dostƒôpnych plik√≥w szumu: 5


In [4]:
csv_path = dataset_path + "/MAD_dataset/training.csv"
df_full = pd.read_csv(csv_path)

# Mapowanie nazw kolumn
rename_map = {
    'filename': 'path',
    'class': 'label',
    'class_name': 'label'
}
df_full = df_full.rename(columns=rename_map)

# Funkcja naprawiajƒÖca ≈õcie≈ºki
def fix_path(path):
    path = str(path)
    if not path.startswith("training/"):
        return os.path.join("training", path)
    return path

df_full['path'] = df_full['path'].apply(fix_path)
print(f"Za≈Çadowano DataFrame: {len(df_full)} plik√≥w.")


Za≈Çadowano DataFrame: 6429 plik√≥w.


In [5]:
class CachedAudioDataset(Dataset):
    def __init__(self, df, root_dir, noise_files=None, training=True, target_len=150000, expansion_factor=1):
        """
        expansion_factor: Ile razy powieliƒá dataset w jednej epoce.
        Np. expansion_factor=5 sprawi, ≈ºe dataset 6000 plik√≥w bƒôdzie "widziany" jako 30000.
        Ka≈ºda kopia dostanie innƒÖ, losowƒÖ augmentacjƒô.
        """
        self.df = df.reset_index(drop=True)
        self.root_dir = os.path.abspath(str(root_dir).strip())
        self.noise_files = noise_files
        self.training = training
        self.target_len = target_len
        self.expansion_factor = expansion_factor if training else 1
        self.target_sr = 48000

        self.labels_to_indices = self.df.groupby('label').groups
        self.all_labels = list(self.labels_to_indices.keys())

        # 1. Cache SZUM√ìW
        self.cached_noises = []
        if noise_files:
            print(f"Cache'owanie {len(noise_files)} plik√≥w szumu...")
            for nf in noise_files:
                try:
                    wav, sr = torchaudio.load(nf)
                    if sr != self.target_sr: wav = T.Resample(sr, self.target_sr)(wav)
                    if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True)
                    self.cached_noises.append(wav)
                except: pass

        # 2. Cache DANYCH TRENINGOWYCH
        print(f"≈Åadowanie {len(self.df)} plik√≥w treningowych z: {self.root_dir}")
        self.audio_cache = []
        errors = 0
        for i, row in enumerate(self.df.itertuples()):
            csv_path = str(row.path).strip()
            full_path = os.path.join(self.root_dir, csv_path)

            try:
                if not os.path.exists(full_path):
                    raise FileNotFoundError("Plik nie istnieje")

                wav, sr = torchaudio.load(full_path)
                if sr != self.target_sr: wav = T.Resample(sr, self.target_sr)(wav)
                if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True)

                if wav.shape[1] < self.target_len:
                    wav = F.pad(wav, (0, self.target_len - wav.shape[1]))
                elif wav.shape[1] > self.target_len:
                    start = (wav.shape[1] - self.target_len) // 2
                    wav = wav[:, start:start+self.target_len]

                self.audio_cache.append(wav)

            except Exception as e:
                errors += 1
                self.audio_cache.append(torch.randn(1, self.target_len) * 0.001)

        if errors > 0:
            print(f" uwaga: {errors} plik√≥w nie za≈Çadowano (wstawiono szum).")
        else:
            print("sukces: Wszystkie pliki w pamiƒôci RAM.")

        if self.training and self.expansion_factor > 1:
            print(f"üöÄ DATASET ROZSZERZONY: {len(self.df)} plik√≥w -> {len(self)} wirtualnych pr√≥bek na epokƒô.")

    def aggressive_augment(self, waveform):
        # Prosta agresywna augmentacja
        gain = random.uniform(0.5, 1.5)
        waveform = waveform * gain


        #extra masking
        #freq_mask = T.FrequencyMasking(freq_mask_param=5)
        #time_mask = T.TimeMasking(time_mask_param=10)
        #waveform = freq_mask(waveform)
        #waveform = time_mask(waveform)

        if self.cached_noises and random.random() > 0.3:
            noise_wav = random.choice(self.cached_noises)
            sig_len = waveform.shape[1]
            if noise_wav.shape[1] < sig_len:
                repeats = int(sig_len / noise_wav.shape[1]) + 1
                curr_noise = noise_wav.repeat(1, repeats)[:, :sig_len]
            else:
                start = random.randint(0, noise_wav.shape[1] - sig_len)
                curr_noise = noise_wav[:, start:start+sig_len]
            snr_db = random.uniform(5.0, 25.0)
            signal_power = waveform.norm(p=2)
            noise_power = curr_noise.norm(p=2)
            if noise_power > 0:
                snr = 10 ** (snr_db / 20)
                scale = signal_power / (noise_power * snr + 1e-9)
                waveform = waveform + (curr_noise * scale)
        return waveform

    def __len__(self):
        # Dataset udaje, ≈ºe jest wiƒôkszy ni≈º w rzeczywisto≈õci
        return len(self.df) * self.expansion_factor

    def __getitem__(self, idx):
        # Mapujemy wirtualny indeks na prawdziwy
        real_idx = idx % len(self.df)

        wav_a = self.audio_cache[real_idx].clone()
        label_a = self.df.iloc[real_idx]['label']

        if self.training:
            wav_a = self.aggressive_augment(wav_a)

        # Positive
        idxs_p = self.labels_to_indices[label_a]
        possible_p = idxs_p.drop(real_idx, errors='ignore')
        idx_p = random.choice(possible_p) if len(possible_p) > 0 else real_idx

        wav_p = self.audio_cache[idx_p].clone()
        if self.training: wav_p = self.aggressive_augment(wav_p)

        # Negative
        label_n = random.choice([l for l in self.all_labels if l != label_a])
        idx_n = random.choice(self.labels_to_indices[label_n])
        wav_n = self.audio_cache[idx_n].clone()
        if self.training: wav_n = self.aggressive_augment(wav_n)

        return wav_a, wav_p, wav_n


In [6]:
class ResNetTripletGPU(pl.LightningModule):
    def __init__(self, df, root_dir, noise_files, margin=1.0, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=['df', 'root_dir', 'noise_files'])
        self.df = df
        self.root_dir = root_dir
        self.noise_files = noise_files

        # 1. Transformacja na GPU
        self.spec_layer = T.MelSpectrogram(
            sample_rate=48000, n_fft=1024, hop_length=512, n_mels=64, f_min=20, f_max=24000
        )
        self.db_layer = T.AmplitudeToDB()

        # 2. Backbone
        self.backbone = resnet18(weights='IMAGENET1K_V1')

        # Dostosowanie ResNet do 1 kana≈Çu
        original_conv1 = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            self.backbone.conv1.weight.data = original_conv1.weight.data.sum(dim=1, keepdim=True)

        # V2 version with more complex head
        self.backbone.fc = nn.Sequential(
            nn.Linear(512, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(512, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(256, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, 128)
        )

        # V3 version even bigger head
        # self.backbone.fc = nn.Sequential(
        #     nn.Linear(512, 1024, bias=False),
        #     nn.BatchNorm1d(1024),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(0.2),
        #     nn.Linear(1024, 512, bias=False),
        #     nn.BatchNorm1d(512),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(0.2),
        #     nn.Linear(512, 256, bias=False),
        #     nn.BatchNorm1d(256),
        #     nn.LeakyReLU(0.1, inplace=True),
        #     nn.Dropout(0.2),
        #     nn.Linear(256, 256, bias=False),
        #     nn.BatchNorm1d(256),
        #     nn.LeakyReLU(0.1, inplace=True),
        #     nn.Linear(256, 128)
        # )

        # self.backbone.fc = nn.Sequential(
        #     nn.Linear(512, 512, bias=False),
        #     nn.BatchNorm1d(512),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(0.3),
        #     nn.Linear(512, 256, bias=False),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(0.2),
        #     nn.Linear(256, 256, bias=False),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(256, 128)
        # )


        self.loss_fn = nn.TripletMarginLoss(margin=margin, p=2)

        #soft margin with cosine similarity
        #self.loss_fn = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - F.cosine_similarity(x, y))

    def compute_features(self, wav):
        # To dzieje siƒô na GPU!
        spec = self.spec_layer(wav)
        spec = self.db_layer(spec)
        spec = (spec - spec.mean()) / (spec.std() + 1e-6)
        return self.backbone(spec)

    def forward(self, x):
        return F.normalize(self.compute_features(x), p=2, dim=1)

    def training_step(self, batch, batch_idx):
        wav_a, wav_p, wav_n = batch
        emb_a = self(wav_a)
        emb_p = self(wav_p)
        emb_n = self(wav_n)

        loss = self.loss_fn(emb_a, emb_p, emb_n)
        acc = (F.pairwise_distance(emb_a, emb_p) < F.pairwise_distance(emb_a, emb_n)).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        wav_a, wav_p, wav_n = batch
        emb_a = self(wav_a)
        emb_p = self(wav_p)
        emb_n = self(wav_n)

        loss = self.loss_fn(emb_a, emb_p, emb_n)
        acc = (F.pairwise_distance(emb_a, emb_p) < F.pairwise_distance(emb_a, emb_n)).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        wav_a, wav_p, wav_n = batch
        emb_a = self(wav_a)
        emb_p = self(wav_p)
        emb_n = self(wav_n)

        loss = self.loss_fn(emb_a, emb_p, emb_n)
        acc = (F.pairwise_distance(emb_a, emb_p) < F.pairwise_distance(emb_a, emb_n)).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    def train_dataloader(self):
        train_df, _ = train_test_split(self.df, test_size=0.2, random_state=42, stratify=self.df['label'])
        ds = CachedAudioDataset(
            train_df,
            self.root_dir,
            self.noise_files,
            training=True,
            expansion_factor=5
        )
        return DataLoader(ds, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True)

    def val_dataloader(self):
        _, val_df = train_test_split(self.df, test_size=0.2, random_state=42, stratify=self.df['label'])
        ds = CachedAudioDataset(val_df, self.root_dir, noise_files=None, training=False)
        return DataLoader(ds, batch_size=64, shuffle=False, num_workers=2, persistent_workers=True)


In [7]:
import datetime

# ==========================================
# KONFIGURACJA UNIKALNEGO TRENINGU
# ==========================================

#OPCJALNIE: customowy dodatek nazy runu
model_big_name = "ResNet_Triplet_big_head_more_dropout"


# 1. Tworzymy unikalnƒÖ nazwƒô folderu na podstawie daty i godziny
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = f"{model_big_name}_{timestamp}"
checkpoint_dir = os.path.join("checkpoints", run_name)

# Tworzymy folder fizycznie
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"üìÇ Checkpointy z tego treningu trafiƒÖ do: {checkpoint_dir}")

# ==========================================
# CALLBACKI I LOGGER
# ==========================================

# 2. Checkpoint najlepszego modelu wg Accuracy
checkpoint_best = ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    dirpath=checkpoint_dir,
    filename="best-epoch={epoch:02d}-acc={val_acc:.4f}",
    save_top_k=1,
    auto_insert_metric_name=False
)

# 3. Checkpoint okresowy (co 5 epok)
checkpoint_periodic = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename="periodic-epoch={epoch:02d}",
    every_n_epochs=5,
    save_last=True,
    save_top_k=-1
)

# 4. WandB Logger
wandb_logger = WandbLogger(
    project="siamese-audio-classifier",
    entity="deep-neural-network-course",
    name=run_name,
    log_model=False
)

# ==========================================
# START TRENINGU
# ==========================================

if local:
    ROOT_DIR = "dataset/MAD_dataset"  # upewnij siƒô, ≈ºe to poprawna ≈õcie≈ºka!
else:
    ROOT_DIR = "/content/dataset/MAD_dataset" # upewnij siƒô, ≈ºe to poprawna ≈õcie≈ºka!


model = ResNetTripletGPU(
    df=df_full,
    root_dir=ROOT_DIR,
    noise_files=noise_files,
    margin=0.5
)

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    devices=1,
    logger=wandb_logger,
    callbacks=[checkpoint_best, checkpoint_periodic],
    log_every_n_steps=10,
    precision=32,
    gradient_clip_val=1.0
)

print(f"üöÄ Rozpoczynam trening: {run_name}")
trainer.fit(model)

print("Trening zako≈Ñczony.")
wandb.finish()


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 4050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


üìÇ Checkpointy z tego treningu trafiƒÖ do: checkpoints/ResNet_Triplet_big_head_more_dropout_2026-01-15_00-26-48
üöÄ Rozpoczynam trening: ResNet_Triplet_big_head_more_dropout_2026-01-15_00-26-48




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params | Mode  | FLOPs
-----------------------------------------------------------------
0 | spec_layer | MelSpectrogram    | 0      | train | 0    
1 | db_layer   | AmplitudeToDB     | 0      | train | 0    
2 | backbone   | ResNet            | 11.7 M | train | 0    
3 | loss_fn    | TripletMarginLoss | 0      | train | 0    
-----------------------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.656    Total estimated model params size (MB)
86        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

≈Åadowanie 1286 plik√≥w treningowych z: /home/iwo/GSN/siamese-audio-classifier/dataset/MAD_dataset
sukces: Wszystkie pliki w pamiƒôci RAM.
Cache'owanie 5 plik√≥w szumu...
≈Åadowanie 5143 plik√≥w treningowych z: /home/iwo/GSN/siamese-audio-classifier/dataset/MAD_dataset
sukces: Wszystkie pliki w pamiƒôci RAM.
üöÄ DATASET ROZSZERZONY: 5143 plik√≥w -> 25715 wirtualnych pr√≥bek na epokƒô.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Trening zako≈Ñczony.


0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà
train_acc,‚ñÅ‚ñÅ‚ñÅ‚ñÜ‚ñÜ‚ñÅ‚ñÅ‚ñà‚ñÜ‚ñà‚ñÜ‚ñà‚ñÜ‚ñÉ‚ñà‚ñà‚ñà‚ñÜ‚ñà‚ñà‚ñà‚ñà‚ñÅ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÜ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÜ‚ñà‚ñà
train_loss,‚ñà‚ñÖ‚ñÉ‚ñÑ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
trainer/global_step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà
val_acc,‚ñÅ‚ñÑ‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñá‚ñÖ‚ñÖ‚ñÜ‚ñá‚ñÖ‚ñá‚ñÜ‚ñÜ‚ñà‚ñÜ‚ñá‚ñà
val_loss,‚ñà‚ñÑ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÖ‚ñÇ‚ñÉ‚ñÇ‚ñÅ‚ñÇ‚ñÉ‚ñÅ

0,1
epoch,19.0
train_acc,1.0
train_loss,0.00304
trainer/global_step,8039.0
val_acc,0.97667
val_loss,0.03515


In [9]:
import os
import shutil
import glob
import datetime
import pandas as pd
from torch.utils.data import DataLoader
import time

#OPCJALNIE: customowy dodatek nazy runu
model_pet_name = "bigger_head_more_dropout_old" # np. V1, V2 itp.

print("üõ†Ô∏è KONFIGURACJA ≈öRODOWISKA I FOLDER√ìW...")

if IS_COLAB:
    from google.colab import drive
    # --- Montowanie Dysku Google ---
    if os.path.exists('/content/drive') and not os.path.exists('/content/drive/MyDrive'):
        print("‚ö†Ô∏è Wykryto b≈Çƒôdne montowanie. Naprawiam...")
        shutil.rmtree('/content/drive')
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive', force_remount=True)
    BASE_DIR = "/content/drive/MyDrive/studia/ProjektGsn/Models"
else:
    # ≈öcie≈ºka lokalna
    BASE_DIR = os.path.abspath("Models")

# Tworzymy unikalnƒÖ nazwƒô runu
if 'run_name' not in globals():
    run_name = f"Run_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

TARGET_RUN_DIR = os.path.join(BASE_DIR, run_name)
os.makedirs(TARGET_RUN_DIR, exist_ok=True)
print(f"‚úÖ Folder na wyniki: {TARGET_RUN_DIR}")

# ==========================================
# Testowanie najlepszego modelu
# ==========================================
print("\nüìä ROZPOCZYNAM TESTOWANIE NAJLEPSZEGO MODELU...")

if 'ROOT_DIR' not in globals():
    ROOT_DIR = "dataset/MAD_dataset" if not IS_COLAB else "/content/dataset/MAD_dataset"

test_csv_path = os.path.join(ROOT_DIR, "test.csv")
if not os.path.exists(test_csv_path):
    test_csv_path = os.path.join(ROOT_DIR, "test.csv")  # fallback lokalny

df_test = pd.read_csv(test_csv_path)

def fix_test_path(path):
    path = str(path)
    if not path.startswith("test/") and not path.startswith("training/"):
        return os.path.join("test", path)
    return path

df_test['path'] = df_test['path'].apply(fix_test_path)

test_ds = CachedAudioDataset(df_test, root_dir=ROOT_DIR, noise_files=None, training=False, expansion_factor=1)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)

# Pobieranie ≈õcie≈ºki najlepszego modelu
if 'checkpoint_best' in globals():
    best_local_path = checkpoint_best.best_model_path
    last_local_path = checkpoint_best.last_model_path
else:
    local_dir = f"checkpoints/{run_name}"
    files = glob.glob(f"{local_dir}/best*.ckpt")
    best_local_path = files[0] if files else None
    last_local_path = os.path.join(local_dir, "last.ckpt")

if not best_local_path or not os.path.exists(best_local_path):
    raise FileNotFoundError("Nie znaleziono modelu 'best' w folderze checkpoints!")

trainer.logger = None  # wy≈ÇƒÖcz logger
best_model = ResNetTripletGPU.load_from_checkpoint(best_local_path, df=df_full, root_dir=ROOT_DIR, noise_files=[])
time_ms_start = datetime.datetime.now()
results = trainer.test(best_model, dataloaders=test_loader)
time_ms_end = datetime.datetime.now()
time_taken = (time_ms_end - time_ms_start).total_seconds() * 1000
print(f"Czas testowania: {time_taken:.2f} ms")
avg_inference_time = time_taken / len(test_ds)
print(f"≈öredni czas inferencji na pr√≥bkƒô: {avg_inference_time:.4f} ms")

acc, loss = results[0]['test_acc'], results[0]['test_loss']

print(f"\nüìù WYNIKI BEST: ACC={acc:.4f}, LOSS={loss:.4f}")

# ==========================================
# Kopiowanie lub zapis lokalny
# ==========================================
print(f"\nüöö Zapis wynik√≥w do folderu: {TARGET_RUN_DIR}")

# A) Zapis najlepszego modelu
best_filename = f"{model_pet_name}BEST_ACC={acc:.4f}_LOSS={loss:.4f}_time={avg_inference_time:.4f}.ckpt"
target_best = os.path.join(TARGET_RUN_DIR, best_filename)
try:
    shutil.copy(best_local_path, target_best)
    print(f"‚úÖ Zapisano NAJLEPSZY model: {best_filename}")
except Exception as e:
    print(f"‚ùå B≈ÇƒÖd kopiowania BEST: {e}")

# B) Zapis ostatniego stanu
if last_local_path and os.path.exists(last_local_path):
    last_filename = "LAST_STATE.ckpt"
    target_last = os.path.join(TARGET_RUN_DIR, last_filename)
    try:
        shutil.copy(last_local_path, target_last)
        print(f"‚úÖ Zapisano OSTATNI model: {last_filename}")
    except Exception as e:
        print(f"‚ùå B≈ÇƒÖd kopiowania LAST: {e}")
else:
    print("‚ö†Ô∏è Nie znaleziono pliku 'last.ckpt'. Je≈õli chcesz, ustaw save_last=True przy treningu.")

print("\nüéâ Zako≈Ñczono! Sprawd≈∫ folder z wynikami.")


üõ†Ô∏è KONFIGURACJA ≈öRODOWISKA I FOLDER√ìW...
‚úÖ Folder na wyniki: /home/iwo/GSN/siamese-audio-classifier/Models/ResNet_Triplet_big_head_more_dropout_2026-01-15_00-26-48

üìä ROZPOCZYNAM TESTOWANIE NAJLEPSZEGO MODELU...
≈Åadowanie 1037 plik√≥w treningowych z: /home/iwo/GSN/siamese-audio-classifier/dataset/MAD_dataset
sukces: Wszystkie pliki w pamiƒôci RAM.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
       Test metric             DataLoader 0
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        test_acc            0.9334619045257568
        test_loss           0.09609925001859665
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚