#### 환경설정

##### 1. Wandb

In [None]:
import wandb

# wandb 로그인
wandb.login(key="")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/HyeonSeok/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvanillahub12[0m ([33mboaz_woony-boaz[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

##### 2. 라이브러리 로드

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import os
import math
import random
import pickle
import wandb
from tqdm import tqdm
from datetime import datetime
from zoneinfo import ZoneInfo

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import librosa.display

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torchvision
import torchvision.models as models
from torch import Tensor
from torchsummary import summary
from torch.hub import load_state_dict_from_url
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.metrics import confusion_matrix, f1_score
from sklearn.manifold import TSNE

##### 3. 경로 설정

In [4]:
ROOT = "/home/HyeonSeok/BOAZ-Chungzins/data/raw"
CHECKPOINT_PATH = "/home/HyeonSeok/BOAZ-Chungzins/save_path/checkpoint"
PICKLE_PATH = "/home/HyeonSeok/BOAZ-Chungzins/save_path/pickle"
text = "/home/HyeonSeok/BOAZ-Chungzins/data/metadata/train_test_split.txt"

demo_info = "/home/HyeonSeok/demographic_info.txt"

##### 4. Seed 설정

In [5]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore

seed_everything(42) # Seed 고정

## 1. Data Load

#### 1-1. Demographic Info

In [6]:
demo_df = pd.read_csv(demo_info, sep=' ', header=None).iloc[:,0:3]
demo_df.columns = ['patient', 'age', 'gender']

유일하게 나이, 성별에 결측값이 존재하는 환자 -> 223번 환자는 test data에만 존재

In [7]:
demo_df.loc[122]

patient    223
age        NaN
gender     NaN
Name: 122, dtype: object

0-18세는 소아청소년, 18세 이상은 성인으로 이진분류

In [8]:
def categorize_age(age):
    if age <= 18:
        return 1
    else:   # NaN도 0에 할당 (0이 최빈값이므로)
        return 0

demo_df['age'] = demo_df['age'].apply(categorize_age)
demo_df.columns = ['patient', 'child', 'gender']

NaN 제외, child 48명, adult 77명

In [9]:
demo_df['child'].value_counts()

child
0    77
1    49
Name: count, dtype: int64

In [10]:
demo_df['gender'] = demo_df['gender'].apply(lambda x: 1 if x=='F' else 0)   # NaN도 0에 할당 (남자가 최빈값이므로)

demo_df.columns = ['patient', 'child', 'female']
demo_df['female'].value_counts()

female
0    80
1    46
Name: count, dtype: int64

In [11]:
demo_dict = demo_df.set_index('patient')[['child', 'female']].apply(list, axis=1).to_dict()

#### 1.2 Train-Test Split

In [12]:
# WAV 파일이 있는 디렉토리 경로
data_dir = ROOT
txt_dir = ROOT

df = pd.read_csv(text, sep='\t', header=None)

# 컬럼 이름 변경
df.columns = ['filename', 'set']

In [13]:
for i in range(len(df)):
    filename = int(df.iloc[i,0].split('_')[0])

    if demo_dict[filename][0] == 1:
        df.loc[i,'child'] = 1
    else:
        df.loc[i,'child'] = 0
    
    if demo_dict[filename][1] == 1:
        df.loc[i,'female'] = 1
    else:
        df.loc[i,'female'] = 0

In [14]:
df.loc[40:50]

Unnamed: 0,filename,set,child,female
40,107_3p2_Tc_mc_AKGC417L,train,0.0,1.0
41,108_1b1_Al_sc_Meditron,train,1.0,0.0
42,109_1b1_Al_sc_Litt3200,test,0.0,1.0
43,109_1b1_Ar_sc_Litt3200,test,0.0,1.0
44,109_1b1_Ll_sc_Litt3200,test,0.0,1.0
45,109_1b1_Lr_sc_Litt3200,test,0.0,1.0
46,109_1b1_Pl_sc_Litt3200,test,0.0,1.0
47,109_1b1_Pr_sc_Litt3200,test,0.0,1.0
48,110_1b1_Pr_sc_Meditron,train,0.0,0.0
49,110_1p1_Al_sc_Meditron,train,0.0,0.0


In [15]:
# train, test split
train_df = df[df['set'] == 'train']
test_df = df[df['set'] == 'test']

# filename list
train_file_list = train_df['filename'].tolist()
test_file_list = test_df['filename'].tolist()

# age list (각 파일의 환자의 소아 여부)
train_age_list = train_df['child'].tolist()
test_age_list = test_df['child'].tolist()

# gender list (각 파일의 환자의 여성 여부)
train_gender_list = train_df['female'].tolist()
test_gender_list = test_df['female'].tolist()

# list 통합
train_list = [list(x) for x in zip(train_file_list, train_age_list, train_gender_list)]
test_list = [list(x) for x in zip(test_file_list, test_age_list, test_gender_list)]

print(f'Train : {len(train_list)}, Test : {len(test_list)}, Total : {len(train_list) + len(test_list)}')

Train : 539, Test : 381, Total : 920


## 2. Data Preprocessing

#### 2.1 Args

        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)

In [None]:
class Args:
    # Audio & Spectrogram
    target_sr = 4000   
    frame_size = 1024
    hop_length = 512    # frame_size 절반
    n_mels = 128
    target_sec = 8

    # Augmentation
    time_mask_param = 0.5
    freq_mask_param = 0.5

    # Train
    lr = 0.03
    warm = True                     # warm-up 사용 여부
    warm_epochs = 10                # warm-up 적용할 초기 epoch 수
    warmup_from = lr * 0.1          # warm-up 시작 learning rate (보통 lr의 10%)
    warmup_to = lr

    batch_size = 128
    workers = 4
    epochs = 300
    weight_decay = 1e-3

    resume = None
    schedule=[120, 160] # schedule

    # MLS
    K = 512
    momentum = 0.999
    T = 0.07
    dim_prj = 64
    top_k = 15
    lambda_bce = 0.5
    out_dim = 512

    # Linear Evaluation
    ft_epochs = 100

    # etc
    gpu = 0
    data = "./data_path"
    seed=42

args = Args()

#### 2.2 Utils (func)

In [17]:
import torch.nn.functional as F
import random

# cycle의 클래스를 추출
def get_class(cr, wh):
    if cr == 1 and wh == 1:
        return 3
    elif cr == 0 and wh == 1:
        return 2
    elif cr == 1 and wh == 0:
        return 1
    elif cr == 0 and wh == 0:
        return 0
    else:
        return -1

# Mel Spectrogram 생성 ( sr=4KHz, frame_size=1024, hop_length=512, n_mels=128 )
def generate_mel_spectrogram(waveform, sample_rate, frame_size, hop_length, n_mels):
    if hop_length is None:
        hop_length = frame_size // 2
    mel_spec_transform = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=frame_size,
        hop_length=hop_length,
        n_mels=n_mels
    )
    mel_spectrogram = mel_spec_transform(waveform)
    mel_db = T.AmplitudeToDB()(mel_spectrogram)

    # scaling
    # mean = mel_db.mean()
    # std = mel_db.std() + 1e-6
    # mel.db = (mel_db - mean) / std

    return mel_db

# Cycle Repeat 또는 Crop
def repeat_or_truncate_segment(mel_segment, target_frames):
    current_frames = mel_segment.shape[-1]
    if current_frames >= target_frames:
        return mel_segment[:, :, :target_frames]
    else:
        repeat_ratio = math.ceil(target_frames / current_frames)
        mel_segment = mel_segment.repeat(1, 1, repeat_ratio)
        return mel_segment[:, :, :target_frames]

def preprocess_waveform_segment(waveform, unit_length):

    """unit_length 기준으로 waveform을 repeat + padding 또는 crop하여 길이 정규화"""
    waveform = waveform.squeeze(0)  # (1, L) → (L,) 로 바꿔도 무방
    length_adj = unit_length - len(waveform)

    if length_adj > 0:
        # waveform이 너무 짧은 경우 → repeat + zero-padding
        half_unit = unit_length // 2

        if length_adj < half_unit:
            # 길이 차이가 작으면 단순 padding
            half_adj = length_adj // 2
            waveform = F.pad(waveform, (half_adj, length_adj - half_adj))
        else:
            # 반복 후 부족한 부분 padding
            repeat_factor = unit_length // len(waveform)
            waveform = waveform.repeat(repeat_factor)[:unit_length]
            remaining = unit_length - len(waveform)
            half_pad = remaining // 2
            waveform = F.pad(waveform, (half_pad, remaining - half_pad))
    else:
        # waveform이 너무 길면 앞쪽 1/4 내에서 랜덤 crop
        length_adj = len(waveform) - unit_length
        start = random.randint(0, length_adj // 4)
        waveform = waveform[start:start + unit_length]

    return waveform.unsqueeze(0)  # 다시 (1, L)로

def preprocess_waveform_with_fade_repeat(waveform, unit_length, fade_ratio=0.1):
    """
    길이 unit_length까지 반복하며 fade-in/out으로 연결하는 방식의 padding
    waveform: (1, L) or (L,)
    fade_ratio: 각 반복 연결부에서 fade-in/out 적용 비율 (0.1 → 10%)
    """
    if waveform.dim() == 2:
        waveform = waveform.squeeze(0)  # (1, L) → (L,)

    orig_len = len(waveform)
    fade_len = int(orig_len * fade_ratio)

    if orig_len >= unit_length:
        # 너무 길면 crop
        length_adj = orig_len - unit_length
        start = random.randint(0, length_adj // 4)
        waveform = waveform[start:start + unit_length]
        return waveform.unsqueeze(0)

    # 만들고자 하는 길이만큼 반복
    full_wave = waveform.clone()
    while len(full_wave) < unit_length:
        next_cycle = waveform.clone()

        # fade-out 마지막 구간
        fade_out = torch.linspace(1.0, 0.0, fade_len)
        full_wave[-fade_len:] *= fade_out

        # fade-in 앞부분
        fade_in = torch.linspace(0.0, 1.0, fade_len)
        next_cycle[:fade_len] *= fade_in

        # 이어붙이기
        full_wave = torch.cat([full_wave, next_cycle], dim=0)

    # 최종 길이 맞추기
    waveform = full_wave[:unit_length]
    return waveform.unsqueeze(0)  # (1, L)

# 데이터 Spec Augmentation ( 0~80% Random Masking )
def apply_spec_augment(mel_segment):

    M = mel_segment.shape[-1]
    F = mel_segment.shape[-2]

    # torchaudio의 마스킹은 0부터 mask_param까지 균등분포에서 랜덤하게 길이를 선택
    time_masking = T.TimeMasking(time_mask_param=int(M * 0.8))
    freq_masking = T.FrequencyMasking(freq_mask_param=int(F * 0.8) )

    aug1 = freq_masking(mel_segment.clone())
    aug2 = time_masking(mel_segment.clone())
    aug3 = freq_masking(time_masking(mel_segment.clone()))

    return aug1, aug2, aug3

# Waveform resample
def resample_waveform(waveform, orig_sr, target_sr=args.target_sr):
    if orig_sr != target_sr:
        resampler = torchaudio.transforms.Resample(
            orig_freq=orig_sr,
            new_freq=target_sr
        )
        return resampler(waveform), target_sr
    return waveform, orig_sr


# Normalize - Mean/Std
def get_mean_and_std(dataset):
    """ 전체 mel-spectrogram에서 mean과 std 계산 """
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

    cnt = 0
    fst_moment = torch.zeros(1)
    snd_moment = torch.zeros(1)
    for inputs, _, _ in tqdm(dataloader, desc="[Calculating Mean/Std]"):
        b, c, h, w = inputs.shape  # inputs: [1, 1, n_mels, time]
        nb_pixels = b * h * w

        fst_moment += torch.sum(inputs, dim=[0, 2, 3])
        snd_moment += torch.sum(inputs**2, dim=[0, 2, 3])
        cnt += nb_pixels

    mean = fst_moment / cnt
    std = torch.sqrt(snd_moment / cnt - mean**2)
    return mean.item(), std.item()

In [18]:
##############################################
import torch
import torch.nn.functional as F
import torchaudio.transforms as T
import numpy as np
import random

# -------------------- Augmentation functions (ICBHI 멜스펙트로그램에 최적화) --------------------

def spec_augment(mel, time_mask_ratio=0.15, freq_mask_ratio=0.15):
    """
    SpecAugment: 시간/주파수 영역 마스킹
    - 시간축 마스킹: 63 * 0.15 ≈ 9 프레임
    - 주파수 마스킹: 128 * 0.1 ≈ 12 채널
    """
    M = mel.shape[-1]  # 시간 축
    F = mel.shape[-2]  # 주파수 축

    time_masking = T.TimeMasking(time_mask_param=max(1, int(M * time_mask_ratio)))
    freq_masking = T.FrequencyMasking(freq_mask_param=max(1, int(F * freq_mask_ratio)))

    mel = freq_masking(mel.clone())
    mel = time_masking(mel)
    return mel

def add_noise(mel, noise_level=0.001):
    """
    노이즈 추가: 적당한 수준의 표준 정규분포 노이즈 (너무 높으면 손실 커짐)
    """
    noise = torch.randn_like(mel) * noise_level
    return mel + noise

def pitch_shift(mel, n_steps=2):
    """
    주파수 축 순환 이동 (mel axis). shape은 그대로 유지됨.
    n_steps=2면 ±2 멜 채널만 이동.
    """
    shift = random.randint(-n_steps, n_steps)
    if shift == 0:
        return mel
    if shift > 0:
        mel = torch.cat([mel[:, :, shift:, :], mel[:, :, :shift, :]], dim=2)
    else:
        shift = abs(shift)
        mel = torch.cat([mel[:, :, -shift:, :], mel[:, :, :-shift, :]], dim=2)
    return mel

def time_stretch(mel, min_rate=0.95, max_rate=1.05):
    """
    시간 축 길이 조절. 너무 심하지 않게 ±5% 범위로만 조정.
    - shape 유지 위해 interpolation 후 crop/pad
    """
    rate = random.uniform(min_rate, max_rate)
    if rate == 1.0:
        return mel

    orig_size = mel.shape[-1]
    target_size = int(orig_size * rate)

    mel_stretched = F.interpolate(
        mel, size=(mel.shape[-2], target_size),  # (mel_bins, time)
        mode='bilinear',
        align_corners=False
    )

    if target_size > orig_size:
        return mel_stretched[..., :orig_size]
    else:
        pad = orig_size - target_size
        return F.pad(mel_stretched, (0, pad))

# -------------------- Dispatcher --------------------

AUGMENTATION_FUNCTIONS_TORCH = {
    "spec_augment": spec_augment,
    "add_noise": add_noise,
    "pitch_shift": pitch_shift,
    "time_stretch": time_stretch
}

def apply_augmentations_torch(x, methods=[], **kwargs):
    for method in methods:
        func = AUGMENTATION_FUNCTIONS_TORCH.get(method)
        if func is None:
            raise ValueError(f"Unknown augmentation: {method}")
        x = func(x, **kwargs.get(method, {}))
    return x

In [19]:
def aug(repeat_mel):
    # 먼저 복사본 준비
    mel1 = repeat_mel.clone()
    mel2 = repeat_mel.clone()

    # 각각 다른 증강 A, B 적용
    aug1 = apply_augmentations_torch(mel1, methods=["add_noise"], add_noise={"noise_level": 0.005})
    aug2 = apply_augmentations_torch(mel2, methods=["time_stretch"], time_stretch={"min_rate": 0.8, "max_rate": 1.2})
    # aug3 = apply_augmentations_torch(mel3, methods=["pitch_shift"], pitch_shift={"n_steps": 2})

    # # 각 결과에 spec_augment 추가 적용
    aug1_spec = spec_augment(aug1, time_mask_ratio=0.6, freq_mask_ratio=0.4)
    aug2_spec = spec_augment(aug2, time_mask_ratio=0.6, freq_mask_ratio=0.4)
    # aug3_spec = spec_augment(aug3, time_mask_ratio=0.6, freq_mask_ratio=0.4)

    return aug1_spec, aug2_spec, None

# classwise하게 cycle_list를 만들지 않고, 모든 cycle_list에서 랜덤으로 샘플링
def window_mix(repeat_mel, cycle_list):
    
    mel1 = repeat_mel.clone()
    mel2 = repeat_mel.clone()

    B, C, F, T = repeat_mel.shape  # 4D: [B, 1, F, T]

    # mel1 증강
    for _ in range(random.randint(1, 3)):
        window_width = random.randint(int(T * 0.1), int(T * 0.4))
        start = random.randint(int(T * 0.1), T - window_width)
        end = start + window_width

        random_normal_cycle1 = random.choice(cycle_list)[0]
        random_normal_cycle1 = random_normal_cycle1.expand(B, -1, -1, -1)
        mel1[:, :, :, start:end] = random_normal_cycle1[:, :, :, start:end]

    # mel2 증강
    for _ in range(random.randint(1, 3)):
        window_width = random.randint(int(T * 0.1), int(T * 0.4))
        start = random.randint(int(T * 0.1), T - window_width)
        end = start + window_width

        random_normal_cycle2 = random.choice(cycle_list)[0]
        random_normal_cycle2 = random_normal_cycle2.expand(B, -1, -1, -1)
        mel2[:, :, :, start:end] = random_normal_cycle2[:, :, :, start:end]

    return mel1, mel2, None


def get_timestamp():
    """Outputs current time in KST like 2404070830"""
    kst_time = datetime.now(ZoneInfo("Asia/Seoul"))
    return kst_time.strftime('%y%m%d%H%M')

# Origin
# def aug(repeat_mel):
#     aug1, aug2, aug3 = apply_spec_augment(repeat_mel)
#     return aug1, aug2, aug3

#### 2.3 CycleDataset

In [20]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm

class CycleDataset(Dataset):
    def __init__(self, file_list, wav_dir, txt_dir, target_sec=args.target_sec, target_sr=args.target_sr, frame_size=args.frame_size, hop_length=args.hop_length, n_mels=args.n_mels):
        self.file_list = file_list
        self.wav_dir = wav_dir
        self.txt_dir = txt_dir
        self.target_sec = target_sec
        self.target_sr = target_sr
        self.frame_size = frame_size
        self.hop_length = hop_length
        self.n_mels = n_mels

        self.cycle_list = []

        print("[INFO] Preprocessing cycles...")
        for filename, child, female in tqdm(self.file_list):
            txt_path = os.path.join(self.txt_dir, filename + '.txt')
            wav_path = os.path.join(self.wav_dir, filename + '.wav')

            if not os.path.exists(txt_path):
                print(f"[WARNING] Missing file: {txt_path}")
            if not os.path.exists(wav_path):
                print(f"[WARNING] Missing file: {wav_path}")

            # Load annotation
            cycle_data = np.loadtxt(txt_path, usecols=(0, 1))
            lung_label = np.loadtxt(txt_path, usecols=(2, 3))

            # Load waveform
            waveform, orig_sr = torchaudio.load(wav_path)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)  # Stereo to mono

            # Resample to target sample rate (4kHz)
            waveform, sample_rate = resample_waveform(waveform, orig_sr, self.target_sr)

            for idx in range(len(cycle_data)):
                # 호흡 주기 start, end
                start_sample = int(cycle_data[idx, 0] * sample_rate)
                end_sample = int(cycle_data[idx, 1] * sample_rate)
                lung_duration = cycle_data[idx, 1] - cycle_data[idx, 0]

                if end_sample <= start_sample:
                    continue  # 잘못된 구간 스킵

                # Waveform repeat + padding 후 Mel_db
                cycle_wave = waveform[:, start_sample:end_sample]
                normed_wave = preprocess_waveform_with_fade_repeat(cycle_wave, unit_length=int(self.target_sec * self.target_sr))
                mel = generate_mel_spectrogram(normed_wave, sample_rate, frame_size=self.frame_size, hop_length=self.hop_length, n_mels=self.n_mels)

                # crackle, wheeze -> class
                cr = int(lung_label[idx, 0])
                wh = int(lung_label[idx, 1])
                label = get_class(cr, wh)

                multi_label = torch.tensor([
                    float(label in [1, 3]),
                    float(label in [2, 3])
                ])  # 변환된 multi-label 반환

                # meta_data (나이, 성별 정보를 추가하였음 - 0720)
                meta_data = (filename, lung_duration, child, female)

                self.cycle_list.append((mel, multi_label, meta_data))

        print(f"[INFO] Total cycles collected: {len(self.cycle_list)}")

    def __len__(self):
        return len(self.cycle_list)

    def __getitem__(self, idx):
        mel, label, meta_data = self.cycle_list[idx]
        return mel, label, meta_data

In [21]:
# def multilabel_to_class(label_tensor):
#     crackle = label_tensor[0].item()
#     wheeze = label_tensor[1].item()
#     if crackle and wheeze:
#         return 3  # Both
#     elif crackle:
#         return 1  # Crackle
#     elif wheeze:
#         return 2  # Wheeze
#     else:
#         return 0  # Normal

# def group_by_class_multilabel(cycle_list):
#     from collections import defaultdict

#     classwise_dict = defaultdict(list)

#     for sample in cycle_list:
#         label_tensor = sample[1]
#         cls = multilabel_to_class(label_tensor)
#         classwise_dict[cls].append(sample)

#     return classwise_dict

##### Pickle.dump

CycleDataset 객체 생성

In [22]:
import random
import matplotlib.pyplot as plt
import librosa.display

seed_everything(42)

wav_dir = ROOT
txt_dir = ROOT

# 1. Dataset 로드
train_dataset = CycleDataset(train_list, wav_dir, txt_dir)
test_dataset = CycleDataset(test_list, wav_dir, txt_dir)

# train_dataset 셔플
random.shuffle(train_dataset.cycle_list)

[INFO] Preprocessing cycles...


100%|██████████| 539/539 [00:11<00:00, 45.43it/s]


[INFO] Total cycles collected: 4142
[INFO] Preprocessing cycles...


100%|██████████| 381/381 [00:07<00:00, 48.44it/s]

[INFO] Total cycles collected: 2756





pickle로 train_dataset, test_dataset 외부 저장

In [23]:
pickle_name = f'MLS_age_gen_fade_normall_{args.target_sr//1000}kHz_{args.frame_size}win_{args.hop_length}hop_{args.n_mels}mel_{args.target_sec}s'

In [24]:
pickle_dict = {
    'train_dataset': train_dataset,
    'test_dataset': test_dataset
}

save_path = os.path.join(PICKLE_PATH, pickle_name + '.pkl')
with open(save_path, 'wb') as f:
    pickle.dump(pickle_dict, f)

In [25]:
# # 2. 간단 통계
# print(f"Total cycles: {len(train_dataset)}")

# label_counter = [0] * 4  # normal, crackle, wheeze, both
# for _, multi_label,_ in train_dataset:
#     if torch.equal(multi_label, torch.tensor([0., 0.])):
#         label_counter[0] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 0.])):
#         label_counter[1] += 1
#     elif torch.equal(multi_label, torch.tensor([0., 1.])):
#         label_counter[2] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 1.])):
#         label_counter[3] += 1

# for idx, count in enumerate(label_counter):
#     print(f"Class {idx}: {count} cycles")

##### Pickle.load
저장된 train_dataset, test_dataset을 로드  
(> Aug 는 Moco 모델에서 사용)

In [26]:
save_path = os.path.join(PICKLE_PATH, pickle_name + '.pkl')
with open(save_path, 'rb') as f:
    pickle_dict = pickle.load(f)

train_dataset = pickle_dict['train_dataset']
test_dataset = pickle_dict['test_dataset']

print(f"[Train] Cycles: {len(train_dataset)}")
print(f"[Test] Cycles: {len(test_dataset)}")

[Train] Cycles: 4142
[Test] Cycles: 2756


In [27]:
wav_dir = ROOT
txt_dir = ROOT

# classwise 분리
# mel_train_classwise_dict = group_by_class_multilabel(train_dataset.cycle_list)

In [28]:
mean, std = get_mean_and_std(train_dataset)
mean, std = mean, std

[Calculating Mean/Std]: 100%|██████████| 4142/4142 [00:09<00:00, 442.80it/s]


#### 2.4 DataLoader

코드 실행 환경에 따라 num_workers를 적절한 값으로 지정해주세요!

In [29]:
seed_everything(42)

train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    num_workers=4,
    drop_last=True,
    pin_memory=True,
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    num_workers=4,
    drop_last=False,
    pin_memory=True,
    shuffle=True
)

label 분포 확인 (단순 참고용, 실제 환경에서는 pretrain set의 label 분포가 어떤지 알 수 없음)

In [30]:
from collections import Counter

# label
labels = torch.stack([multi_label for _, multi_label, _ in train_dataset])

# test 데이터셋의 라벨 분포 출력
train_labels = torch.stack([multi_label for _, multi_label, _ in train_dataset])
train_labels_class = (
    train_labels[:, 0].long() * 1 +  # crackle bit → *1
    train_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}

# test 데이터셋의 라벨 분포 출력
test_labels = torch.stack([multi_label for _, multi_label, _ in test_dataset])
test_labels_class = (
    test_labels[:, 0].long() * 1 +  # crackle bit → *1
    test_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}


print(f"Train sample: {len(train_labels_class)}")
print("Train label distribution:", Counter(train_labels_class.tolist()))
print(f"\nTest sample: {len(test_labels_class)}")
print("Test label distribution:", Counter(test_labels_class.tolist()))

Train sample: 4142
Train label distribution: Counter({0: 2063, 1: 1215, 2: 501, 3: 363})

Test sample: 2756
Test label distribution: Counter({0: 1579, 1: 649, 2: 385, 3: 143})


## 3. Modeling

#### 3.1 Pre-trained ResNet50

In [31]:
# def backbone_resnet():
#     # 1. 기본 ResNet50 생성 (pretrained=False로 시작)
#     resnet = models.resnet50(pretrained=False)

#     # 2. 첫 번째 conv 레이어를 1채널용으로 수정
#     resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

#     # 먼저 fc 제거
#     resnet.fc = nn.Identity()

#     # 3. ImageNet 가중치 로드 (conv1 제외)
#     state_dict = load_state_dict_from_url(
#         'https://download.pytorch.org/models/resnet50-19c8e357.pth',
#         progress=True
#     )
#     if 'conv1.weight' in state_dict:
#         del state_dict['conv1.weight']
#     resnet.load_state_dict(state_dict, strict=False)

#     return resnet

ResNet34

In [32]:
from torchvision import models
from torch.hub import load_state_dict_from_url
import torch.nn as nn

def backbone_resnet():
    # 1. 기본 ResNet34 생성
    resnet = models.resnet34(pretrained=False)

    # 2. 첫 번째 conv 레이어를 1채널용으로 수정
    resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # fc 제거
    resnet.fc = nn.Identity()

    # 3. ImageNet 가중치 로드 (conv1 제외)
    state_dict = load_state_dict_from_url(
        'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
        progress=True
    )
    if 'conv1.weight' in state_dict:
        del state_dict['conv1.weight']
    resnet.load_state_dict(state_dict, strict=False)

    return resnet

ResNet18

In [33]:
# from torchvision import models
# from torch.hub import load_state_dict_from_url
# import torch.nn as nn

# def backbone_resnet():
#     # 1. 기본 ResNet18 생성
#     resnet = models.resnet18(pretrained=False)

#     # 2. 첫 번째 conv 레이어를 1채널용으로 수정
#     resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

#     # fc 제거
#     resnet.fc = nn.Identity()

#     # 3. ImageNet 가중치 로드 (conv1 제외)
#     state_dict = load_state_dict_from_url(
#         'https://download.pytorch.org/models/resnet18-f37072fd.pth',
#         progress=True
#     )
#     if 'conv1.weight' in state_dict:
#         del state_dict['conv1.weight']
#     resnet.load_state_dict(state_dict, strict=False)

#     return resnet

In [34]:
# summary 함수 사용: (채널, 높이, 너비) 크기를 지정
summary(backbone_resnet().to(device), input_size=(1, 224, 64))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 64, 112, 32]           3,136
       BatchNorm2d-2          [-1, 64, 112, 32]             128
              ReLU-3          [-1, 64, 112, 32]               0
         MaxPool2d-4           [-1, 64, 56, 16]               0
            Conv2d-5           [-1, 64, 56, 16]          36,864
       BatchNorm2d-6           [-1, 64, 56, 16]             128
              ReLU-7           [-1, 64, 56, 16]               0
            Conv2d-8           [-1, 64, 56, 16]          36,864
       BatchNorm2d-9           [-1, 64, 56, 16]             128
             ReLU-10           [-1, 64, 56, 16]               0
       BasicBlock-11           [-1, 64, 56, 16]               0
           Conv2d-12           [-1, 64, 56, 16]          36,864
      BatchNorm2d-13           [-1, 64, 56, 16]             128
             ReLU-14           [-1, 64,

#### 3.1 Other Bacbones

DenseNet

In [35]:
def backbone_densenet121():
    # 1. DenseNet121 구조만 (pretrained=False)
    densenet = models.densenet121(pretrained=False)

    # 2. 첫번째 conv 레이어를 1채널로 교체
    old_conv = densenet.features.conv0
    new_conv = nn.Conv2d(
        1, old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=(old_conv.bias is not None)
    )
    densenet.features.conv0 = new_conv

    # 3. ImageNet 가중치 불러오기 (conv0 제외)
    state_dict = load_state_dict_from_url(
        'https://download.pytorch.org/models/densenet121-a639ec97.pth', progress=True
    )
    # conv0 (features.conv0.weight) 삭제
    if 'features.conv0.weight' in state_dict:
        del state_dict['features.conv0.weight']
    densenet.load_state_dict(state_dict, strict=False)

    densenet.classifier = nn.Identity()

    return densenet

def backbone_densenet161():
    # 1. DenseNet161 구조만 (pretrained=False)
    densenet = models.densenet161(pretrained=False)

    # 2. 첫번째 conv 레이어를 1채널로 교체
    old_conv = densenet.features.conv0
    new_conv = nn.Conv2d(
        in_channels=1,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=(old_conv.bias is not None)
    )
    densenet.features.conv0 = new_conv

    # 3. ImageNet 가중치 불러오기 (conv0 제외)
    state_dict = load_state_dict_from_url(
        'https://download.pytorch.org/models/densenet161-8d451a50.pth', progress=True
    )
    if 'features.conv0.weight' in state_dict:
        del state_dict['features.conv0.weight']
    densenet.load_state_dict(state_dict, strict=False)

    # 4. classifier 제거
    densenet.classifier = nn.Identity()

    return densenet

def backbone_densenet201():
    # 1. DenseNet201 구조만 (pretrained=False)
    densenet = models.densenet201(pretrained=False)

    # 2. 첫번째 conv 레이어를 1채널로 교체
    old_conv = densenet.features.conv0
    new_conv = nn.Conv2d(
        in_channels=1,
        out_channels=old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=(old_conv.bias is not None)
    )
    densenet.features.conv0 = new_conv

    # 3. ImageNet 가중치 불러오기 (conv0 제외)
    state_dict = load_state_dict_from_url(
        'https://download.pytorch.org/models/densenet201-c1103571.pth', progress=True
    )
    if 'features.conv0.weight' in state_dict:
        del state_dict['features.conv0.weight']
    densenet.load_state_dict(state_dict, strict=False)

    # 4. classifier 제거
    densenet.classifier = nn.Identity()

    return densenet

In [36]:
# !pip install torchinfo

In [37]:
# from torchinfo import summary

# model = backbone_densenet121().to(device)
# summary(model, input_size=(1, 1, 224, 64))

GRU + ATT

In [38]:
class HAN_GRU(nn.Module):
    def __init__(self, freq_dim=128, hidden_freq=100, hidden_time=250, output_dim=512):
        super(HAN_GRU, self).__init__()
        self.hidden_freq = hidden_freq
        self.hidden_time = hidden_time
        self.output_dim = output_dim

        # Bidirectional GRU (주파수 영역)
        self.freq_gru = nn.GRU(input_size=freq_dim, hidden_size=hidden_freq,
                               batch_first=True, bidirectional=True)

        # Attention for frequency
        self.freq_attn_fc = nn.Linear(hidden_freq * 2, hidden_freq * 2)
        self.freq_context_vector = nn.Parameter(torch.randn(hidden_freq * 2))

        # Bidirectional GRU (시간 영역)
        self.time_gru = nn.GRU(input_size=hidden_freq * 2, hidden_size=hidden_time,
                               batch_first=True, bidirectional=True)

        # Attention for time
        self.time_attn_fc = nn.Linear(hidden_time * 2, hidden_time * 2)
        self.time_context_vector = nn.Parameter(torch.randn(hidden_time * 2))

        # 마지막 출력 벡터 차원 맞추기
        self.fc_out = nn.Linear(hidden_time * 2, output_dim)

    def attention(self, rnn_output, attn_fc, context_vector, mask=None):
        u = torch.tanh(attn_fc(rnn_output))  # [B, T, D]
        attn_scores = torch.matmul(u, context_vector)  # [B, T]

        if mask is not None:
            # mask가 0인 (padding) 위치의 attention score를 -1e9로 설정 → softmax 이후 거의 0이 됨
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        attn_weights = torch.softmax(attn_scores, dim=1)  # [B, T]
        attn_output = torch.sum(rnn_output * attn_weights.unsqueeze(-1), dim=1)  # [B, D]
        return attn_output

    def forward(self, x, mask=None):
        # x: (B, C=1, F, T) → squeeze channel
        x = x.squeeze(1)   # (B, F, T)
        x = x.permute(0, 2, 1)  # (B, T, F)

        # Frequency GRU
        freq_output, _ = self.freq_gru(x)  # (B, T, 2*hidden_freq)
        freq_attn_output = self.attention(freq_output, self.freq_attn_fc, self.freq_context_vector, mask)  # (B, 2*hidden_freq)

        # 시간 축을 따라 Attention-GRU
        time_input = freq_output  # (B, T, 2*hidden_freq)
        time_output, _ = self.time_gru(time_input)  # (B, T, 2*hidden_time)
        time_attn_output = self.attention(time_output, self.time_attn_fc, self.time_context_vector, mask)  # (B, 2*hidden_time)

        # 최종 임베딩 차원으로 투사
        # out = self.fc_out(time_attn_output)  # (B, output_dim)
        return time_attn_output


In [39]:
def backbone_han_gru():
    """
    MoCo에 사용 가능한 GRU + 계층적 Attention 기반 백본 모델 정의.
    - 입력: log-mel spectrogram 형태 (B, 1, F, T)
    - 출력: 512차원 feature vector
    """
    return HAN_GRU()

#### 3.2 MoCo (MLS)

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# K: queue_g의 크기
# dim_enc: projector 통과 전, g1,g2 벡터의 차원
# dim_prj: projector 통과 후, z1,z2 벡터의 차원
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim_enc=args.out_dim, dim_prj=64, K=512, m=0.999, T=0.07, top_k=15, lambda_bce=0.5):
        super().__init__()
        self.K = K
        self.m = m
        self.T = T
        self.top_k = top_k
        self.lambda_bce = lambda_bce

        self.encoder_q = base_encoder()
        self.encoder_k = base_encoder()

        dim_enc = dim_enc
        self.proj_head_q = nn.Sequential(
            nn.Linear(dim_enc, dim_enc),
            nn.BatchNorm1d(dim_enc),
            nn.GELU(),
            nn.Linear(dim_enc, dim_prj)
        )
        self.proj_head_k = nn.Sequential(
            nn.Linear(dim_enc, dim_enc),
            nn.BatchNorm1d(dim_enc),
            nn.GELU(),
            nn.Linear(dim_enc, dim_prj)
        )

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        self.register_buffer("queue_g", F.normalize(torch.randn(dim_enc, K), dim=0))      # g2를 정규화한 후 열 단위로 Qg에 저장
        self.register_buffer("queue_z", F.normalize(torch.randn(dim_prj, K), dim=0))      # z2를 정규화한 후 열 단위로 Qz에 저장
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))               # 현재 queue에 새로 쓸 위치(인덱스)를 추적하는 포인터 역할

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, g2, z2):
        batch_size = g2.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0
        self.queue_g[:, ptr:ptr+batch_size] = g2.T.detach()
        self.queue_z[:, ptr:ptr+batch_size] = z2.T.detach()
        self.queue_ptr[0] = (ptr + batch_size) % self.K

    def forward(self, im_q, im_k, epoch=None, warmup_epochs=10):
        # encoder_q → g1 (feature)
        g1 = F.normalize(self.encoder_q(im_q), dim=1)  # shape: [B, 2048]

        # projection head → z1
        z1 = F.normalize(self.proj_head_q(g1), dim=1)  # shape: [B, 128]

        # encoder k
        with torch.no_grad():
            self._momentum_update_key_encoder()
            g2 = F.normalize(self.encoder_k(im_k), dim=1)
            z2 = F.normalize(self.proj_head_k(g2), dim=1)

        # top-k mining
        sim_g = torch.matmul(g1, self.queue_g.clone().detach())  # [N, K]
        # Ablation(1-1) Hard top-k
        topk_idx = torch.topk(sim_g, self.top_k, dim=1).indices
        y = torch.zeros_like(sim_g)
        y.scatter_(1, topk_idx, 1.0)
        # # Ablation(1-2) Soft top-k
        # topk_sim, topk_idx = torch.topk(sim_g, self.top_k, dim=1)
        # y = torch.zeros_like(sim_g)
        # y.scatter_(1, topk_idx, F.softmax(topk_sim / self.T, dim=1))

        ##################################################################
        # logits from z1 · Qz
        sim_z = torch.matmul(z1, self.queue_z.clone().detach())
        # Ablation(2-1) BCE Loss
        bce_loss = F.binary_cross_entropy_with_logits(sim_z / self.T, y) # 개선-> sigmoid(sim_z), 1/D

        # # Ablation(2-2) Weighted BCE Loss
        # pos_weight = torch.ones_like(sim_z) * (self.K / self.top_k)
        # bce_loss = F.binary_cross_entropy_with_logits(sim_z / self.T, y, pos_weight=pos_weight)
        # # Ablation(2-3) another Weighted BCE Loss (비추, top-k만 보는 느낌)
        # raw_loss = F.binary_cross_entropy_with_logits(sim_z / self.T, y, reduction='none')  # shape: [B, K]
        # bce_loss = raw_loss.sum() / (y.sum() + 1e-6)

        ###################################################################
        # InfoNCE loss
        l_pos = torch.sum(z1 * z2, dim=1, keepdim=True)
        l_neg = torch.matmul(z1, self.queue_z.clone().detach())
        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)
        info_nce_loss = F.cross_entropy(logits, labels)

        # Total loss (with optional warmup) # MLS 논문에서는 warmup 아예 안쓴다고 함
        if epoch is not None and epoch < warmup_epochs:
            loss = info_nce_loss
        # else:
        loss = info_nce_loss + self.lambda_bce * bce_loss
        # print(f"INFO_NCE: {info_nce_loss}")
        # print(f"TRIPLET: {triplet_loss}")
        # print(f"BCE: {bce_loss}")

        self._dequeue_and_enqueue(g2, z2)

        return loss, logits, labels

