#### 환경설정

##### 1. Wandb

In [5]:
import wandb

# wandb 로그인
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mvanillahub12[0m ([33mboaz_woony-boaz[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


##### 2. 라이브러리 로드

In [6]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
import os
import math
import random
import pickle
import wandb
from tqdm import tqdm
from datetime import datetime
from zoneinfo import ZoneInfo

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import librosa.display

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torchvision
import torchvision.models as models
from torch import Tensor
from torchsummary import summary
from torch.hub import load_state_dict_from_url
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.metrics import confusion_matrix, f1_score
from sklearn.manifold import TSNE

##### 3. 경로 설정

In [8]:
from google.colab import drive

drive.mount('/content/drive')

ROOT = "/content/drive/MyDrive/ADV 프로젝트/data/ICBHI/ICBHI_final_database"
CHECKPOINT_PATH = "/content/drive/MyDrive/ADV 프로젝트/checkpoints"
PICKLE_PATH = "/content/drive/MyDrive/ADV 프로젝트/pickle"
text = "/content/drive/MyDrive/ADV 프로젝트/data/ICBHI/ICBHI_challenge_train_test.txt"

Mounted at /content/drive


## 1. Data Load

#### 1.1 Data Load

In [9]:
# WAV 파일이 있는 디렉토리 경로
data_dir = ROOT
txt_dir = ROOT

df = pd.read_csv(text, sep='\t', header=None)

# 컬럼 이름 변경
df.columns = ['filename', 'set']

# train, test split
train_df = df[df['set'] == 'train']
test_df = df[df['set'] == 'test']

# filename list
train_list = sorted(train_df['filename'].tolist())
test_list = sorted(test_df['filename'].tolist())

print(f'Train :{len(train_list)}, Test: {len(test_list)}, Total: {len(train_list) + len(test_list)}')

Train :539, Test: 381, Total: 920


#### 1.2 Pretext-Finetune Split

In [10]:
# shuffle train data
df_shuffled = train_df.sample(frac=1, random_state=42)

# split ratio
train_size = int(0.8 * len(df_shuffled))

# pretrain, finetune split
pretrain_df = df_shuffled[:train_size]
finetune_df = df_shuffled[train_size:]

# filename list (pretext_list -> pretrain list)
pretrain_list = sorted(pretrain_df['filename'].tolist())
finetune_list = sorted(finetune_df['filename'].tolist())

# patient id list
pretrain_patient_list = []
for filename in pretrain_list:
    number = int(filename.split('_')[0])
    pretrain_patient_list.append(number)

finetune_patient_list = []
for filename in finetune_list:
    number = int(filename.split('_')[0])
    finetune_patient_list.append(number)

pretrain_patient_counts = pd.Series(pretrain_patient_list).value_counts()
finetune_patient_counts = pd.Series(finetune_patient_list).value_counts()

print(f"[Pretrain] 환자 수: {len(pretrain_patient_counts.index)}, 샘플 수: {pretrain_patient_counts.sum()}")
print(f"[Finetune] 환자 수: {len(finetune_patient_counts.index)}, 샘플 수: {finetune_patient_counts.sum()}")

[Pretrain] 환자 수: 74, 샘플 수: 431
[Finetune] 환자 수: 43, 샘플 수: 108


## 2. Data Preprocessing

#### 2.1 Args

        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)

In [11]:
class Args:
    # Audio & Spectrogram
    target_sr = 4000    # 4KHz
    frame_size = 1024
    hop_length = 512    # frame_size 절반
    n_mels = 128
    target_sec = 8

    # Augmentation
    time_mask_param = 0.5
    freq_mask_param = 0.5

    # Train
    lr = 0.03
    warm = True                     # warm-up 사용 여부
    warm_epochs = 10                # warm-up 적용할 초기 epoch 수
    warmup_from = lr * 0.1          # warm-up 시작 learning rate (보통 lr의 10%)
    warmup_to = lr

    batch_size = 128
    workers = 2
    epochs = 300
    weight_decay = 1e-4

    resume = None
    schedule=[120, 160] # schedule

    # MLS
    K = 512
    momentum = 0.999
    T = 0.07
    dim_prj = 128
    top_k = 20
    lambda_bce = 0.3
    out_dim = 2048

    # Linear Evaluation
    ft_epochs = 100

    # etc
    gpu = 0
    data = "./data_path"

args = Args()

#### 2.2 Utils (func)

In [12]:
import torch.nn.functional as F
import random

# cycle의 클래스를 추출
def get_class(cr, wh):
    if cr == 1 and wh == 1:
        return 3
    elif cr == 0 and wh == 1:
        return 2
    elif cr == 1 and wh == 0:
        return 1
    elif cr == 0 and wh == 0:
        return 0
    else:
        return -1

# Mel Spectrogram 생성 ( sr=4KHz, frame_size=1024, hop_length=512, n_mels=128 )
def generate_mel_spectrogram(waveform, sample_rate, frame_size, hop_length, n_mels):
    if hop_length is None:
        hop_length = frame_size // 2
    mel_spec_transform = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=frame_size,
        hop_length=hop_length,
        n_mels=n_mels
    )
    mel_spectrogram = mel_spec_transform(waveform)
    mel_db = T.AmplitudeToDB()(mel_spectrogram)
    return mel_db

# Cycle Repeat 또는 Crop
def repeat_or_truncate_segment(mel_segment, target_frames):
    current_frames = mel_segment.shape[-1]
    if current_frames >= target_frames:
        return mel_segment[:, :, :target_frames]
    else:
        repeat_ratio = math.ceil(target_frames / current_frames)
        mel_segment = mel_segment.repeat(1, 1, repeat_ratio)
        return mel_segment[:, :, :target_frames]

def preprocess_waveform_segment(waveform, unit_length):
    """unit_length 기준으로 waveform을 repeat + padding 또는 crop하여 길이 정규화"""
    waveform = waveform.squeeze(0)  # (1, L) → (L,) 로 바꿔도 무방
    length_adj = unit_length - len(waveform)

    if length_adj > 0:
        # waveform이 너무 짧은 경우 → repeat + zero-padding
        half_unit = unit_length // 2

        if length_adj < half_unit:
            # 길이 차이가 작으면 단순 padding
            half_adj = length_adj // 2
            waveform = F.pad(waveform, (half_adj, length_adj - half_adj))
        else:
            # 반복 후 부족한 부분 padding
            repeat_factor = unit_length // len(waveform)
            waveform = waveform.repeat(repeat_factor)[:unit_length]
            remaining = unit_length - len(waveform)
            half_pad = remaining // 2
            waveform = F.pad(waveform, (half_pad, remaining - half_pad))
    else:
        # waveform이 너무 길면 앞쪽 1/4 내에서 랜덤 crop
        length_adj = len(waveform) - unit_length
        start = random.randint(0, length_adj // 4)
        waveform = waveform[start:start + unit_length]

    return waveform.unsqueeze(0)  # 다시 (1, L)로

# 데이터 Spec Augmentation ( 0~80% Random Masking )
def apply_spec_augment(mel_segment):
    M = mel_segment.shape[-1]
    F = mel_segment.shape[-2]

    # torchaudio의 마스킹은 0부터 mask_param까지 균등분포에서 랜덤하게 길이를 선택
    time_masking = T.TimeMasking(time_mask_param=int(M * 0.8))
    freq_masking = T.FrequencyMasking(freq_mask_param=int(F * 0.8) )

    aug1 = freq_masking(mel_segment.clone())
    aug2 = time_masking(mel_segment.clone())
    aug3 = freq_masking(time_masking(mel_segment.clone()))

    return aug1, aug2, aug3

# Waveform resample
def resample_waveform(waveform, orig_sr, target_sr=args.target_sr):
    if orig_sr != target_sr:
        resampler = torchaudio.transforms.Resample(
            orig_freq=orig_sr,
            new_freq=target_sr
        )
        return resampler(waveform), target_sr
    return waveform, orig_sr

In [13]:
def aug(repeat_mel):
    aug1, aug2, aug3 = apply_spec_augment(repeat_mel)
    return aug1, aug2, aug3

def get_timestamp():
    """Outputs current time in KST like 2404070830"""
    kst_time = datetime.now(ZoneInfo("Asia/Seoul"))
    return kst_time.strftime('%y%m%d%H%M')

#### 2.3 CycleDataset

In [14]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm

class CycleDataset(Dataset):
    def __init__(self, filename_list, wav_dir, txt_dir, target_sec=args.target_sec, target_sr=args.target_sr, frame_size=args.frame_size, hop_length=args.hop_length, n_mels=args.n_mels):
        self.filename_list = filename_list
        self.wav_dir = wav_dir
        self.txt_dir = txt_dir
        self.target_sec = target_sec
        self.target_sr = target_sr
        self.frame_size = frame_size
        self.hop_length = hop_length
        self.n_mels = n_mels

        self.cycle_list = []

        print("[INFO] Preprocessing cycles...")
        for filename in tqdm(self.filename_list):
            txt_path = os.path.join(self.txt_dir, filename + '.txt')
            wav_path = os.path.join(self.wav_dir, filename + '.wav')

            if not os.path.exists(txt_path):
                print(f"[WARNING] Missing file: {txt_path}")
            if not os.path.exists(wav_path):
                print(f"[WARNING] Missing file: {wav_path}")

            # Load annotation
            cycle_data = np.loadtxt(txt_path, usecols=(0, 1))
            lung_label = np.loadtxt(txt_path, usecols=(2, 3))

            # Load waveform
            waveform, orig_sr = torchaudio.load(wav_path)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)  # Stereo to mono

            # Resample to target sample rate (4kHz)
            waveform, sample_rate = resample_waveform(waveform, orig_sr, self.target_sr)

            for idx in range(len(cycle_data)):
                # 호흡 주기 start, end
                start_sample = int(cycle_data[idx, 0] * sample_rate)
                end_sample = int(cycle_data[idx, 1] * sample_rate)
                lung_duration = cycle_data[idx, 1] - cycle_data[idx, 0]

                if end_sample <= start_sample:
                    continue  # 잘못된 구간 스킵

                # Waveform repeat + padding 후 Mel_db
                cycle_wave = waveform[:, start_sample:end_sample]
                normed_wave = preprocess_waveform_segment(cycle_wave, unit_length=int(self.target_sec * self.target_sr))
                mel = generate_mel_spectrogram(normed_wave, sample_rate, frame_size=self.frame_size, hop_length=self.hop_length, n_mels=self.n_mels)

                # crackle, wheeze -> class
                cr = int(lung_label[idx, 0])
                wh = int(lung_label[idx, 1])
                label = get_class(cr, wh)

                multi_label = torch.tensor([
                    float(label in [1, 3]),
                    float(label in [2, 3])
                ])  # 변환된 multi-label 반환

                # meta_data
                meta_data = (filename, lung_duration)

                self.cycle_list.append((mel, multi_label, meta_data))

        print(f"[INFO] Total cycles collected: {len(self.cycle_list)}")

    def __len__(self):
        return len(self.cycle_list)

    def __getitem__(self, idx):
        mel, label, meta_data = self.cycle_list[idx]
        return mel, label, meta_data

##### Pickle.dump

CycleDataset 객체 생성

In [15]:
# import random
# import matplotlib.pyplot as plt
# import librosa.display

# wav_dir = ROOT
# txt_dir = ROOT

# # 1. Dataset 로드
# train_dataset = CycleDataset(train_list, wav_dir, txt_dir)
# test_dataset = CycleDataset(test_list, wav_dir, txt_dir)

In [16]:
# # 2. 간단 통계
# print(f"Total cycles: {len(train_dataset)}")

# label_counter = [0] * 4  # normal, crackle, wheeze, both
# for _, multi_label,_ in train_dataset:
#     if torch.equal(multi_label, torch.tensor([0., 0.])):
#         label_counter[0] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 0.])):
#         label_counter[1] += 1
#     elif torch.equal(multi_label, torch.tensor([0., 1.])):
#         label_counter[2] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 1.])):
#         label_counter[3] += 1

# for idx, count in enumerate(label_counter):
#     print(f"Class {idx}: {count} cycles")

pickle로 train_dataset, test_dataset 외부 저장

In [17]:
# pickle_dict = {
#     'train_dataset': train_dataset,
#     'test_dataset': test_dataset
# }

# save_path = os.path.join(PICKLE_PATH, 'saved_datasets_multilabel.pkl')
# with open(save_path, 'wb') as f:
#     pickle.dump(pickle_dict, f)

##### Pickle.load
저장된 train_dataset, test_dataset을 로드

In [18]:
save_path = os.path.join(PICKLE_PATH, 'saved_datasets_multilabel.pkl')
with open(save_path, 'rb') as f:
    pickle_dict = pickle.load(f)

train_dataset = pickle_dict['train_dataset']
test_dataset = pickle_dict['test_dataset']

print(f"[Train] Cycles: {len(train_dataset)}")
print(f"[Test] Cycles: {len(train_dataset)}")

[Train] Cycles: 4142
[Test] Cycles: 4142


#### 2.4 DataLoader

In [19]:
seed = 42
random.seed(seed)
np.random.seed(seed)

# train_dataset 내에서 각 파일의 인덱스를 추출
pretrain_idx = []
finetune_idx = []

for i in range(len(train_dataset)):
    filename = train_dataset[i][2][0]

    if filename in pretrain_list:
        pretrain_idx.append(i)
    elif filename in finetune_list:
        finetune_idx.append(i)

# 인덱스 순서 셔플
random.shuffle(pretrain_idx)
random.shuffle(finetune_idx)

print(f"Pretrain set size: {len(pretrain_idx)}, Finetune set size: {len(finetune_idx)}")

Pretrain set size: 3257, Finetune set size: 885


코드 실행 환경에 따라 num_workers를 적절한 값으로 지정해주세요!

In [20]:
# Dataset 생성 (Subset)
pretrain_dataset = Subset(train_dataset, pretrain_idx)
finetune_dataset = Subset(train_dataset, finetune_idx)

# DataLoader 생성
# DataLoader에서 shuffle=True로 지정하면 매 epoch마다 셔플 순서가 달라짐 => 재현성 문제 발생
# pretrain_dataset, finetune_dataset은 이미 셔플이 완료된 것으로, 이것을 DataLoader에 입력함
pretrain_loader = DataLoader(
    pretrain_dataset,
    batch_size=args.batch_size,
    num_workers=2,
    drop_last=True,
    pin_memory=True,
    shuffle=False
)

finetune_loader = DataLoader(
    finetune_dataset,
    batch_size=args.batch_size,
    num_workers=2,
    drop_last=True,
    pin_memory=True,
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    num_workers=2,
    pin_memory=True,
    shuffle=False
)

label 분포 확인 (단순 참고용, 실제 환경에서는 pretrain set의 label 분포가 어떤지 알 수 없음)

In [21]:
from collections import Counter

# label
labels = torch.stack([multi_label for _, multi_label, _ in train_dataset])

# pretext와 finetune 데이터셋의 라벨 분포 출력
pretrain_labels = labels[pretrain_idx]
pretrain_labels_class = (
    pretrain_labels[:, 0].long() * 1 +  # crackle bit → *1
    pretrain_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}
finetune_labels = labels[finetune_idx]
finetune_labels_class = (
    finetune_labels[:, 0].long() * 1 +  # crackle bit → *1
    finetune_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}

# test 데이터셋의 라벨 분포 출력
test_labels = torch.stack([multi_label for _, multi_label, _ in test_dataset])
test_labels_class = (
    test_labels[:, 0].long() * 1 +  # crackle bit → *1
    test_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}

print(f"Pretrain sample: {len(pretrain_labels_class)}")
print("Pretrain label distribution:", Counter(pretrain_labels_class.tolist()))
print(f"\nFinetune sample: {len(finetune_labels_class)}")
print("Finetune label distribution:", Counter(finetune_labels_class.tolist()))
print(f"Test sample: {len(test_labels_class)}")
print("Test label distribution:", Counter(test_labels_class.tolist()))

Pretrain sample: 3257
Pretrain label distribution: Counter({0: 1607, 1: 953, 2: 417, 3: 280})

Finetune sample: 885
Finetune label distribution: Counter({0: 456, 1: 262, 2: 84, 3: 83})
Test sample: 2756
Test label distribution: Counter({0: 1579, 1: 649, 2: 385, 3: 143})


## 3. Modeling

#### 3.1 Pre-trained ResNet50

In [22]:
def backbone_resnet():
    # 1. 기본 ResNet50 생성 (pretrained=False로 시작)
    resnet = models.resnet50(pretrained=False)

    # 2. 첫 번째 conv 레이어를 1채널용으로 수정
    resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # 먼저 fc 제거
    resnet.fc = nn.Identity()

    # 3. ImageNet 가중치 로드 (conv1 제외)
    state_dict = load_state_dict_from_url(
        'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        progress=True
    )
    if 'conv1.weight' in state_dict:
        del state_dict['conv1.weight']
    resnet.load_state_dict(state_dict, strict=False)

    return resnet

In [23]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torchvision.models as models
# from torch.hub import load_state_dict_from_url
# from torchvision.models import resnet50

# def backbone_resnet(dim=args.mlp_dim, mlp=False):
#     resnet = resnet50(weights=None, num_classes=dim)  # deprecated 대응

#     # 1채널 입력 변경
#     resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

#     # Pretrained weight 불러오기 (conv1, fc 제외)
#     state_dict = load_state_dict_from_url(
#         'https://download.pytorch.org/models/resnet50-19c8e357.pth',
#         progress=True
#     )
#     for k in ['conv1.weight', 'fc.weight', 'fc.bias']:
#         if k in state_dict:
#             del state_dict[k]
#     resnet.load_state_dict(state_dict, strict=False)

#     return resnet

In [24]:
# summary 함수 사용: (채널, 높이, 너비) 크기를 지정
summary(backbone_resnet().to(device), input_size=(1, 224, 64))

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 309MB/s]


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 64, 112, 32]           3,136
       BatchNorm2d-2          [-1, 64, 112, 32]             128
              ReLU-3          [-1, 64, 112, 32]               0
         MaxPool2d-4           [-1, 64, 56, 16]               0
            Conv2d-5           [-1, 64, 56, 16]           4,096
       BatchNorm2d-6           [-1, 64, 56, 16]             128
              ReLU-7           [-1, 64, 56, 16]               0
            Conv2d-8           [-1, 64, 56, 16]          36,864
       BatchNorm2d-9           [-1, 64, 56, 16]             128
             ReLU-10           [-1, 64, 56, 16]               0
           Conv2d-11          [-1, 256, 56, 16]          16,384
      BatchNorm2d-12          [-1, 256, 56, 16]             512
           Conv2d-13          [-1, 256, 56, 16]          16,384
      BatchNorm2d-14          [-1, 256,

#### 3.2 MoCo (MLS)

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# K: queue_g의 크기
# dim_enc: projector 통과 전, g1,g2 벡터의 차원
# dim_prj: projector 통과 후, z1,z2 벡터의 차원
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim_prj=128, K=512, m=0.999, T=0.07, top_k=10, lambda_bce=0.5):
        super().__init__()
        self.K = K
        self.m = m
        self.T = T
        self.top_k = top_k
        self.lambda_bce = lambda_bce

        self.encoder_q = base_encoder()
        self.encoder_k = base_encoder()

        dim_enc = 2048
        self.proj_head_q = nn.Sequential(
            nn.Linear(dim_enc, dim_enc),
            nn.BatchNorm1d(dim_enc),
            nn.GELU(),
            nn.Linear(dim_enc, dim_prj)
        )
        self.proj_head_k = nn.Sequential(
            nn.Linear(dim_enc, dim_enc),
            nn.BatchNorm1d(dim_enc),
            nn.GELU(),
            nn.Linear(dim_enc, dim_prj)
        )

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        self.register_buffer("queue_g", F.normalize(torch.randn(dim_enc, K), dim=0))      # g2를 정규화한 후 열 단위로 Qg에 저장
        self.register_buffer("queue_z", F.normalize(torch.randn(dim_prj, K), dim=0))      # z2를 정규화한 후 열 단위로 Qz에 저장
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))               # 현재 queue에 새로 쓸 위치(인덱스)를 추적하는 포인터 역할

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, g2, z2):
        batch_size = g2.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0
        # self.queue_g[:, ptr:ptr + batch_size] = g2.T.detach()[:, :self.queue_g.shape[1] - ptr]  # Update only available space
        # self.queue_z[:, ptr:ptr + batch_size] = z2.T.detach()[:, :self.queue_z.shape[1] - ptr]  # Update only available space
        self.queue_g[:, ptr:ptr+batch_size] = g2.T.detach()
        self.queue_z[:, ptr:ptr+batch_size] = z2.T.detach()
        self.queue_ptr[0] = (ptr + batch_size) % self.K

    def forward(self, im_q, im_k, epoch=None, warmup_epochs=10):
        # encoder_q → g1 (feature)
        g1 = F.normalize(self.encoder_q(im_q), dim=1)  # shape: [B, 2048]

        # projection head → z1
        z1 = F.normalize(self.proj_head_q(g1), dim=1)  # shape: [B, 128]

        # encoder k
        with torch.no_grad():
            self._momentum_update_key_encoder()
            g2 = F.normalize(self.encoder_k(im_k), dim=1)
            z2 = F.normalize(self.proj_head_k(g2), dim=1)

        # top-k mining
        sim_g = torch.matmul(g1, self.queue_g.clone().detach())  # [N, K]
        topk_idx = torch.topk(sim_g, self.top_k, dim=1).indices
        y = torch.zeros_like(sim_g)
        y.scatter_(1, topk_idx, 1.0)

        # logits from z1 · Qz
        sim_z = torch.matmul(z1, self.queue_z.clone().detach())
        bce_loss = F.binary_cross_entropy_with_logits(sim_z / self.T, y)

        # InfoNCE loss
        l_pos = torch.sum(z1 * z2, dim=1, keepdim=True)
        l_neg = torch.matmul(z1, self.queue_z.clone().detach())
        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)
        info_nce_loss = F.cross_entropy(logits, labels)

        # Total loss (with optional warmup)
        if epoch is not None and epoch < warmup_epochs:
            loss = info_nce_loss
        else:
            loss = info_nce_loss + self.lambda_bce * bce_loss

        self._dequeue_and_enqueue(g2, z2)

        return loss, logits, labels

