In [None]:
import torch
import torch.nn as nn
import torch.nn.utils as utils
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer, ViTModel
import evaluate

# GPU 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#############################################
########### 변수 설정 #########################
#############################################

MAX_STEPS = 50
LOG_INTERVER = 1
BATCH_SIZE = 5  # 한 번에 4개의 음성 파일을 처리하도록 설정

# EPSILON = 0.0001 # SPSA Perturbation 크기
# ALPHA = 0.602 # SPSA Learning rate scaling 0.602
# GAMMA = 0.101  # SPSA Decay 0.101
# AK=0.0004 
# CK=0.000025 # gradient 추정시 사용 값
# O=10
# P_TRIGGER_EPSILON = 0.00000005 # p_trigger 업데이트시 사용

EPSILON = 0.001  # 기존보다 10배 증가
ALPHA = 0.602
GAMMA = 0.101
AK = 0.00001  # 기존보다 25배 증가
CK = 0.005  # 기존보다 5배 증가
O = 7  # 기존보다 감소
P_TRIGGER_EPSILON = 0.0000001  # 기존보다 10배 증가

MAX_GRAD = 3000 # Gradient clipping 제한치


# LOSS_FN = "wer"  # WER 기반 Loss
LOSS_FN = "cross entropy"  # Loss 유형
MAX_FRAMES = 3000
meshgrid_HIDDEN_DIM = 768  # ViT hidden_dim
MAX_NEW_TOKENS = 444  # Whisper default Max Tokens 448 - 4. 4: decoder_input_ids 개수
ENCODER_NAME = "google/vit-base-patch16-224-in21k"

# whisper_version = "openai/whisper-large-v3"
whisper_version = "openai/whisper-small"

if whisper_version == "openai/whisper-small":
    NUM_MEL_BINS = 80
elif whisper_version == "openai/whisper-large-v3":
    NUM_MEL_BINS = 128
else:
    raise ValueError("위스퍼 버전 확인 필요")

#############################################
#############################################


#############################################
#####  MeshGridMask (Binary Mask) #####
#############################################

# ✅ **MeshGridMask: 필터 생성 (0 또는 1)**
class MeshGridMask(nn.Module):
    def __init__(self, mel_bins, frames):
        super(MeshGridMask, self).__init__()
        # 필터는 (mel_bins, frames) 크기이며 학습 가능한 파라미터로 설정됨
        self.filter = nn.Parameter(torch.randint(0, 2, (mel_bins, frames), dtype=torch.float32), requires_grad=False)  # 0 또는 1 초기값

    def forward(self, mel):
        """
        mel: (batch_size, mel_bins, frames)
        filter: (mel_bins, frames) → 확장하여 배치 차원 적용
        """
        return mel * self.filter  # 필터링된 mel spectrogram

############################

# meshgrid 초기화
meshgrid = MeshGridMask(mel_bins=NUM_MEL_BINS, frames=MAX_FRAMES).to(device)


############################
###### 업데이트 로직 설정 ######
############################

from whisper.normalizers import EnglishTextNormalizer

wer_metric = evaluate.load("wer")
normalizer = EnglishTextNormalizer() # normalizer 적용

# def calculate_wer(references, predictions): 
#     return wer_metric.compute(references=references, predictions=predictions)

def calculate_wer(references, predictions, tokenizer):
    """
    패딩된 labels를 무시하고 WER을 계산하는 함수
    """
    with torch.no_grad():  # Gradient 저장 방지
        filtered_references = []
        
        # 패딩을 제외한 원본 labels 추출 & list of characters → list of words 변환
        for ref in references:
            ref_filtered = [word for word in ref if word != tokenizer.pad_token_id]
            filtered_references.append("".join(ref_filtered))  # 🔹 join()을 사용해 문자 리스트를 문자열로 변환

        # ✅ 리스트의 각 요소에 `normalizer()` 적용
        filtered_references = ["".join(ref_filtered) for ref_filtered in filtered_references]  # 리스트 -> 문자열 변환
        filtered_references = [normalizer(ref) for ref in filtered_references]  # 정상화 적용
        predictions = [normalizer(pred) for pred in predictions]  # 정상화 적용

    # ✅ 정상화된 데이터를 WER 계산에 사용
    return wer_metric.compute(references=filtered_references, predictions=predictions)