## 4. Pretrain

In [41]:
pretrain_project_name = f'SHS_age_gen_fade_res34_PT_top{args.top_k}_{get_timestamp()}'

In [42]:
# 모델 지정하기 전 seed 고정 필요
# seed_everything(args.seed) # Seed 고정

wandb.init(
    project="ICBHI_MSL_Ablation_all",           # 프로젝트 이름
    name=f"{pretrain_project_name}", # 실험 이름
    config={
        "epochs": args.ft_epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay,
    }
)

# 1. MoCo 모델 생성
model = MoCo(
    base_encoder = backbone_resnet,
    dim_enc = args.out_dim,
    dim_prj = args.dim_prj,
    K = args.K,
    m = args.momentum,
    T = args.T,
    top_k = args.top_k,
    lambda_bce = args.lambda_bce
).cuda()

# 2. Optimizer
# optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=0.9,
    weight_decay=args.weight_decay,
    nesterov=True
)

# 3. Cosine Scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)

# 4. Train
# Best loss 초기화
best_loss = float('inf')
best_epoch = -1

for epoch in range(args.epochs):
    # ===============================
    # Training
    # ===============================
    model.train()
    total_train_loss = 0.0

    for i, (repeat_mel, label, _) in enumerate(train_loader): # label 여기선 사용 X
        # im_q, im_k, _ = aug(repeat_mel)
        im_q, im_k, _ = window_mix(repeat_mel, train_dataset.cycle_list)

        # scaling augs
        im_q = (im_q - mean) / (std + 1e-6)
        im_k = (im_k - mean) / (std + 1e-6)

        im_q = im_q.cuda(device=args.gpu, non_blocking=True)
        im_k = im_k.cuda(device=args.gpu, non_blocking=True)

        optimizer.zero_grad()
        loss, output, target = model(im_q=im_q, im_k=im_k)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch {epoch} | Avg Train Loss: {avg_train_loss:.4f}")
    print(f"[Epoch {epoch} | Step {i}] im_q: {im_q.shape}, im_k: {im_k.shape}")

    # =====================================
    # Scheduler
    # =====================================
    scheduler.step()

    # # =====================================
    # Logging with wandb
    # =====================================
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({
        # "epoch": epoch,
        "train_loss": avg_train_loss,
        # "lr": current_lr
    })

    # =====================================
    # Checkpoint (Every 100 epochs)
    # =====================================
    if (epoch + 1) % 100 == 0:
        ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_{epoch:03d}.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, ckpt_path)
        print(f"💾 Saved checkpoint to {ckpt_path}")

    # ===============================
    # Save Best Checkpoint
    # ===============================
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        best_epoch = epoch
        best_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")