In [26]:
# import torch
# import torch.nn as nn

# class MoCo(nn.Module):
#     """
#     Build a MoCo model with: a query encoder, a key encoder, and a queue
#     https://arxiv.org/abs/1911.05722
#     """

#     def __init__(self, base_encoder, dim=args.mlp_dim, K=args.K, m=args.m, T=args.T, mlp=args.mlp):
#         """
#         dim: feature dimension (default: 128)
#         K: queue size; number of negative keys (default: 8192) # original=65536
#         m: moco momentum of updating key encoder (default: 0.999)
#         T: softmax temperature (default: 0.07)
#         mlp: if True, use MLP head (default: True)
#         """
#         super(MoCo, self).__init__()

#         self.K = K
#         self.m = m
#         self.T = T

#         # create the encoders
#         # num_classes is the output fc dimension
#         self.encoder_q = base_encoder(dim=args.mlp_dim, mlp=args.mlp)
#         self.encoder_k = base_encoder(dim=args.mlp_dim, mlp=args.mlp)

#         if mlp:  # hack: brute-force replacement
#             dim_mlp = self.encoder_q.fc.weight.shape[1]
#             self.encoder_q.fc = nn.Sequential(
#                 nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
#             )
#             self.encoder_k.fc = nn.Sequential(
#                 nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
#             )