import torch.nn.functional as F

def calculate_cross_entropy_loss(whisper_model, mel_with_delta, labels):
    """
    mel_with_delta: meshgrid와 결합된 Mel Spectrogram
    labels: ground truth token IDs
    """
    with torch.no_grad():  # Gradient 저장 방지
        outputs = whisper_model(input_features=mel_with_delta, labels=labels)
    return outputs.loss  # CrossEntropy loss

# SPSA 업데이트 함수 (Trigger Vector + Decoder 학습)
# 디코더 전체를 벡터화하여 perturbation 후, 다시 벡터에서 파라미터로 변환


class SPSA:
    def __init__(self, epsilon=0.01, epsilon_decay=0.99, min_epsilon=0.0001):
        """
        epsilon: 초기 perturbation 크기
        epsilon_decay: 매 step마다 epsilon을 줄이는 감쇠율 (0 < epsilon_decay < 1)
        min_epsilon: epsilon의 최소값 (너무 작아지지 않도록 제한)
        """
        super(SPSA, self).__init__()   
        self.epsilon_0 = epsilon
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.min_epsilon = min_epsilon
        self.step_count = 0  # 현재 스텝

    def parameter_update(self):
        """ epsilon을 감소시키는 업데이트 """
        self.step_count += 1
        self.epsilon = max(self.epsilon_0 * (self.epsilon_decay ** self.step_count), self.min_epsilon)

    def spsa_update(self, meshgrid, whisper_model, mel: list, labels: list):
 
        torch.cuda.empty_cache()  # 메모리 캐시 정리
      
        # **1. 필터 파라미터를 가져오기**
        filter_params = meshgrid.filter.detach().clone()  # ✅ meshgrid.filter에서 직접 복사                


        # Perturbation 방식 변경 (일부 값만 0↔1로 변경)
        num_to_flip = int(self.epsilon * filter_params.numel())  # 업데이트할 개수 결정
        
        # 서로 다른 perturbation을 적용하도록 두 개의 indices 생성
        indices_1 = torch.randint(0, meshgrid_params.numel(), (num_to_flip,), device=device)
        indices_2 = torch.randint(0, meshgrid_params.numel(), (num_to_flip,), device=device)

        # 필터 1 적용
        filter_1 = meshgrid_params.clone()
        filter_1.view(-1)[indices_1] = 1 - filter_1.view(-1)[indices_1]  # 기존 값 반전
        mel_filtered_1 = mel * filter_1.unsqueeze(0)

        # 필터 2 적용
        filter_2 = meshgrid_params.clone()
        filter_2.view(-1)[indices_2] = 1 - filter_2.view(-1)[indices_2]  # 기존 값 반전
        mel_filtered_2 = mel * filter_2.unsqueeze(0)

        predictions_1 = whisper_model.generate(input_features=mel_filtered_1, max_new_tokens=MAX_NEW_TOKENS)
        predictions_2 = whisper_model.generate(input_features=mel_filtered_2, max_new_tokens=MAX_NEW_TOKENS)

        ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)
        pred_1_texts = tokenizer.batch_decode(predictions_1, skip_special_tokens=True)
        pred_2_texts = tokenizer.batch_decode(predictions_2, skip_special_tokens=True)

        print("Ref", ref_texts)
        print("Pred_1",pred_1_texts)
        print("Pred_2",pred_2_texts)


        if LOSS_FN == "wer":

            loss_1 = calculate_wer(ref_texts, pred_1_texts, tokenizer) 
            loss_2 = calculate_wer(ref_texts, pred_2_texts, tokenizer) 

        elif LOSS_FN == "cross entropy":

            loss_1 = calculate_cross_entropy_loss(whisper_model, mel_filtered_1, labels)
            loss_2 = calculate_cross_entropy_loss(whisper_model, mel_filtered_2, labels)
        else:
            raise ValueError("Loss function not supported")

        if loss_1 < loss_2:
            meshgrid.filter.data.copy_(filter_1)  # Gradient tracking 유지
        else:
            meshgrid.filter.data.copy_(filter_2)  # Gradient tracking 유지

        return loss_1, loss_2