Epoch 0 | Avg Train Loss: 8.0771
[Epoch 0 | Step 31] im_q: torch.Size([128, 1, 128, 63]), im_k: torch.Size([128, 1, 128, 63])
=> Saved best checkpoint (epoch: 0, loss: 8.0771)
Epoch 1 | Avg Train Loss: 6.8493
[Epoch 1 | Step 31] im_q: torch.Size([128, 1, 128, 63]), im_k: torch.Size([128, 1, 128, 63])
=> Saved best checkpoint (epoch: 1, loss: 6.8493)
Epoch 2 | Avg Train Loss: 6.4527
[Epoch 2 | Step 31] im_q: torch.Size([128, 1, 128, 63]), im_k: torch.Size([128, 1, 128, 63])
=> Saved best checkpoint (epoch: 2, loss: 6.4527)
Epoch 3 | Avg Train Loss: 6.3260
[Epoch 3 | Step 31] im_q: torch.Size([128, 1, 128, 63]), im_k: torch.Size([128, 1, 128, 63])
=> Saved best checkpoint (epoch: 3, loss: 6.3260)
Epoch 4 | Avg Train Loss: 6.0583
[Epoch 4 | Step 31] im_q: torch.Size([128, 1, 128, 63]), im_k: torch.Size([128, 1, 128, 63])
=> Saved best checkpoint (epoch: 4, loss: 6.0583)
Epoch 5 | Avg Train Loss: 5.7364
[Epoch 5 | Step 31] im_q: torch.Size([128, 1, 128, 63]), im_k: torch.Size([128, 1, 128,

## 5. Linear Evaluation

In [277]:
label_weights = torch.tensor([1.1, 1.3, .9, .9]).to(device)

In [278]:
import torch
import torch.nn.functional as F

class JointWeightedBCELoss(torch.nn.Module):
    def __init__(self, label_weights=label_weights, lambda_joint=20.0):
        super(JointWeightedBCELoss, self).__init__()
        self.label_weights = label_weights  # [crackle, wheeze, child, female]
        self.lambda_joint = lambda_joint

    def forward(self, logits, labels):
        """
        logits: [B, 4] raw model outputs
        labels: [B, 4] binary labels (0 or 1)
        """

        # 1. 기본 가중치 BCE 손실 계산
        bce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')  # [B,4]
        weighted_loss = bce_loss * self.label_weights                                   # [B,4]
        base_loss = weighted_loss.sum(dim=1).mean()                                     # scalar

        # 2. crackle & wheeze에 대한 오류 확률 계산
        probs = torch.sigmoid(logits)
        
        # crackle (0번 label), wheeze (1번 label)
        crackle_labels = labels[:, 0]
        wheeze_labels = labels[:, 1]
        crackle_probs = probs[:, 0]
        wheeze_probs = probs[:, 1]

        # 예측 오류 확률 (높을수록 잘못 예측)
        crackle_error = torch.where(crackle_labels == 1, 1 - crackle_probs, crackle_probs)
        wheeze_error = torch.where(wheeze_labels == 1, 1 - wheeze_probs, wheeze_probs)

        # 3. crackle-wheeze 공동 페널티 (오류 확률 곱 → 동시에 틀릴수록 큼)
        joint_loss = crackle_error * wheeze_error  # [B]

        # 최종 공동 손실 계산
        joint_loss_mean = joint_loss.mean()  # 평균을 사용해 배치 크기와 독립적으로 유지

        # 4. 총 손실
        total_loss = base_loss + self.lambda_joint * joint_loss_mean

        return total_loss

#### validate

In [279]:
len(test_dataset)

2756

In [280]:
def validate(model, val_loader, criterion, device):
    import numpy as np
    from sklearn.metrics import confusion_matrix

    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    total_cr_loss = 0.0
    total_wh_loss = 0.0
    total_ch_loss = 0.0
    total_fe_loss = 0.0

    with torch.no_grad():
        for cycle, labels, meta in val_loader:
            cycle, labels, meta = cycle.to(device), labels.to(device), meta

            # 나이
            child = meta[2].unsqueeze(1).to(device)

            # 성별
            female = meta[3].unsqueeze(1).to(device)

            # (crackle, wheeze, child, female)
            labels = torch.cat([labels, child, female], dim=1).to(device)

            cycle = (cycle - mean) / (std + 1e-6)
            outputs = model(cycle)
            
            if criterion is None:
                bce_loss = F.binary_cross_entropy_with_logits(outputs, labels, reduction='none')      # [B, 4]
                weighted_loss = bce_loss * label_weights                    # shape: [B, 4]
                loss_base = weighted_loss.sum(dim=1).mean()                      # 각 샘플별 sum → batch mean

                # crackle, wheeze loss만 추출 (B, 0) and (B, 1)
                crackle_loss = bce_loss[:, 0]   # [B]
                wheeze_loss = bce_loss[:, 1]    # [B]
                child_loss = bce_loss[:, 2]    # [B]
                female_loss = bce_loss[:, 3]    # [B]

                # joint_loss 계산 (둘 다 어려운 샘플 penalize)
                joint_loss = F.relu(crackle_loss*label_weights[0]  + wheeze_loss*label_weights[1] - 1.) ** 2   # [B]
                joint_loss_mean = joint_loss.mean()

                # --- Final Loss ---
                loss = loss_base + joint_loss_mean

                # [DEBUG] 각 loss 로깅용
                total_cr_loss += crackle_loss.mean()
                total_wh_loss += wheeze_loss.mean()
                total_ch_loss += child_loss.mean()
                total_fe_loss += female_loss.mean()

            else:
                loss = criterion(outputs, labels)

            running_loss += loss.item()

            preds = (torch.sigmoid(outputs) > 0.5).int()  # threshold = 0.5
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0).numpy()   # [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy() # [N, 2]

    avg_loss = running_loss / len(val_loader)

    test_crackle_loss = total_cr_loss / len(val_loader)
    test_wheeze_loss = total_wh_loss / len(val_loader)
    test_child_loss = total_ch_loss / len(val_loader)
    test_female_loss = total_fe_loss / len(val_loader)

    return avg_loss, test_crackle_loss, test_wheeze_loss, test_child_loss, test_female_loss, all_labels, all_preds


### Weighted loss

In [281]:
a = torch.tensor([1,1])
b = torch.tensor([2])
c = torch.tensor([2])

torch.cat([a,b,c],dim=0)

tensor([1, 1, 2, 2])

In [282]:
from collections import Counter
import torch
import numpy as np

# 💡 다중 라벨 예시: targets는 [B, C] binary matrix (e.g., [1, 0, 1, 0])
label_list = []

# 👇 train_dataset이 (x, multi_label_tensor, _) 형태라고 가정
for _, label, meta in test_dataset:

        # multi-label
        label = label.to(device)

        # 나이
        child = torch.tensor([meta[2]]).to(device)

        # 성별
        female = torch.tensor([meta[3]]).to(device)

        # (crackle, wheeze, child, female)
        labels = torch.cat([label, child, female], dim=0).to(device)

        label_list.append(labels) 

# 전체 label을 합치기
all_labels = torch.stack(label_list, dim=0)  # shape: [N, C]
num_classes = all_labels.size(1)
total_samples = all_labels.size(0)

# 클래스별 1의 개수 세기
class_counts = all_labels.sum(dim=0)  # shape: [C]
class_weights = total_samples / (num_classes * class_counts + 1e-6)  # smoothed

# tensor로 변환
class_weights_tensor = class_weights.float().to(device)

# 🔹 출력
for i, count in enumerate(class_counts.tolist()):
    print(f"Class {i} - Positives (1): {int(count)} / {total_samples} samples")
print(f"Class Weights: {class_weights_tensor}")

alpha_norm = class_weights_tensor / class_weights_tensor.sum()
print(f"alpha_norm: {alpha_norm}")

Class 0 - Positives (1): 792 / 2756 samples
Class 1 - Positives (1): 528 / 2756 samples
Class 2 - Positives (1): 333 / 2756 samples
Class 3 - Positives (1): 525 / 2756 samples
Class Weights: tensor([0.8699, 1.3049, 2.0691, 1.3124], device='cuda:0')
alpha_norm: tensor([0.1566, 0.2349, 0.3724, 0.2362], device='cuda:0')


In [283]:
# import torch

# # ⚙️ 각 클래스의 positive 개수 (from label distribution)
# crackle_pos = 262 + 83  # label 1 or 3
# wheeze_pos  = 84 + 83   # label 2 or 3

# total_samples = 885
# num_classes = 2

# # ⚖️ 기본 class weight 계산: inverse frequency
# class_counts = torch.tensor([crackle_pos, wheeze_pos], dtype=torch.float)
# class_weights = total_samples / (num_classes * class_counts + 1e-6)

# # ✅ 정규화: sum = 1
# alpha_norm = class_weights / class_weights.sum()

# # 출력
# print("Raw Class Weights:", class_weights)
# print("Normalized Alpha (sum=1):", alpha_norm)


### Multi-label Focal Loss

In [284]:
import torch.nn.functional as F
import torch.nn as nn

class MultiLabelFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha  # Tensor of shape [C], or scalar
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        logits: [B, C] - raw scores
        targets: [B, C] - binary or soft labels
        """
        probs = torch.sigmoid(logits)  # [B, C]
        ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')  # [B, C]

        pt = probs * targets + (1 - probs) * (1 - targets)  # p_t
        focal_weight = (1 - pt) ** self.gamma               # (1 - pt)^γ

        loss = focal_weight * ce_loss                       # focal weight 적용

        if self.alpha is not None:
            alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)  # [B, C]
            loss = alpha_factor * loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class StableMultiLabelFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', eps=1e-6):
        super().__init__()
        self.alpha = alpha  # tensor of shape [C] or None
        self.gamma = gamma
        self.reduction = reduction
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs = torch.clamp(probs, min=self.eps, max=1.0 - self.eps)

        # Focal weight
        pt = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - pt) ** self.gamma

        # BCE loss
        ce_loss = - (targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))

        loss = focal_weight * ce_loss

        # Safe alpha (class weights) application
        if self.alpha is not None:
            if self.alpha.dim() == 1:
                alpha = self.alpha.view(1, -1)  # reshape for broadcasting
            else:
                alpha = self.alpha
            loss = alpha * loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


In [285]:
# from collections import Counter
# import torch

# label_dist = Counter({0:456, 1:262, 2:84, 3:83})  # Finetune 분포

# # Crackle: (1 + Both), Wheeze: (2 + Both)
# n_crackle = label_dist[1] + label_dist[3]  # 262 + 83
# n_wheeze  = label_dist[2] + label_dist[3]  # 84 + 83
# n_total   = sum(label_dist.values())       # 885

# pos_weight = torch.tensor([
#     (n_total - n_crackle) / (n_crackle + 1e-6),
#     (n_total - n_wheeze) / (n_wheeze + 1e-6)
# ], device=device)

# print(pos_weight)

## Linear Evaluation

In [286]:
wandb.finish()

In [287]:
## Wandb 정의

# import wandb
finetune_project_name = f'SHS_age_gen_fade_res34_LE_top{args.top_k}_{get_timestamp()}'

wandb.init(
    project="ICBHI_MSL_Ablation_all",           # 프로젝트 이름
    name=f"{finetune_project_name}", # 실험 이름
    config={
        "epochs": args.ft_epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay,
    }
)

In [None]:
import os
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

# 1. Model Load
# 위에서부터 했다면
load_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
# 중간부터 이어서 한다면
# load_ckpt_path = CHECKPOINT_PATH + "/SHS_aug(T.N)_PT_128bs_top15_0.5ld_2507201907_best_checkpoint.pth.tar"

# 저장 경로
save_ckpt_path = CHECKPOINT_PATH+"/LE_pth"

# 재현성을 위한 시드 재설정
seed_everything(args.seed)

# MoCo 모델 생성 및 체크포인트 로드
model_eval = MoCo(
    base_encoder=backbone_resnet,
    dim_enc=args.out_dim,
    dim_prj=args.dim_prj,
    K=args.K,
    m=args.momentum,
    T=args.T,
    top_k=args.top_k,
    lambda_bce=args.lambda_bce
)

checkpoint = torch.load(load_ckpt_path, map_location=device)
model_eval.load_state_dict(checkpoint["state_dict"])

# 사전 학습된 encoder 추출
encoder = model_eval.encoder_q.to(device)

# 3. Fine-tuning을 위한 분류 모델 정의
# !!!!!!!!!!!!!!!!!!!!!!!!! num_classes = 4로 변경 !!!!!!!!!!!!!!!!!!!!!!!!!
# !!!!!!!!!!!!!!!!!!!!!!!!! num_classes = 4로 변경 !!!!!!!!!!!!!!!!!!!!!!!!!
# !!!!!!!!!!!!!!!!!!!!!!!!! num_classes = 4로 변경 !!!!!!!!!!!!!!!!!!!!!!!!!
# !!!!!!!!!!!!!!!!!!!!!!!!! num_classes = 4로 변경 !!!!!!!!!!!!!!!!!!!!!!!!!
class FineTuningModel(nn.Module):
    def __init__(self, encoder, out_dim=args.out_dim, num_classes=4):   # (crackle, wheeze, child, female)
        super().__init__()
        self.encoder = encoder
        # 마지막 FC layer를 제외한 encoder의 모든 레이어 freeze
        for param in self.encoder.parameters():
            param.requires_grad = False

        # 새로운 분류 헤드 추가
        self.classifier = nn.Linear(out_dim, num_classes)
        # self.classifier = nn.Sequential(
        #     nn.Linear(out_dim, out_dim),
        #     nn.GELU(),
        #     nn.Dropout(0.3),
        #     nn.Linear(out_dim, 128),
        #     nn.GELU(),
        #     nn.Linear(128, num_classes)
        # )

    def forward(self, x):
        features = self.encoder(x)
        return self.classifier(features)

# 재현성을 위한 시드 재설정
seed_everything(args.seed)

# 4. 모델, 손실 함수, 옵티마이저 설정
model = FineTuningModel(encoder, out_dim = args.out_dim).to(device)
##############################

# Ablation(3-1) LE -> BCE Loss
# criterion = nn.BCEWithLogitsLoss()

# Ablation(3-2) LE -> 임의로 Wheeze에 가중치를 더 주는 BCE Loss
criterion = JointWeightedBCELoss()

# Ablation(3-3) LE -> Weighted BCE Loss
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Ablation(3-4) LE -> Multi-label Focal Loss
# criterion = MultiLabelFocalLoss(
#     alpha=alpha_norm.to(device),  # 정규화된 값
#     gamma=1.0,                    # hard label일 경우
#     reduction='mean'
# )

############################
# optimizer = optim.AdamW(model.classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=args.lr,
    momentum=0.9,
    weight_decay=args.weight_decay,
    nesterov=True
)
scheduler = CosineAnnealingLR(optimizer, T_max=args.ft_epochs, eta_min=1e-6)  # Linear Evaluation에서 epochs는 다르게 적용

# Best loss 초기화
best_loss = float('inf')
best_epoch = -1

# 5. Linear Evaluation
for epoch in range(args.ft_epochs):

    # ===============================
    # 1. Training
    # ===============================
    model.train()
    total_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0

    total_cr_loss = 0.0
    total_wh_loss = 0.0
    total_ch_loss = 0.0
    total_fe_loss = 0.0

    all_preds = []
    all_labels = []
    all_outputs = []

    pbar = tqdm(train_loader, desc='Linear Evaluation')
    for i, (cycle, labels, meta) in enumerate(pbar):

        # 나이
        child = meta[2].unsqueeze(1)

        # 성별
        female = meta[3].unsqueeze(1)

        # (crackle, wheeze, child, female)
        labels = torch.cat([labels, child, female], dim=1)   

        # Forward pass
        cycle = cycle.cuda(args.gpu)
        cycle = (cycle - mean) / (std + 1e-6)
        labels = labels.cuda(args.gpu)

        # backpropagation
        optimizer.zero_grad()
        output = model(cycle)

        if criterion is None:
            bce_loss = F.binary_cross_entropy_with_logits(output, labels, reduction='none')      # [B, 4]
            weighted_loss = bce_loss * label_weights                    # shape: [B, 4]
            loss_base = weighted_loss.sum(dim=1).mean()                 # 각 샘플별 sum → batch mean

            crackle_loss = bce_loss[:, 0]   # [B]
            wheeze_loss = bce_loss[:, 1]    # [B]
            child_loss = bce_loss[:, 2]    # [B]
            female_loss = bce_loss[:, 3]    # [B]

            # joint_loss 계산 (둘 다 어려운 샘플 penalize)
            joint_loss = F.relu(crackle_loss*label_weights[0]  + wheeze_loss*label_weights[1] - 1.) ** 2   # [B]
            joint_loss_mean = joint_loss.mean()

            # --- Final Loss ---
            loss = loss_base + joint_loss_mean

            # [DEBUG] 각 loss 로깅용
            total_cr_loss += crackle_loss.mean()
            total_wh_loss += wheeze_loss.mean()
            total_ch_loss += child_loss.mean()
            total_fe_loss += female_loss.mean()

        else:
            loss = criterion(output, labels)

        loss.backward()
        optimizer.step()

        # loss 계산
        total_loss += loss.item() # loss : -> float

        # 예측값과 실제값 저장 ( Ablation(4-1) threshold )
        predicted = (torch.sigmoid(output) > 0.5).float()
        all_preds.append(predicted.detach().cpu())
        all_labels.append(labels.detach().cpu())
        all_outputs.append(output.detach().cpu())

    # train loss
    train_loss = total_loss / len(train_loader)

    train_crackle_loss = total_cr_loss / len(train_loader)
    train_wheeze_loss = total_wh_loss / len(train_loader)
    train_child_loss = total_ch_loss / len(train_loader)
    train_female_loss = total_fe_loss / len(train_loader)

    # Concatenate
    all_preds = torch.cat(all_preds, dim=0).numpy()    # shape: [N, 4]
    all_labels = torch.cat(all_labels, dim=0).numpy()  # shape: [N, 4]
    all_output = torch.cat(all_outputs, dim=0).numpy()

    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}")
    
    # =====================================
    # 2-Edited. Multi-class 민감도/특이도 계산
    # =====================================
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import wandb
    from sklearn.metrics import confusion_matrix

    def multilabel_to_multiclass(y):
        # Crackle → 1, Wheeze → 2, Both → 3, None → 0
        y = np.array(y)
        return y[:, 0] + y[:, 1]*2

    def evaluate_multiclass_confusion(y_true, y_pred, class_names=["Normal", "Wheeze", "Crackle", "Both"]):
        y_true_cls = multilabel_to_multiclass(y_true)
        y_pred_cls = multilabel_to_multiclass(y_pred)

        cm = confusion_matrix(y_true_cls, y_pred_cls, labels=[0, 1, 2, 3])

        # N_n: 정상 → 정상
        N_n = cm[0, 0]
        N_total = cm[0].sum()

        # 이상 클래스 정답 수: W, C, B
        W_total = cm[1].sum()
        C_total = cm[2].sum()
        B_total = cm[3].sum()

        # 각각의 정답 → 정확한 예측만 고려
        W_w = cm[1, 1]
        C_c = cm[2, 2]
        B_b = cm[3, 3]

        SP = N_n / (N_total + 1e-6) #spec
        SE = (W_w + C_c + B_b) / (W_total + C_total + B_total + 1e-6) #sense

        AS = (SP + SE) / 2
        HS = 2 * SP * SE / (SP + SE + 1e-6)

        return cm, SE, SP, y_true_cls, y_pred_cls

    def log_multiclass_conf_matrix_wandb(cm, class_names, sens, spec, normalize, tag):
        # Normalize (비율) 옵션
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
            fmt = '.2f'
            title = "Confusion Matrix (Normalized %)"
        else:
            fmt = 'd'
            title = "Confusion Matrix (Raw Count)"

        fig, ax = plt.subplots(figsize=(7, 6))
        sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names, ax=ax)

        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(title)

        icbhi_score = (sens + spec) / 2
        # 우하단에 성능 출력
        ax.text(
            0.99, 0.15,
            f"Sensitivity: {sens*100:.2f}%\nSpecificity: {spec*100:.2f}%\nICBHI Score: {icbhi_score*100:.2f}%",
            ha='right', va='bottom',
            transform=plt.gca().transAxes,
            fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
        )

        plt.tight_layout()
        # wandb.log({tag: wandb.Image(fig)})
        # plt.close(fig)
        return fig

    # 1. 4-class Confusion Matrix 평가
    class_names = ["Normal", "Crackle", "Wheeze", "Both"]
    cm_4x4, finetune_train_sens, finetune_train_spec, y_true_cls, y_pred_cls = evaluate_multiclass_confusion(all_labels, all_preds, class_names)
    finetune_icbhi_score = (finetune_train_sens + finetune_train_spec)/2

    print("4-Class Confusion Matrix:\n", cm_4x4)
    print(f"Sensitivity: {finetune_train_sens:.4f}, Specificity: {finetune_train_spec:.4f}, ICBHI Score: {finetune_icbhi_score:.4f}")


    # ===============================
    # 3. Validation
    # ===============================
    test_loss, test_crackle_loss, test_wheeze_loss, test_child_loss, test_female_loss, test_labels, test_preds = validate(
        model, test_loader, criterion, device
    )

    precision = precision_score(test_labels, test_preds, average='macro')
    recall = recall_score(test_labels, test_preds, average='macro')
    f1 = f1_score(test_labels, test_preds, average='macro')

    test_cm_4x4, test_sens, test_spec, test_y_true_cls, test_y_pred_cls = evaluate_multiclass_confusion(test_labels, test_preds)
    test_icbhi_score = (test_sens+test_spec)/2

    print("[Validation] Confusion Matrix:\n", test_cm_4x4)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"[VALIDATION] Sensitivity: {test_sens:.4f}, Specificity: {test_spec:.4f}, Avg ICBHI Score: {(test_sens+test_spec)/2:.4f}")
    print("##################################################")


    # ===============================
    # 4. Confusion Matrix
    # ===============================

    # 2. Finetune Count Confusion Matrix 시각화
    fig_finetune_raw = log_multiclass_conf_matrix_wandb(cm_4x4, class_names, finetune_train_sens, finetune_train_spec, normalize=False, tag="finetune_conf_matrix_raw")
    fig_finetune_norm = log_multiclass_conf_matrix_wandb(cm_4x4, class_names, finetune_train_sens, finetune_train_spec, normalize=True, tag="finetune_conf_matrix_norm")

    # 3. Test Confusion Matrix 시각화
    fig_test_raw = log_multiclass_conf_matrix_wandb(test_cm_4x4, class_names, test_sens, test_spec, normalize=False, tag="test_conf_matrix_raw")
    fig_test_norm = log_multiclass_conf_matrix_wandb(test_cm_4x4, class_names, test_sens, test_spec, normalize=True, tag="test_conf_matrix_norm")

    # 4. log dictionary 생성
    wandb_log_dict = {
        "finetune_conf_matrix_raw": wandb.Image(fig_finetune_raw),
        "finetune_conf_matrix_norm": wandb.Image(fig_finetune_norm),
        "test_conf_matrix_raw": wandb.Image(fig_test_raw),
        "test_conf_matrix_norm": wandb.Image(fig_test_norm)
    }

    # =====================================
    # 5. Checkpoint (Every 50 epochs)
    # =====================================
    if (epoch + 1) % 50 == 0:
        ckpt_path = save_ckpt_path + f"{finetune_project_name}_{epoch:03d}.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, ckpt_path)
        print(f"💾 Saved checkpoint to {save_ckpt_path}")

    # ===============================
    # 6. Save Best Checkpoint
    # ===============================
    if test_loss < best_loss:
        best_loss = test_loss
        best_epoch = epoch
        best_ckpt_path = save_ckpt_path + f"{finetune_project_name}_best.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")


        # 🔹 Confusion Matrix Logging for Best
        cm_best, sens_best, spec_best,_, _ = evaluate_multiclass_confusion(test_labels, test_preds, class_names)
        fig_best_raw = log_multiclass_conf_matrix_wandb(cm_best, class_names, sens_best, spec_best, normalize=False, tag="best_test_conf_matrix_raw")

        fig_best_norm = log_multiclass_conf_matrix_wandb(cm_best, class_names, sens_best, spec_best, normalize=True, tag="best_test_conf_matrix_norm")

        wandb_log_dict.update({
            "best_test_conf_matrix_raw": wandb.Image(fig_best_raw),
            "best_test_conf_matrix_norm": wandb.Image(fig_best_norm)
        })


    if epoch == args.ft_epochs - 1:
        # 🔸 Confusion Matrix Logging for Last Epoch
        cm_last, sens_last, spec_last, _, _  = evaluate_multiclass_confusion(test_labels, test_preds, class_names)
        fig_last_raw = log_multiclass_conf_matrix_wandb(cm_last, class_names, sens_last, spec_last, normalize=False, tag="last_test_conf_matrix_raw")

        fig_last_norm = log_multiclass_conf_matrix_wandb(cm_last, class_names, sens_last, spec_last, normalize=True, tag="last_test_conf_matrix_norm")

        wandb_log_dict.update({
            "last_test_conf_matrix_raw": wandb.Image(fig_last_raw),
            "last_test_conf_matrix_norm": wandb.Image(fig_last_norm)
        })
    # =====================================
    # 7. Logging with wandb confusion matrix
    # =====================================

    # step 1. metrics
    wandb.log({
        # Train metrics
        "Finetune/epoch": epoch,
        "Finetune/train_loss": train_loss,
        "Finetune/train_crackle_loss": train_crackle_loss,
        "Finetune/train_wheeze_loss": train_wheeze_loss,
        "Finetune/train_child_loss": train_child_loss,
        "Finetune/train_female_loss": train_female_loss,
        "Finetune/train_sens": finetune_train_sens,
        "Finetune/train_spec": finetune_train_spec,
        "Finetune/icbhi_score": finetune_icbhi_score,

        # Test metrics
        "Test/loss": test_loss,
        "Test/test_crackle_loss": test_crackle_loss,
        "Test/test_wheeze_loss": test_wheeze_loss,
        "Test/test_child_loss": test_child_loss,
        "Test/test_female_loss": test_female_loss,
        "Test/sensitivity": test_sens,
        "Test/specificity": test_spec,
        "Test/icbhi_score": test_icbhi_score
    })

    # step 2. Confusion matrix
    wandb.log(wandb_log_dict)

    plt.close(fig_finetune_raw)
    plt.close(fig_finetune_norm)
    plt.close(fig_test_raw)
    plt.close(fig_test_norm)
    if 'fig_best_raw' in locals(): plt.close(fig_best_raw)
    if 'fig_best_norm' in locals(): plt.close(fig_best_norm)
    if 'fig_last_raw' in locals(): plt.close(fig_last_raw)
    if 'fig_last_norm' in locals(): plt.close(fig_last_norm)

    # ===============================
    # 8. Scheduler Step
    # ===============================
    scheduler.step()

wandb.finish()

In [None]:
TP = cm_last[1:, 1:].sum()
FN = cm_last[1:, 0].sum()
FP = cm_last[0, 1:].sum()
TN = cm_last[0, 0]
print( f"{TP}/{FN}/{FP}/{TN}" )

sens =FN / (TP + FN + 1e-6)
spec = TN / (TN + FP + 1e-6)

print(f"{TP}/{TP + FN}")
print(f"{TN}/{TN + FP}")
print(f"{sens}/{spec}")
print(f"{(0.31+spec)/2}")

497/680/539/1040
497/1177
1040/1579
0.5777400165014953/0.6586447114258108
0.4843223557129054


In [None]:
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))


In [None]:
# import numpy as np

# # sigmoid 적용
# sigmoid_output = sigmoid(all_output)  # shape: (N, 2)
# all_preds = (sigmoid_output > 0.5).astype(int)  # binary prediction
# all_labels = all_labels.astype(int)  # 정수형으로 일치

# # 맞춘 것들
# correct_mask = np.all(all_preds == all_labels, axis=1)
# correct = np.concatenate([sigmoid_output, all_preds, all_labels], axis=1)[correct_mask]

# # 틀린 것들
# incorrect_mask = ~correct_mask
# incorrect_preds = all_preds[incorrect_mask]
# incorrect_labels = all_labels[incorrect_mask]
# incorrect_sigmoid = sigmoid_output[incorrect_mask]
# incorrect_concat = np.concatenate([incorrect_sigmoid, incorrect_preds, incorrect_labels], axis=1)

# # 그룹별 필터링
# def get_mismatched_by_label(target_label):
#     mask = np.all(incorrect_labels == target_label, axis=1)
#     return incorrect_concat[mask]

# # 각 그룹 추출
# wrong_10 = get_mismatched_by_label([1, 0])  # crackle
# wrong_01 = get_mismatched_by_label([0, 1])  # wheeze
# wrong_11 = get_mismatched_by_label([1, 1])  # both
# wrong_00 = get_mismatched_by_label([0, 0])  # normal


In [None]:
# print("\n✅ 맞춘 것들 (예: [sigmoid1, sigmoid2, pred1, pred2, label1, label2])")
# print(correct)

In [None]:
# import numpy as np
# np.set_printoptions(threshold=10000000)
# print(np.concatenate([sigmoid(all_output), all_preds, all_labels], axis=1)[:])

In [None]:
# all_output[:128]

In [None]:
# len(all_outputs[0])