#         for param_q, param_k in zip(
#             self.encoder_q.parameters(), self.encoder_k.parameters()
#         ):
#             param_k.data.copy_(param_q.data)  # initialize
#             param_k.requires_grad = False  # not update by gradient

#         # create the queue
#         self.register_buffer("queue", torch.randn(dim, K))
#         self.queue = nn.functional.normalize(self.queue, dim=0)

#         self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

#     @torch.no_grad()
#     def _momentum_update_key_encoder(self):
#         """
#         Momentum update of the key encoder
#         """
#         for param_q, param_k in zip(
#             self.encoder_q.parameters(), self.encoder_k.parameters()
#         ):
#             param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

#     @torch.no_grad()
#     def _dequeue_and_enqueue(self, keys):
#         # gather keys before updating queue
#         keys = keys

#         batch_size = keys.shape[0]

#         ptr = int(self.queue_ptr)
#         assert self.K % batch_size == 0  # for simplicity

#         # replace the keys at ptr (dequeue and enqueue)
#         self.queue[:, ptr : ptr + batch_size] = keys.T
#         ptr = (ptr + batch_size) % self.K  # move pointer

#         self.queue_ptr[0] = ptr

#     def forward(self, im_q, im_k):
#         """
#         Input:
#             im_q: a batch of query images
#             im_k: a batch of key images
#         Output:
#             logits, targets
#         """

#         # compute query features
#         q = self.encoder_q(im_q)  # queries: NxC
#         q = nn.functional.normalize(q, dim=1)

#         # compute key features
#         with torch.no_grad():  # no gradient to keys
#             self._momentum_update_key_encoder()  # update the key encoder
#             k = self.encoder_k(im_k)  # keys: NxC
#             k = nn.functional.normalize(k, dim=1)

#         # compute logits
#         # Einstein sum is more intuitive
#         # positive logits: Nx1
#         l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) # [N, 1]
#         # negative logits: NxK
#         l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])  # [N,dim] * [dim,K] = [N,K]


#         # logits: Nx(1+K)
#         logits = torch.cat([l_pos, l_neg], dim=1) # [N, 1+K]

#         # apply temperature
#         logits /= self.T

#         # labels: positive key indicators
#         labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

#         # dequeue and enqueue
#         self._dequeue_and_enqueue(k)

#         return logits, labels

## 4. Pretrain

In [27]:
next(iter(pretrain_loader))[0][0].shape

torch.Size([1, 128, 62])

In [145]:
pretrain_project_name = f'Moco_MLS_PT_{args.batch_size}bs_top{args.top_k}_{args.lambda_bce}ld_{get_timestamp()}'

# wandb 초기화 (프로젝트명, 실험 이름 등 설정)
wandb.init(
    project="SHS_ICBHI_MLS", # 프로젝트 이름
    name=f"{pretrain_project_name}",  # 실험 이름
    config={
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }
)

# 1. MoCo 모델 생성
model = MoCo(
    base_encoder = backbone_resnet,
    dim_prj = args.dim_prj,
    K = args.K,
    m = args.momentum,
    T = args.T,
    top_k = args.top_k,
    lambda_bce = args.lambda_bce
).cuda()

# 2. Optimizer
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)

# 3. Cosine Scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)

# 4. Train
# Best loss 초기화
best_loss = float('inf')
best_epoch = -1

for epoch in range(args.epochs):
    # ===============================
    # Training
    # ===============================
    model.train()
    total_train_loss = 0.0

    for i, (repeat_mel, label, _) in enumerate(pretrain_loader): # label 여기선 사용 X
        im_q, im_k, _ = aug(repeat_mel)
        im_q = im_q.cuda(device=args.gpu, non_blocking=True)
        im_k = im_k.cuda(device=args.gpu, non_blocking=True)

        optimizer.zero_grad()
        loss, output, target = model(im_q=im_q, im_k=im_k)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(pretrain_loader)
    print(f"Epoch {epoch} | Avg Train Loss: {avg_train_loss:.4f}")

    # =====================================
    # Scheduler
    # =====================================
    scheduler.step()

    # =====================================
    # Logging with wandb
    # =====================================
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({
        "epoch": epoch,
        "train_loss": avg_train_loss,
        "lr": current_lr
    })

    # =====================================
    # Checkpoint (Every 100 epochs)
    # =====================================
    if (epoch + 1) % 100 == 0:
        ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_{epoch:03d}.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, ckpt_path)
        print(f"💾 Saved checkpoint to {ckpt_path}")

    # ===============================
    # Save Best Checkpoint
    # ===============================
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        best_epoch = epoch
        best_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")

Epoch 0 | Avg Train Loss: 7.3025
=> Saved best checkpoint (epoch: 0, loss: 7.3025)
Epoch 1 | Avg Train Loss: 6.7359
=> Saved best checkpoint (epoch: 1, loss: 6.7359)
Epoch 2 | Avg Train Loss: 6.4584
=> Saved best checkpoint (epoch: 2, loss: 6.4584)
Epoch 3 | Avg Train Loss: 5.9246
=> Saved best checkpoint (epoch: 3, loss: 5.9246)
Epoch 4 | Avg Train Loss: 5.5460
=> Saved best checkpoint (epoch: 4, loss: 5.5460)
Epoch 5 | Avg Train Loss: 5.4201
=> Saved best checkpoint (epoch: 5, loss: 5.4201)
Epoch 6 | Avg Train Loss: 5.2855
=> Saved best checkpoint (epoch: 6, loss: 5.2855)
Epoch 7 | Avg Train Loss: 5.3276
Epoch 8 | Avg Train Loss: 5.1401
=> Saved best checkpoint (epoch: 8, loss: 5.1401)
Epoch 9 | Avg Train Loss: 5.0494
=> Saved best checkpoint (epoch: 9, loss: 5.0494)
Epoch 10 | Avg Train Loss: 4.9193
=> Saved best checkpoint (epoch: 10, loss: 4.9193)
Epoch 11 | Avg Train Loss: 4.7761
=> Saved best checkpoint (epoch: 11, loss: 4.7761)
Epoch 12 | Avg Train Loss: 4.6123
=> Saved best ch

In [146]:
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█
lr,█████████▇▇▇▇▇▇▇▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▁▁▁▁
train_loss,█▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▁▂▂▁▂▂▂▂▁▂▂▁▂▁▂▁▁

0,1
epoch,299.0
lr,0.0
train_loss,1.26079


## 5. Linear Evaluation

In [148]:
import os
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix

finetune_project_name = f'Moco_MLS_LE_{args.batch_size}bs_top{args.top_k}_{args.lambda_bce}ld_{get_timestamp()}'

# wandb 초기화 (프로젝트명, 실험 이름 등 설정)
wandb.init(
    project="SHS_ICBHI_MLS",          # 프로젝트 이름
    name=f"{finetune_project_name}",  # 실험 이름
    config={
        "epochs": args.ft_epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }
)

# 1. Model Load
# ckpt_path
load_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
save_ckpt_path = CHECKPOINT_PATH