#########################################
####### Whisper 모델 로드 및 Freeze ########
#########################################

processor = WhisperProcessor.from_pretrained(whisper_version)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_version).to(device)
tokenizer = WhisperTokenizer.from_pretrained(whisper_version, language="en", task="transcribe")

 # Whisper 모델 Freeze
for param in whisper_model.parameters():
    param.requires_grad = False 



#######################################
######### 데이터 준비 및 모델 초기화 #########
#######################################

from datasets import load_dataset

DATASET_ID = "Jzuluaga/atcosim_corpus"
dataset = load_dataset(DATASET_ID, "default", split="train[:2%]")  # 데이터 일부만 사용

# 데이터 전처리
def preprocess_data(batch):
    audio = batch["audio"]
    mel = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    labels = processor.tokenizer(batch["text"], return_tensors="pt", padding="longest").input_ids.squeeze(0)
    return {"mel": torch.tensor(mel, dtype=torch.float32).unsqueeze(0).to(device), "labels": labels.to(device)}

    
processed_dataset = [preprocess_data(item) for item in dataset]

# CustomDataset 및 DataLoader 정의
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    

# ✅ collate_fn 정의 → 다른 길이의 labels 처리
def collate_fn(batch):
    mel_batch = torch.stack([item["mel"].squeeze(0) for item in batch])  # Mel-Spectrogram 배치화
    labels_batch = [item["labels"] for item in batch]  # Labels 리스트로 유지

    # ✅ 가장 긴 labels에 맞게 패딩
    labels_padded = pad_sequence(labels_batch, batch_first=True, padding_value=tokenizer.pad_token_id)

    return {"mel": mel_batch, "labels": labels_padded}

# ✅ DataLoader 적용 (배치 크기 지정)
dataset = CustomDataset(processed_dataset)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)  # ✅ collate_fn 추가

###################
##### 학습 루프 #####
###################

print("Update method: ", LOSS_FN)
spsa = SPSA(alpha=ALPHA, gamma=GAMMA, epsilon=EPSILON, ak=AK, ck=CK, o=O, p_trigger_epsilon=P_TRIGGER_EPSILON)
avg_losses = []

for epoch in range(MAX_STEPS):  # MAX_STEPS 만큼 반복
    total_loss = 0.0 
    num_batches = 0  # 배치 개수 카운트

    for batch in dataloader:
        mel = batch["mel"].to(device)  # 배치 데이터를 텐서로 변환
        labels = batch["labels"].to(device)  # 패딩된 labels 텐서 변환

        loss_1, loss_2 = spsa.spsa_update(meshgrid, whisper_model, mel, labels)

        loss = min(loss_1, loss_2)  # 더 작은 Loss 선택
        total_loss += loss
        num_batches += 1

        spsa.parameter_update()  # epsilon 감소

    # 해당 에폭의 평균 Loss 계산
    avg_loss = total_loss / (num_batches * BATCH_SIZE)
    avg_losses.append(avg_loss)

    print()
    print(f"Epoch {epoch}: Avg WER Loss = {avg_loss:.4f}--------------------%%%%%%%%%%%%%%%%%%%@@@@@@@@@@@@@@@+++++++++++++++++++++")
    print()

    if epoch >= MAX_STEPS:
        break

# 학습 완료 후 모델 저장
torch.save(meshgrid.state_dict(), "meshgrid.pth")
print("meshgrid saved!")

import matplotlib.pyplot as plt

# GPU 텐서를 CPU로 옮긴 후 NumPy 배열로 변환
avg_losses_cpu = [loss.cpu().item() if isinstance(loss, torch.Tensor) else loss for loss in avg_losses]

plt.plot(avg_losses_cpu)
plt.xlabel('Epoch')
plt.ylabel(f'Average {LOSS_FN}')
plt.title(f'Average {LOSS_FN} Over Epochs')
plt.show()


369