# Load Encoder
model_eval = MoCo(
    base_encoder=backbone_resnet,
    dim_prj = args.dim_prj,
    K = args.K,
    m = args.momentum,
    T = args.T,
    top_k = args.top_k,
    lambda_bce = args.lambda_bce
)
checkpoint = torch.load(load_ckpt_path,
                        map_location=device)  # map_location 파라미터 추가
model_eval.load_state_dict(checkpoint["state_dict"])
encoder = model_eval.encoder_q.eval().to(device)

# 2. Dataset 정의
# Dataset 정의는 이미 되어있음 - finetune_loader, test_loader

# 3. Linear Evaluation을 위한 분류 모델 정의 ( Data 개수 작으므로, encoder 파라미터 frozen )
class FineTuningModel(nn.Module):
    def __init__(self, encoder, out_dim, num_classes=2):
        super().__init__()
        self.encoder = encoder
        # 마지막 FC layer를 제외한 encoder의 모든 레이어 freeze
        for param in self.encoder.parameters():
            param.requires_grad = False

        # 새로운 분류 헤드 추가 (Crackle, Wheeze를 독립적으로 예측하므로 num_classes=2)
        # self.classifier = nn.Sequential(
        #     nn.Linear(out_dim, 256),
        #     nn.GELU(),
        #     nn.Linear(256, num_classes)
        # )
        self.classifier = nn.Linear(out_dim, num_classes)

    def forward(self, x):
        features = self.encoder(x)
        return self.classifier(features)

# 4. 모델, 손실 함수, 옵티마이저 설정
model = FineTuningModel(encoder, out_dim = args.out_dim).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=args.ft_epochs, eta_min=1e-6)  # Linear Evaluation에서 epochs는 다르게 적용

# Best loss 초기화
best_loss = float('inf')
best_epoch = -1

# 5. Linear Evaluation
for epoch in range(args.ft_epochs):

    model.train()
    total_loss = 0.0
    total_samples = 0

    all_preds = []
    all_labels = []

    pbar = tqdm(finetune_loader, desc='Linear Evaluation')
    for i, (cycle, labels, _) in enumerate(pbar):
        # Forward pass
        cycle = cycle.cuda(args.gpu)
        labels = labels.cuda(args.gpu)

        # backpropagation
        optimizer.zero_grad()
        output = model(cycle)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        # update loss
        total_loss += loss.item()

        # 예측값과 실제값 저장
        predicted = (torch.sigmoid(output) > 0.5).float()
        all_preds.append(predicted.detach().cpu())
        all_labels.append(labels.detach().cpu())


    # train loss
    train_loss = total_loss / len(finetune_loader)

    # Concatenate
    all_preds = torch.cat(all_preds, dim=0).numpy()    # shape: [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy()  # shape: [N, 2]

    crackle_sens = 0
    crackle_spec = 0
    wheeze_sens = 0
    wheeze_spec = 0

    for i, label_name in enumerate(['Crackle', 'Wheeze']):
        y_true = all_labels[:, i]
        y_pred = all_preds[:, i]

        cm = confusion_matrix(y_true, y_pred)  # [[TN, FP], [FN, TP]]
        TN, FP, FN, TP = cm.ravel()

        sensitivity = TP / (TP + FN + 1e-6)
        specificity = TN / (TN + FP + 1e-6)

        if i == 0:
            crackle_sens = sensitivity
            crackle_spec = specificity
        elif i == 1:
            wheeze_sens = sensitivity
            wheeze_spec = specificity

    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}")
    print(f"  [Average] Sens: {(crackle_sens+wheeze_sens)/2:.4f}, Spec: {(crackle_spec+wheeze_spec)/2:.4f}, Score: {(crackle_sens+crackle_spec+wheeze_sens+wheeze_spec)/4:.4f}")
    print(f"  [Crackle] Sens: {crackle_sens:.4f}, Spec: {crackle_spec:.4f}, Score: {(crackle_sens+crackle_spec)/2:.4f}")
    print(f"  [Wheeze]  Sens: {wheeze_sens:.4f}, Spec: {wheeze_spec:.4f}, Score: {(wheeze_sens+wheeze_spec)/2:.4f}")

    # learning rate scheduling
    scheduler.step()

    # Save Best Checkpoint
    if train_loss < best_loss:
        best_loss = train_loss
        best_epoch = epoch
        best_ckpt_path = save_ckpt_path + f"/{finetune_project_name}_best.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch+1}, loss: {best_loss:.4f})")


wandb.finish()

Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.01it/s]


Epoch: 1, Train Loss: 8.3402
  [Average] Sens: 0.2995, Spec: 0.6539, Score: 0.4767
  [Crackle] Sens: 0.4136, Spec: 0.4715, Score: 0.4425
  [Wheeze]  Sens: 0.1854, Spec: 0.8363, Score: 0.5109
=> Saved best checkpoint (epoch: 1, loss: 8.3402)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.92it/s]


Epoch: 2, Train Loss: 7.3586
  [Average] Sens: 0.2357, Spec: 0.7486, Score: 0.4922
  [Crackle] Sens: 0.3390, Spec: 0.6723, Score: 0.5056
  [Wheeze]  Sens: 0.1325, Spec: 0.8250, Score: 0.4787
=> Saved best checkpoint (epoch: 2, loss: 7.3586)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.99it/s]


Epoch: 3, Train Loss: 5.2327
  [Average] Sens: 0.3599, Spec: 0.6653, Score: 0.5126
  [Crackle] Sens: 0.4814, Spec: 0.5137, Score: 0.4975
  [Wheeze]  Sens: 0.2384, Spec: 0.8169, Score: 0.5276
=> Saved best checkpoint (epoch: 3, loss: 5.2327)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.56it/s]


Epoch: 4, Train Loss: 3.2149
  [Average] Sens: 0.1522, Spec: 0.8109, Score: 0.4815
  [Crackle] Sens: 0.1322, Spec: 0.8097, Score: 0.4710
  [Wheeze]  Sens: 0.1722, Spec: 0.8120, Score: 0.4921
=> Saved best checkpoint (epoch: 4, loss: 3.2149)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.47it/s]


Epoch: 5, Train Loss: 1.7745
  [Average] Sens: 0.5052, Spec: 0.6579, Score: 0.5816
  [Crackle] Sens: 0.5932, Spec: 0.5687, Score: 0.5810
  [Wheeze]  Sens: 0.4172, Spec: 0.7472, Score: 0.5822
=> Saved best checkpoint (epoch: 5, loss: 1.7745)


Linear Evaluation: 100%|██████████| 6/6 [00:01<00:00,  5.97it/s]


Epoch: 6, Train Loss: 1.3691
  [Average] Sens: 0.3531, Spec: 0.8212, Score: 0.5872
  [Crackle] Sens: 0.6068, Spec: 0.6765, Score: 0.6417
  [Wheeze]  Sens: 0.0993, Spec: 0.9660, Score: 0.5327
=> Saved best checkpoint (epoch: 6, loss: 1.3691)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  7.17it/s]


Epoch: 7, Train Loss: 1.1049
  [Average] Sens: 0.5016, Spec: 0.7368, Score: 0.6192
  [Crackle] Sens: 0.5661, Spec: 0.6681, Score: 0.6171
  [Wheeze]  Sens: 0.4371, Spec: 0.8055, Score: 0.6213
=> Saved best checkpoint (epoch: 7, loss: 1.1049)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.42it/s]


Epoch: 8, Train Loss: 1.1195
  [Average] Sens: 0.4511, Spec: 0.7551, Score: 0.6031
  [Crackle] Sens: 0.4915, Spec: 0.6448, Score: 0.5682
  [Wheeze]  Sens: 0.4106, Spec: 0.8655, Score: 0.6380


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.71it/s]


Epoch: 9, Train Loss: 1.1467
  [Average] Sens: 0.3889, Spec: 0.7694, Score: 0.5792
  [Crackle] Sens: 0.4136, Spec: 0.6490, Score: 0.5313
  [Wheeze]  Sens: 0.3642, Spec: 0.8898, Score: 0.6270


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.52it/s]


Epoch: 10, Train Loss: 1.0375
  [Average] Sens: 0.3990, Spec: 0.7883, Score: 0.5936
  [Crackle] Sens: 0.4271, Spec: 0.6786, Score: 0.5529
  [Wheeze]  Sens: 0.3709, Spec: 0.8979, Score: 0.6344
=> Saved best checkpoint (epoch: 10, loss: 1.0375)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.17it/s]


Epoch: 11, Train Loss: 0.9427
  [Average] Sens: 0.4125, Spec: 0.7980, Score: 0.6052
  [Crackle] Sens: 0.4475, Spec: 0.6786, Score: 0.5631
  [Wheeze]  Sens: 0.3775, Spec: 0.9173, Score: 0.6474
=> Saved best checkpoint (epoch: 11, loss: 0.9427)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.89it/s]


Epoch: 12, Train Loss: 0.8639
  [Average] Sens: 0.4276, Spec: 0.8060, Score: 0.6168
  [Crackle] Sens: 0.4712, Spec: 0.6913, Score: 0.5813
  [Wheeze]  Sens: 0.3841, Spec: 0.9206, Score: 0.6523
=> Saved best checkpoint (epoch: 12, loss: 0.8639)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.23it/s]


Epoch: 13, Train Loss: 0.7799
  [Average] Sens: 0.4693, Spec: 0.8162, Score: 0.6428
  [Crackle] Sens: 0.4949, Spec: 0.6892, Score: 0.5921
  [Wheeze]  Sens: 0.4437, Spec: 0.9433, Score: 0.6935
=> Saved best checkpoint (epoch: 13, loss: 0.7799)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.11it/s]


Epoch: 14, Train Loss: 0.7497
  [Average] Sens: 0.4712, Spec: 0.8025, Score: 0.6368
  [Crackle] Sens: 0.5119, Spec: 0.6892, Score: 0.6005
  [Wheeze]  Sens: 0.4305, Spec: 0.9157, Score: 0.6731
=> Saved best checkpoint (epoch: 14, loss: 0.7497)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.23it/s]


Epoch: 15, Train Loss: 0.7307
  [Average] Sens: 0.4532, Spec: 0.8327, Score: 0.6430
  [Crackle] Sens: 0.5356, Spec: 0.7125, Score: 0.6240
  [Wheeze]  Sens: 0.3709, Spec: 0.9530, Score: 0.6619
=> Saved best checkpoint (epoch: 15, loss: 0.7307)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.11it/s]


Epoch: 16, Train Loss: 0.7461
  [Average] Sens: 0.4798, Spec: 0.8410, Score: 0.6604
  [Crackle] Sens: 0.5424, Spec: 0.7209, Score: 0.6317
  [Wheeze]  Sens: 0.4172, Spec: 0.9611, Score: 0.6892


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.52it/s]


Epoch: 17, Train Loss: 0.8053
  [Average] Sens: 0.4662, Spec: 0.8202, Score: 0.6432
  [Crackle] Sens: 0.5085, Spec: 0.6956, Score: 0.6020
  [Wheeze]  Sens: 0.4238, Spec: 0.9449, Score: 0.6844


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.59it/s]


Epoch: 18, Train Loss: 0.8374
  [Average] Sens: 0.4528, Spec: 0.8233, Score: 0.6381
  [Crackle] Sens: 0.5017, Spec: 0.6871, Score: 0.5944
  [Wheeze]  Sens: 0.4040, Spec: 0.9595, Score: 0.6817


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.99it/s]


Epoch: 19, Train Loss: 0.8375
  [Average] Sens: 0.4528, Spec: 0.8260, Score: 0.6394
  [Crackle] Sens: 0.4949, Spec: 0.6892, Score: 0.5921
  [Wheeze]  Sens: 0.4106, Spec: 0.9627, Score: 0.6867


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.22it/s]


Epoch: 20, Train Loss: 0.8089
  [Average] Sens: 0.4343, Spec: 0.8255, Score: 0.6299
  [Crackle] Sens: 0.4712, Spec: 0.6850, Score: 0.5781
  [Wheeze]  Sens: 0.3974, Spec: 0.9660, Score: 0.6817


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.67it/s]


Epoch: 21, Train Loss: 0.7734
  [Average] Sens: 0.4443, Spec: 0.8260, Score: 0.6352
  [Crackle] Sens: 0.4780, Spec: 0.6829, Score: 0.5804
  [Wheeze]  Sens: 0.4106, Spec: 0.9692, Score: 0.6899


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.23it/s]


Epoch: 22, Train Loss: 0.7435
  [Average] Sens: 0.4379, Spec: 0.8316, Score: 0.6347
  [Crackle] Sens: 0.4983, Spec: 0.6956, Score: 0.5969
  [Wheeze]  Sens: 0.3775, Spec: 0.9676, Score: 0.6725


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.39it/s]


Epoch: 23, Train Loss: 0.7171
  [Average] Sens: 0.4579, Spec: 0.8326, Score: 0.6453
  [Crackle] Sens: 0.5119, Spec: 0.6977, Score: 0.6048
  [Wheeze]  Sens: 0.4040, Spec: 0.9676, Score: 0.6858
=> Saved best checkpoint (epoch: 23, loss: 0.7171)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.38it/s]


Epoch: 24, Train Loss: 0.6984
  [Average] Sens: 0.4580, Spec: 0.8324, Score: 0.6452
  [Crackle] Sens: 0.5186, Spec: 0.6956, Score: 0.6071
  [Wheeze]  Sens: 0.3974, Spec: 0.9692, Score: 0.6833
=> Saved best checkpoint (epoch: 24, loss: 0.6984)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.29it/s]


Epoch: 25, Train Loss: 0.6869
  [Average] Sens: 0.4647, Spec: 0.8419, Score: 0.6533
  [Crackle] Sens: 0.5254, Spec: 0.7146, Score: 0.6200
  [Wheeze]  Sens: 0.4040, Spec: 0.9692, Score: 0.6866
=> Saved best checkpoint (epoch: 25, loss: 0.6869)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.71it/s]


Epoch: 26, Train Loss: 0.6819
  [Average] Sens: 0.4766, Spec: 0.8435, Score: 0.6600
  [Crackle] Sens: 0.5492, Spec: 0.7146, Score: 0.6319
  [Wheeze]  Sens: 0.4040, Spec: 0.9724, Score: 0.6882
=> Saved best checkpoint (epoch: 26, loss: 0.6819)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.62it/s]


Epoch: 27, Train Loss: 0.6804
  [Average] Sens: 0.4733, Spec: 0.8456, Score: 0.6594
  [Crackle] Sens: 0.5492, Spec: 0.7188, Score: 0.6340
  [Wheeze]  Sens: 0.3974, Spec: 0.9724, Score: 0.6849
=> Saved best checkpoint (epoch: 27, loss: 0.6804)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.42it/s]


Epoch: 28, Train Loss: 0.6788
  [Average] Sens: 0.4783, Spec: 0.8427, Score: 0.6605
  [Crackle] Sens: 0.5525, Spec: 0.7146, Score: 0.6336
  [Wheeze]  Sens: 0.4040, Spec: 0.9708, Score: 0.6874
=> Saved best checkpoint (epoch: 28, loss: 0.6788)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.69it/s]


Epoch: 29, Train Loss: 0.6754
  [Average] Sens: 0.4766, Spec: 0.8427, Score: 0.6596
  [Crackle] Sens: 0.5492, Spec: 0.7146, Score: 0.6319
  [Wheeze]  Sens: 0.4040, Spec: 0.9708, Score: 0.6874
=> Saved best checkpoint (epoch: 29, loss: 0.6754)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.77it/s]


Epoch: 30, Train Loss: 0.6691
  [Average] Sens: 0.4732, Spec: 0.8406, Score: 0.6569
  [Crackle] Sens: 0.5424, Spec: 0.7104, Score: 0.6264
  [Wheeze]  Sens: 0.4040, Spec: 0.9708, Score: 0.6874
=> Saved best checkpoint (epoch: 30, loss: 0.6691)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.58it/s]


Epoch: 31, Train Loss: 0.6607
  [Average] Sens: 0.4664, Spec: 0.8417, Score: 0.6540
  [Crackle] Sens: 0.5288, Spec: 0.7125, Score: 0.6206
  [Wheeze]  Sens: 0.4040, Spec: 0.9708, Score: 0.6874
=> Saved best checkpoint (epoch: 31, loss: 0.6607)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.06it/s]


Epoch: 32, Train Loss: 0.6506
  [Average] Sens: 0.4664, Spec: 0.8414, Score: 0.6539
  [Crackle] Sens: 0.5288, Spec: 0.7104, Score: 0.6196
  [Wheeze]  Sens: 0.4040, Spec: 0.9724, Score: 0.6882
=> Saved best checkpoint (epoch: 32, loss: 0.6506)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.06it/s]


Epoch: 33, Train Loss: 0.6398
  [Average] Sens: 0.4630, Spec: 0.8446, Score: 0.6538
  [Crackle] Sens: 0.5220, Spec: 0.7167, Score: 0.6194
  [Wheeze]  Sens: 0.4040, Spec: 0.9724, Score: 0.6882
=> Saved best checkpoint (epoch: 33, loss: 0.6398)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.83it/s]


Epoch: 34, Train Loss: 0.6287
  [Average] Sens: 0.4696, Spec: 0.8435, Score: 0.6566
  [Crackle] Sens: 0.5220, Spec: 0.7146, Score: 0.6183
  [Wheeze]  Sens: 0.4172, Spec: 0.9724, Score: 0.6948
=> Saved best checkpoint (epoch: 34, loss: 0.6287)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.57it/s]


Epoch: 35, Train Loss: 0.6176
  [Average] Sens: 0.4645, Spec: 0.8456, Score: 0.6551
  [Crackle] Sens: 0.5119, Spec: 0.7188, Score: 0.6153
  [Wheeze]  Sens: 0.4172, Spec: 0.9724, Score: 0.6948
=> Saved best checkpoint (epoch: 35, loss: 0.6176)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.37it/s]


Epoch: 36, Train Loss: 0.6064
  [Average] Sens: 0.4730, Spec: 0.8446, Score: 0.6588
  [Crackle] Sens: 0.5288, Spec: 0.7167, Score: 0.6228
  [Wheeze]  Sens: 0.4172, Spec: 0.9724, Score: 0.6948
=> Saved best checkpoint (epoch: 36, loss: 0.6064)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.84it/s]


Epoch: 37, Train Loss: 0.5954
  [Average] Sens: 0.4764, Spec: 0.8456, Score: 0.6610
  [Crackle] Sens: 0.5356, Spec: 0.7188, Score: 0.6272
  [Wheeze]  Sens: 0.4172, Spec: 0.9724, Score: 0.6948
=> Saved best checkpoint (epoch: 37, loss: 0.5954)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.46it/s]


Epoch: 38, Train Loss: 0.5845
  [Average] Sens: 0.4898, Spec: 0.8507, Score: 0.6702
  [Crackle] Sens: 0.5492, Spec: 0.7273, Score: 0.6382
  [Wheeze]  Sens: 0.4305, Spec: 0.9741, Score: 0.7023
=> Saved best checkpoint (epoch: 38, loss: 0.5845)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.48it/s]


Epoch: 39, Train Loss: 0.5738
  [Average] Sens: 0.4881, Spec: 0.8538, Score: 0.6710
  [Crackle] Sens: 0.5458, Spec: 0.7336, Score: 0.6397
  [Wheeze]  Sens: 0.4305, Spec: 0.9741, Score: 0.7023
=> Saved best checkpoint (epoch: 39, loss: 0.5738)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.48it/s]


Epoch: 40, Train Loss: 0.5632
  [Average] Sens: 0.4864, Spec: 0.8560, Score: 0.6712
  [Crackle] Sens: 0.5424, Spec: 0.7378, Score: 0.6401
  [Wheeze]  Sens: 0.4305, Spec: 0.9741, Score: 0.7023
=> Saved best checkpoint (epoch: 40, loss: 0.5632)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.84it/s]


Epoch: 41, Train Loss: 0.5528
  [Average] Sens: 0.4830, Spec: 0.8549, Score: 0.6690
  [Crackle] Sens: 0.5356, Spec: 0.7357, Score: 0.6357
  [Wheeze]  Sens: 0.4305, Spec: 0.9741, Score: 0.7023
=> Saved best checkpoint (epoch: 41, loss: 0.5528)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.41it/s]


Epoch: 42, Train Loss: 0.5424
  [Average] Sens: 0.4846, Spec: 0.8660, Score: 0.6753
  [Crackle] Sens: 0.5322, Spec: 0.7548, Score: 0.6435
  [Wheeze]  Sens: 0.4371, Spec: 0.9773, Score: 0.7072
=> Saved best checkpoint (epoch: 42, loss: 0.5424)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.31it/s]


Epoch: 43, Train Loss: 0.5321
  [Average] Sens: 0.4863, Spec: 0.8663, Score: 0.6763
  [Crackle] Sens: 0.5356, Spec: 0.7569, Score: 0.6462
  [Wheeze]  Sens: 0.4371, Spec: 0.9757, Score: 0.7064
=> Saved best checkpoint (epoch: 43, loss: 0.5321)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.81it/s]


Epoch: 44, Train Loss: 0.5218
  [Average] Sens: 0.4863, Spec: 0.8652, Score: 0.6757
  [Crackle] Sens: 0.5288, Spec: 0.7548, Score: 0.6418
  [Wheeze]  Sens: 0.4437, Spec: 0.9757, Score: 0.7097
=> Saved best checkpoint (epoch: 44, loss: 0.5218)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.37it/s]


Epoch: 45, Train Loss: 0.5115
  [Average] Sens: 0.4997, Spec: 0.8663, Score: 0.6830
  [Crackle] Sens: 0.5492, Spec: 0.7569, Score: 0.6530
  [Wheeze]  Sens: 0.4503, Spec: 0.9757, Score: 0.7130
=> Saved best checkpoint (epoch: 45, loss: 0.5115)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.27it/s]


Epoch: 46, Train Loss: 0.5011
  [Average] Sens: 0.5065, Spec: 0.8689, Score: 0.6877
  [Crackle] Sens: 0.5627, Spec: 0.7653, Score: 0.6640
  [Wheeze]  Sens: 0.4503, Spec: 0.9724, Score: 0.7114
=> Saved best checkpoint (epoch: 46, loss: 0.5011)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.99it/s]


Epoch: 47, Train Loss: 0.4906
  [Average] Sens: 0.5082, Spec: 0.8665, Score: 0.6874
  [Crackle] Sens: 0.5661, Spec: 0.7590, Score: 0.6625
  [Wheeze]  Sens: 0.4503, Spec: 0.9741, Score: 0.7122
=> Saved best checkpoint (epoch: 47, loss: 0.4906)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.13it/s]


Epoch: 48, Train Loss: 0.4801
  [Average] Sens: 0.5099, Spec: 0.8692, Score: 0.6896
  [Crackle] Sens: 0.5695, Spec: 0.7611, Score: 0.6653
  [Wheeze]  Sens: 0.4503, Spec: 0.9773, Score: 0.7138
=> Saved best checkpoint (epoch: 48, loss: 0.4801)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.21it/s]


Epoch: 49, Train Loss: 0.4695
  [Average] Sens: 0.5133, Spec: 0.8671, Score: 0.6902
  [Crackle] Sens: 0.5763, Spec: 0.7569, Score: 0.6666
  [Wheeze]  Sens: 0.4503, Spec: 0.9773, Score: 0.7138
=> Saved best checkpoint (epoch: 49, loss: 0.4695)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.14it/s]


Epoch: 50, Train Loss: 0.4590
  [Average] Sens: 0.5167, Spec: 0.8695, Score: 0.6931
  [Crackle] Sens: 0.5831, Spec: 0.7632, Score: 0.6731
  [Wheeze]  Sens: 0.4503, Spec: 0.9757, Score: 0.7130
=> Saved best checkpoint (epoch: 50, loss: 0.4590)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  7.54it/s]


Epoch: 51, Train Loss: 0.4487
  [Average] Sens: 0.5217, Spec: 0.8747, Score: 0.6982
  [Crackle] Sens: 0.5864, Spec: 0.7738, Score: 0.6801
  [Wheeze]  Sens: 0.4570, Spec: 0.9757, Score: 0.7163
=> Saved best checkpoint (epoch: 51, loss: 0.4487)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  7.54it/s]


Epoch: 52, Train Loss: 0.4386
  [Average] Sens: 0.5251, Spec: 0.8821, Score: 0.7036
  [Crackle] Sens: 0.5932, Spec: 0.7886, Score: 0.6909
  [Wheeze]  Sens: 0.4570, Spec: 0.9757, Score: 0.7163
=> Saved best checkpoint (epoch: 52, loss: 0.4386)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  7.61it/s]


Epoch: 53, Train Loss: 0.4288
  [Average] Sens: 0.5335, Spec: 0.8843, Score: 0.7089
  [Crackle] Sens: 0.6034, Spec: 0.7928, Score: 0.6981
  [Wheeze]  Sens: 0.4636, Spec: 0.9757, Score: 0.7196
=> Saved best checkpoint (epoch: 53, loss: 0.4288)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.08it/s]


Epoch: 54, Train Loss: 0.4194
  [Average] Sens: 0.5470, Spec: 0.8895, Score: 0.7183
  [Crackle] Sens: 0.6305, Spec: 0.8034, Score: 0.7169
  [Wheeze]  Sens: 0.4636, Spec: 0.9757, Score: 0.7196
=> Saved best checkpoint (epoch: 54, loss: 0.4194)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.42it/s]


Epoch: 55, Train Loss: 0.4107
  [Average] Sens: 0.5554, Spec: 0.8927, Score: 0.7241
  [Crackle] Sens: 0.6407, Spec: 0.8097, Score: 0.7252
  [Wheeze]  Sens: 0.4702, Spec: 0.9757, Score: 0.7229
=> Saved best checkpoint (epoch: 55, loss: 0.4107)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.83it/s]


Epoch: 56, Train Loss: 0.4028
  [Average] Sens: 0.5571, Spec: 0.8969, Score: 0.7270
  [Crackle] Sens: 0.6373, Spec: 0.8182, Score: 0.7277
  [Wheeze]  Sens: 0.4768, Spec: 0.9757, Score: 0.7263
=> Saved best checkpoint (epoch: 56, loss: 0.4028)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.03it/s]


Epoch: 57, Train Loss: 0.3959
  [Average] Sens: 0.5537, Spec: 0.9001, Score: 0.7269
  [Crackle] Sens: 0.6305, Spec: 0.8245, Score: 0.7275
  [Wheeze]  Sens: 0.4768, Spec: 0.9757, Score: 0.7263
=> Saved best checkpoint (epoch: 57, loss: 0.3959)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.50it/s]


Epoch: 58, Train Loss: 0.3901
  [Average] Sens: 0.5571, Spec: 0.9025, Score: 0.7298
  [Crackle] Sens: 0.6373, Spec: 0.8309, Score: 0.7341
  [Wheeze]  Sens: 0.4768, Spec: 0.9741, Score: 0.7254
=> Saved best checkpoint (epoch: 58, loss: 0.3901)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.44it/s]


Epoch: 59, Train Loss: 0.3852
  [Average] Sens: 0.5655, Spec: 0.9033, Score: 0.7344
  [Crackle] Sens: 0.6475, Spec: 0.8309, Score: 0.7392
  [Wheeze]  Sens: 0.4834, Spec: 0.9757, Score: 0.7296
=> Saved best checkpoint (epoch: 59, loss: 0.3852)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.63it/s]


Epoch: 60, Train Loss: 0.3811
  [Average] Sens: 0.5621, Spec: 0.9001, Score: 0.7311
  [Crackle] Sens: 0.6475, Spec: 0.8245, Score: 0.7360
  [Wheeze]  Sens: 0.4768, Spec: 0.9757, Score: 0.7263
=> Saved best checkpoint (epoch: 60, loss: 0.3811)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.47it/s]


Epoch: 61, Train Loss: 0.3776
  [Average] Sens: 0.5671, Spec: 0.8999, Score: 0.7335
  [Crackle] Sens: 0.6508, Spec: 0.8224, Score: 0.7366
  [Wheeze]  Sens: 0.4834, Spec: 0.9773, Score: 0.7304
=> Saved best checkpoint (epoch: 61, loss: 0.3776)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.86it/s]


Epoch: 62, Train Loss: 0.3747
  [Average] Sens: 0.5772, Spec: 0.9038, Score: 0.7405
  [Crackle] Sens: 0.6644, Spec: 0.8288, Score: 0.7466
  [Wheeze]  Sens: 0.4901, Spec: 0.9789, Score: 0.7345
=> Saved best checkpoint (epoch: 62, loss: 0.3747)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.60it/s]


Epoch: 63, Train Loss: 0.3721
  [Average] Sens: 0.5807, Spec: 0.9057, Score: 0.7432
  [Crackle] Sens: 0.6780, Spec: 0.8309, Score: 0.7544
  [Wheeze]  Sens: 0.4834, Spec: 0.9806, Score: 0.7320
=> Saved best checkpoint (epoch: 63, loss: 0.3721)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.44it/s]


Epoch: 64, Train Loss: 0.3702
  [Average] Sens: 0.5841, Spec: 0.9070, Score: 0.7456
  [Crackle] Sens: 0.6847, Spec: 0.8351, Score: 0.7599
  [Wheeze]  Sens: 0.4834, Spec: 0.9789, Score: 0.7312
=> Saved best checkpoint (epoch: 64, loss: 0.3702)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.52it/s]


Epoch: 65, Train Loss: 0.3687
  [Average] Sens: 0.5892, Spec: 0.9070, Score: 0.7481
  [Crackle] Sens: 0.6949, Spec: 0.8351, Score: 0.7650
  [Wheeze]  Sens: 0.4834, Spec: 0.9789, Score: 0.7312
=> Saved best checkpoint (epoch: 65, loss: 0.3687)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.18it/s]


Epoch: 66, Train Loss: 0.3678
  [Average] Sens: 0.5809, Spec: 0.9064, Score: 0.7437
  [Crackle] Sens: 0.6915, Spec: 0.8372, Score: 0.7644
  [Wheeze]  Sens: 0.4702, Spec: 0.9757, Score: 0.7229
=> Saved best checkpoint (epoch: 66, loss: 0.3678)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.61it/s]


Epoch: 67, Train Loss: 0.3674
  [Average] Sens: 0.5676, Spec: 0.9073, Score: 0.7374
  [Crackle] Sens: 0.6915, Spec: 0.8372, Score: 0.7644
  [Wheeze]  Sens: 0.4437, Spec: 0.9773, Score: 0.7105
=> Saved best checkpoint (epoch: 67, loss: 0.3674)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.74it/s]


Epoch: 68, Train Loss: 0.3674
  [Average] Sens: 0.5742, Spec: 0.9091, Score: 0.7416
  [Crackle] Sens: 0.6915, Spec: 0.8457, Score: 0.7686
  [Wheeze]  Sens: 0.4570, Spec: 0.9724, Score: 0.7147
=> Saved best checkpoint (epoch: 68, loss: 0.3674)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.12it/s]


Epoch: 69, Train Loss: 0.3672
  [Average] Sens: 0.5924, Spec: 0.9075, Score: 0.7500
  [Crackle] Sens: 0.6881, Spec: 0.8393, Score: 0.7637
  [Wheeze]  Sens: 0.4967, Spec: 0.9757, Score: 0.7362
=> Saved best checkpoint (epoch: 69, loss: 0.3672)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.71it/s]


Epoch: 70, Train Loss: 0.3656
  [Average] Sens: 0.5990, Spec: 0.9077, Score: 0.7534
  [Crackle] Sens: 0.6881, Spec: 0.8478, Score: 0.7680
  [Wheeze]  Sens: 0.5099, Spec: 0.9676, Score: 0.7388
=> Saved best checkpoint (epoch: 70, loss: 0.3656)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.38it/s]


Epoch: 71, Train Loss: 0.3620
  [Average] Sens: 0.6239, Spec: 0.9039, Score: 0.7639
  [Crackle] Sens: 0.6915, Spec: 0.8436, Score: 0.7675
  [Wheeze]  Sens: 0.5563, Spec: 0.9643, Score: 0.7603
=> Saved best checkpoint (epoch: 71, loss: 0.3620)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.25it/s]


Epoch: 72, Train Loss: 0.3575
  [Average] Sens: 0.6057, Spec: 0.9048, Score: 0.7552
  [Crackle] Sens: 0.6881, Spec: 0.8436, Score: 0.7658
  [Wheeze]  Sens: 0.5232, Spec: 0.9660, Score: 0.7446
=> Saved best checkpoint (epoch: 72, loss: 0.3575)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.32it/s]


Epoch: 73, Train Loss: 0.3540
  [Average] Sens: 0.5990, Spec: 0.9104, Score: 0.7547
  [Crackle] Sens: 0.6881, Spec: 0.8436, Score: 0.7658
  [Wheeze]  Sens: 0.5099, Spec: 0.9773, Score: 0.7436
=> Saved best checkpoint (epoch: 73, loss: 0.3540)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.69it/s]


Epoch: 74, Train Loss: 0.3525
  [Average] Sens: 0.5924, Spec: 0.9121, Score: 0.7522
  [Crackle] Sens: 0.6881, Spec: 0.8436, Score: 0.7658
  [Wheeze]  Sens: 0.4967, Spec: 0.9806, Score: 0.7386
=> Saved best checkpoint (epoch: 74, loss: 0.3525)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.25it/s]


Epoch: 75, Train Loss: 0.3525
  [Average] Sens: 0.5924, Spec: 0.9136, Score: 0.7530
  [Crackle] Sens: 0.6881, Spec: 0.8499, Score: 0.7690
  [Wheeze]  Sens: 0.4967, Spec: 0.9773, Score: 0.7370
=> Saved best checkpoint (epoch: 75, loss: 0.3525)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.08it/s]


Epoch: 76, Train Loss: 0.3518
  [Average] Sens: 0.5990, Spec: 0.9136, Score: 0.7563
  [Crackle] Sens: 0.6881, Spec: 0.8499, Score: 0.7690
  [Wheeze]  Sens: 0.5099, Spec: 0.9773, Score: 0.7436
=> Saved best checkpoint (epoch: 76, loss: 0.3518)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.27it/s]


Epoch: 77, Train Loss: 0.3505
  [Average] Sens: 0.5924, Spec: 0.9152, Score: 0.7538
  [Crackle] Sens: 0.6881, Spec: 0.8499, Score: 0.7690
  [Wheeze]  Sens: 0.4967, Spec: 0.9806, Score: 0.7386
=> Saved best checkpoint (epoch: 77, loss: 0.3505)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.62it/s]


Epoch: 78, Train Loss: 0.3499
  [Average] Sens: 0.5907, Spec: 0.9150, Score: 0.7528
  [Crackle] Sens: 0.6847, Spec: 0.8478, Score: 0.7663
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 78, loss: 0.3499)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.65it/s]


Epoch: 79, Train Loss: 0.3495
  [Average] Sens: 0.5907, Spec: 0.9134, Score: 0.7520
  [Crackle] Sens: 0.6847, Spec: 0.8478, Score: 0.7663
  [Wheeze]  Sens: 0.4967, Spec: 0.9789, Score: 0.7378
=> Saved best checkpoint (epoch: 79, loss: 0.3495)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.78it/s]


Epoch: 80, Train Loss: 0.3489
  [Average] Sens: 0.5924, Spec: 0.9160, Score: 0.7542
  [Crackle] Sens: 0.6881, Spec: 0.8499, Score: 0.7690
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 80, loss: 0.3489)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.41it/s]


Epoch: 81, Train Loss: 0.3484
  [Average] Sens: 0.5907, Spec: 0.9152, Score: 0.7530
  [Crackle] Sens: 0.6847, Spec: 0.8499, Score: 0.7673
  [Wheeze]  Sens: 0.4967, Spec: 0.9806, Score: 0.7386
=> Saved best checkpoint (epoch: 81, loss: 0.3484)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.64it/s]


Epoch: 82, Train Loss: 0.3480
  [Average] Sens: 0.5941, Spec: 0.9163, Score: 0.7552
  [Crackle] Sens: 0.6915, Spec: 0.8520, Score: 0.7718
  [Wheeze]  Sens: 0.4967, Spec: 0.9806, Score: 0.7386
=> Saved best checkpoint (epoch: 82, loss: 0.3480)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.60it/s]


Epoch: 83, Train Loss: 0.3476
  [Average] Sens: 0.5941, Spec: 0.9160, Score: 0.7551
  [Crackle] Sens: 0.6915, Spec: 0.8499, Score: 0.7707
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 83, loss: 0.3476)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.49it/s]


Epoch: 84, Train Loss: 0.3473
  [Average] Sens: 0.5958, Spec: 0.9160, Score: 0.7559
  [Crackle] Sens: 0.6949, Spec: 0.8499, Score: 0.7724
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 84, loss: 0.3473)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.59it/s]


Epoch: 85, Train Loss: 0.3470
  [Average] Sens: 0.5958, Spec: 0.9160, Score: 0.7559
  [Crackle] Sens: 0.6949, Spec: 0.8499, Score: 0.7724
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 85, loss: 0.3470)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.07it/s]


Epoch: 86, Train Loss: 0.3468
  [Average] Sens: 0.5958, Spec: 0.9160, Score: 0.7559
  [Crackle] Sens: 0.6949, Spec: 0.8499, Score: 0.7724
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 86, loss: 0.3468)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.50it/s]


Epoch: 87, Train Loss: 0.3465
  [Average] Sens: 0.5992, Spec: 0.9160, Score: 0.7576
  [Crackle] Sens: 0.7017, Spec: 0.8499, Score: 0.7758
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 87, loss: 0.3465)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.72it/s]


Epoch: 88, Train Loss: 0.3463
  [Average] Sens: 0.5992, Spec: 0.9160, Score: 0.7576
  [Crackle] Sens: 0.7017, Spec: 0.8499, Score: 0.7758
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 88, loss: 0.3463)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.55it/s]


Epoch: 89, Train Loss: 0.3461
  [Average] Sens: 0.5975, Spec: 0.9160, Score: 0.7568
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 89, loss: 0.3461)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.72it/s]


Epoch: 90, Train Loss: 0.3459
  [Average] Sens: 0.5958, Spec: 0.9160, Score: 0.7559
  [Crackle] Sens: 0.6949, Spec: 0.8499, Score: 0.7724
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 90, loss: 0.3459)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.31it/s]


Epoch: 91, Train Loss: 0.3458
  [Average] Sens: 0.5958, Spec: 0.9160, Score: 0.7559
  [Crackle] Sens: 0.6949, Spec: 0.8499, Score: 0.7724
  [Wheeze]  Sens: 0.4967, Spec: 0.9822, Score: 0.7394
=> Saved best checkpoint (epoch: 91, loss: 0.3458)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.39it/s]


Epoch: 92, Train Loss: 0.3457
  [Average] Sens: 0.5991, Spec: 0.9160, Score: 0.7576
  [Crackle] Sens: 0.6949, Spec: 0.8499, Score: 0.7724
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 92, loss: 0.3457)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.43it/s]


Epoch: 93, Train Loss: 0.3456
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 93, loss: 0.3456)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.27it/s]


Epoch: 94, Train Loss: 0.3455
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 94, loss: 0.3455)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.45it/s]


Epoch: 95, Train Loss: 0.3454
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 95, loss: 0.3454)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.70it/s]


Epoch: 96, Train Loss: 0.3454
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 96, loss: 0.3454)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.61it/s]


Epoch: 97, Train Loss: 0.3453
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 97, loss: 0.3453)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00, 10.16it/s]


Epoch: 98, Train Loss: 0.3453
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 98, loss: 0.3453)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  8.85it/s]


Epoch: 99, Train Loss: 0.3452
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 99, loss: 0.3452)


Linear Evaluation: 100%|██████████| 6/6 [00:00<00:00,  9.84it/s]


Epoch: 100, Train Loss: 0.3452
  [Average] Sens: 0.6008, Spec: 0.9160, Score: 0.7584
  [Crackle] Sens: 0.6983, Spec: 0.8499, Score: 0.7741
  [Wheeze]  Sens: 0.5033, Spec: 0.9822, Score: 0.7427
=> Saved best checkpoint (epoch: 100, loss: 0.3452)


## 6. Test

In [149]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns

test_project_name = f'Moco_MLS_TT_{args.batch_size}bs_top{args.top_k}_{args.lambda_bce}ld_{get_timestamp()}'

# wandb 초기화 (프로젝트명, 실험 이름 등 설정)
wandb.init(
    project="SHS_ICBHI_MLS",          # 프로젝트 이름
    name=f"{test_project_name}"       # 실험 이름
)

# Model Load
# ckpt_path
load_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
save_ckpt_path = CHECKPOINT_PATH

# Load Encoder
model_eval = MoCo(base_encoder=backbone_resnet)
checkpoint = torch.load(load_ckpt_path,
                        map_location=device)  # map_location 파라미터 추가
model_eval.load_state_dict(checkpoint["state_dict"])
encoder = model_eval.encoder_q.eval().to(device)

model = FineTuningModel(encoder, out_dim = args.out_dim).to(device)

# 저장된 체크포인트 로드
best_ckpt_path = CHECKPOINT_PATH + f"/{finetune_project_name}_best.pth.tar"
checkpoint = torch.load(best_ckpt_path, map_location=device)

# 모델 가중치 로드 및 평가 모드 전환
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Test 평가 함수 (multi-label 대응)
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels, _ in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            predicted = (torch.sigmoid(outputs) > 0.5).float()
            all_preds.append(predicted.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0).numpy()   # [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy() # [N, 2]

    avg_loss = running_loss / len(val_loader)

    # 개별 label별 성능 계산
    results = {}
    for i, lbl_name in enumerate(['Crackle', 'Wheeze']):
        y_true = all_labels[:, i]
        y_pred = all_preds[:, i]

        cm = confusion_matrix(y_true, y_pred)
        TN, FP, FN, TP = cm.ravel()
        sens = TP / (TP + FN + 1e-6)
        spec = TN / (TN + FP + 1e-6)
        score = (sens + spec) / 2

        results[lbl_name] = {
            'sensitivity': sens,
            'specificity': spec,
            'ICBHI score': score
        }

    # 평균 점수
    avg_results = {
        key: sum([results[cls][key] for cls in results]) / 2
        for key in ['sensitivity', 'specificity', 'ICBHI score']
    }

    return avg_loss, avg_results, results, all_labels, all_preds


# wandb 로깅용 confusion matrix 함수 (label별)
# 이걸 비율로 바꿀 필요가 있음
def log_confusion_matrix_multilabel(y_true, y_pred, label_names):
    for i, label_name in enumerate(label_names):
        y_t = y_true[:, i]
        y_p = y_pred[:, i]
        cm = confusion_matrix(y_t, y_p)
        fig, ax = plt.subplots(figsize=(4, 4))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['0', '1'], yticklabels=['0', '1'], ax=ax)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(f'Confusion Matrix - {label_name}')
        plt.tight_layout()
        wandb.log({f"confusion_matrix_{label_name}": wandb.Image(fig)})
        plt.close(fig)


test_loss, avg_results, label_results, y_true, y_pred = validate(model, test_loader, criterion, device)

print(f"[Test] Loss: {test_loss:.4f}")
for lbl in ['Crackle', 'Wheeze']:
    r = label_results[lbl]
    print(f"  [{lbl}] Sens: {r['sensitivity']:.4f}, Spec: {r['specificity']:.4f}, ICBHI Score: {r['ICBHI score']:.4f}")

# 평균 성능 출력
print(f"  [Average] Sens: {avg_results['sensitivity']:.4f}, Spec: {avg_results['specificity']:.4f}, ICBHI Score: {avg_results['ICBHI score']:.4f}")

# wandb 로그
wandb.log({
    "Test/loss": test_loss,
    "Test/sens": avg_results["sensitivity"],
    "Test/spec": avg_results["specificity"],
    "Test/Score": avg_results["ICBHI score"]
})

# Confusion matrix wandb 이미지로 로그
log_confusion_matrix_multilabel(y_true, y_pred, label_names=["Crackle", "Wheeze"])



[Test] Loss: 2.6113
  [Crackle] Sens: 0.4015, Spec: 0.7576, ICBHI Score: 0.5796
  [Wheeze] Sens: 0.3220, Spec: 0.7689, ICBHI Score: 0.5454
  [Average] Sens: 0.3617, Spec: 0.7632, ICBHI Score: 0.5625


In [150]:
def multilabel_to_multiclass(labels):
    """
    Crackle, Wheeze → 4개의 클래스로 변환
    [0,0] → 0 (Normal)
    [1,0] → 1 (Crackle)
    [0,1] → 2 (Wheeze)
    [1,1] → 3 (Both)
    """
    labels = labels.cpu().numpy() if torch.is_tensor(labels) else labels
    return (labels[:, 0] * 1 + labels[:, 1] * 2).astype(int)

Confusion Matrix를 비율로 바꿀 필요가 있음

In [151]:
def evaluate_multiclass_confusion(y_true, y_pred, class_names=["Normal", "Crackle", "Wheeze", "Both"]):
    # 4-class confusion matrix
    y_true_cls = multilabel_to_multiclass(y_true)
    y_pred_cls = multilabel_to_multiclass(y_pred)

    conf_matrix = confusion_matrix(y_true_cls, y_pred_cls, labels=[0,1,2,3])  # 4x4 matrix

    # Positive: 1,2,3 / Negative: 0
    TP = conf_matrix[1:, 1:].sum()    # 양성 중에 양성으로 예측
    FN = conf_matrix[1:, 0].sum()     # 양성인데 음성으로 예측
    FP = conf_matrix[0, 1:].sum()     # 음성인데 양성으로 예측
    TN = conf_matrix[0, 0]            # 음성인데 양성으로 예측

    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0    # 민감도
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0    # 특이도

    return conf_matrix, sensitivity, specificity


def log_multiclass_confusion_matrix_wandb(conf_matrix, class_names, sensitivity, specificity):

    fig, ax = plt.subplots(figsize=(7, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title('Multi-Class Confusion Matrix')

    plt.text(
            0.99, 0.16,  # 우하단 (x=99%, y=16%) 위치
            f"Sensitivity: {sensitivity*100:.2f}\nSpecificity: {specificity*100:.2f}\nICBHI Score: {100*(sensitivity+specificity)/2:.2f}",
            ha='right', va='bottom',
            transform=plt.gca().transAxes,  # 축 기준 좌표로 해석
            fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
        )

    plt.tight_layout()
    wandb.log({"multiclass_confusion_matrix": wandb.Image(fig)})
    plt.close(fig)


# 4개 class에 대한 confusion matrix 계산 및 성능 평가
class_names = ["Normal", "Crackle", "Wheeze", "Both"]
cm_4x4, sens, spec = evaluate_multiclass_confusion(y_true, y_pred, class_names)

print("4x4 Confusion Matrix:\n", cm_4x4)
print(f"Sensitivity: {sens:.4f}")
print(f"Specificity: {spec:.4f}")
print(f"ICBHI Score: {(sens+spec)/2:.4f}")

# wandb 로깅
log_multiclass_confusion_matrix_wandb(cm_4x4, class_names, sens, spec)

wandb.log({
    "Metrics/sensitivity_4class": sens,
    "Metrics/specificity_4class": spec,
    "Metrics/ICHBI_score_4class": (sens+spec)/2
})

wandb.finish()

4x4 Confusion Matrix:
 [[843 311 348  77]
 [318 241  51  39]
 [182  66 115  22]
 [ 74  36  31   2]]
Sensitivity: 0.5123
Specificity: 0.5339
ICBHI Score: 0.5231


0,1
Metrics/ICHBI_score_4class,▁
Metrics/sensitivity_4class,▁
Metrics/specificity_4class,▁
Test/Score,▁
Test/loss,▁
Test/sens,▁
Test/spec,▁

0,1
Metrics/ICHBI_score_4class,0.5231
Metrics/sensitivity_4class,0.51232
Metrics/specificity_4class,0.53388
Test/Score,0.56249
Test/loss,2.61133
Test/sens,0.36174
Test/spec,0.76324


## 7. t-SNE Visualization (작업중)

In [152]:
# import torch
# import matplotlib.pyplot as plt
# from sklearn.manifold import TSNE
# from tqdm import tqdm

# # Multi-label → Multi-class 레이블 변환 함수
# def multilabel_to_multiclass(labels):
#     """
#     Crackle, Wheeze → 4개의 클래스로 변환
#     [0,0] → 0 (Normal)
#     [1,0] → 1 (Crackle)
#     [0,1] → 2 (Wheeze)
#     [1,1] → 3 (Both)
#     """
#     labels = labels.cpu().numpy() if torch.is_tensor(labels) else labels
#     return (labels[:, 0] * 1 + labels[:, 1] * 2).astype(int)


# # t-SNE를 위한 feature 추출 함수
# @torch.no_grad()
# def extract_features(encoder, dataloader, device):
#     features = []
#     labels = []

#     for x, label in tqdm(dataloader, desc="Extracting features"):
#         x = x.to(device)
#         out = encoder(x)
#         out = torch.nn.functional.normalize(out, dim=1)  # L2 정규화
#         features.append(out.cpu())
#         labels.append(label.cpu())

#     features = torch.cat(features, dim=0).numpy()
#     labels = torch.cat(labels, dim=0)
#     return features, labels


# # t-SNE 시각화 함수
# def plot_tsne(features, labels, num_classes, sensitivity, specificity, title="t-SNE Visualization"):
#     labels_cls = multilabel_to_multiclass(labels)

#     tsne = TSNE(n_components=2, random_state=42, perplexity=30)
#     reduced = tsne.fit_transform(features)

#     plt.figure(figsize=(10, 8))

#     label_names = ["Normal", "Crackle", "Wheeze", "Both"]
#     for i in range(num_classes):
#         idx = labels_cls == i
#         plt.scatter(reduced[idx, 0], reduced[idx, 1], label=label_names[i], alpha=0.6)

#     plt.text(
#         0.95, 0.1,
#         f"Sensitivity: {sensitivity*100:.2f}\nSpecificity: {specificity*100:.2f}\nICBHI Score: {(sensitivity + specificity)*50:.2f}",
#         ha='right', va='bottom',
#         transform=plt.gca().transAxes,
#         fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
#     )

#     plt.legend()
#     plt.title(title)
#     plt.xlabel("Dim 1")
#     plt.ylabel("Dim 2")
#     plt.grid(True)
#     plt.show()


# # 전체 파이프라인 실행 예시
# encoder = model.base_model.eval().to(device)

# # (1) Test 데이터에 대해 Feature 및 Multi-label 정답 추출
# features, labels = extract_features(encoder, test_dl, device)

# # (2) t-SNE 시각화
# plot_tsne(
#     features,
#     labels,
#     num_classes=4,
#     sensitivity=sensitivity,  # 앞서 계산된 값 사용
#     specificity=specificity,
#     title="t-SNE Visualization of Test Data"
# )