In [None]:
! pip install pytorch-lightning

In [14]:
import os
import json
import torch
import torchaudio
from glob import glob

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from transformers import WhisperForConditionalGeneration, WhisperProcessor, TrainingArguments, Trainer, AdamW, get_scheduler

In [22]:
# ----- Params -----

data_dir = "datasets/OldPeople_Voice/label/"
audio_dir = "datasets/OldPeople_Voice/"
save_dir = "whisper_finetuned"

# ----- ------ -----

In [3]:
class CustomAudioDataset(Dataset):
    def __init__(self, json_list, processor):
        self.processor = processor
        self.data = []

        # 모든 JSON 파일을 리스트로
        for json_path in json_list:
            with open(json_path, "r", encoding="utf-8") as f:
                data = json.load(f)
            
            # 오디오 파일 경로
            audio_file = os.path.join(audio_dir, data["발화정보"]["fileNm"])
            
            # 파일이 실제 존재하는지 확인 (오류 방지)
            if not os.path.exists(audio_file):
                print(f"⚠️ Warning: {audio_file} 파일이 존재하지 않습니다.")
                continue  # 해당 파일 건너뛰기
            
            text = data["발화정보"]["stt"]
            self.data.append((audio_file, text))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        audio_file, text = self.data[idx]

        waveform, sample_rate = torchaudio.load(audio_file)

        # 16kHz 샘플링 for Whisper
        if sample_rate != 16000:
            waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)

        # 오디오 데이터 변환
        input_features = self.processor(
            waveform.squeeze(0).numpy(),
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features

        # 텍스트 토큰화 하기
        labels = self.processor.tokenizer(text, return_tensors="pt").input_ids

        return {
            "input_features": input_features.squeeze(0),
            "labels": labels.squeeze(0)
        }

In [4]:
def load_data(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    audio_file = os.path.join(audio_dir, data["발화정보"]["fileNm"])
    text = data["발화정보"]["stt"]
    # duration = float(str(data["발화정보"]["recrdTime"]))

    return {
        "audio": audio_file,  # 파일 경로 저장
        "text": text,
        # "duration": duration
    }

In [20]:
batch_size = 4
learning_rate = 1e-5
num_epochs = 3
gradient_accumulation_steps = 2
device = "cuda" if torch.cuda.is_available() else "cpu"

In [28]:
model = WhisperForConditionalGeneration.from_pretrained("SungBeom/whisper-small-ko")
processor = WhisperProcessor.from_pretrained("openai/whisper-small")

json_list = glob(data_dir+"/*.json")
print("data path :",json_list)
dataset = CustomAudioDataset(json_list, processor)
print(f"{len(dataset)}개 로드.")
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/967M [00:00<?, ?B/s]

data path : ['datasets/OldPeople_Voice/label/노인남여_노인대화77_F_김XX_62_제주_실내_84050.json', 'datasets/OldPeople_Voice/label/노인남여_노인대화77_F_김XX_62_제주_실내_84051.json', 'datasets/OldPeople_Voice/label/노인남여_노인대화77_F_김XX_62_제주_실내_84052.json', 'datasets/OldPeople_Voice/label/노인남여_노인대화77_F_김XX_62_제주_실내_84053.json', 'datasets/OldPeople_Voice/label/노인남여_노인대화77_F_김XX_62_제주_실내_84054.json']
5개 로드.


In [29]:
# optimizer & scheduler 
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * num_epochs
)

In [31]:
# Train
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for step, batch in enumerate(train_dataloader):
        # 배치에서 input_features와 labels 추출
        input_features = [item["input_features"].to(device) for item in batch]
        labels = [item["labels"].to(device) for item in batch]

        # padding 처리
        input_features = torch.nn.utils.rnn.pad_sequence(input_features, batch_first=True, padding_value=0)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        # 모델에 입력
        outputs = model(input_features, labels=labels)
        loss = outputs.loss / gradient_accumulation_steps  # gradient accumulation 적용
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1 == len(train_dataloader)):
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"🚀 Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # 모델 저장
    model.save_pretrained(f"{save_dir}/epoch_{epoch+1}")
    processor.save_pretrained(f"{save_dir}/epoch_{epoch+1}")


🚀 Epoch [1/3], Loss: 4.7223
🚀 Epoch [2/3], Loss: 4.1509
🚀 Epoch [3/3], Loss: 3.7681


# TEST

In [32]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor

model_path = "whisper_finetuned/epoch_1"  # X는 저장한 에포크 번호
model = WhisperForConditionalGeneration.from_pretrained(model_path)
processor = WhisperProcessor.from_pretrained(model_path)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# 테스트 오디오 파일
audio_file = audio_dir + "노인남여_노인대화77_F_김XX_62_제주_실내_84051.WAV"

# 오디오 파일 로드
waveform, sample_rate = torchaudio.load(audio_file)

# Whisper는 16kHz 샘플링 속도를 사용하므로 변환 필요
if sample_rate != 16000:
    waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)

# 모델의 입력으로 변환
input_features = processor(waveform.squeeze(0).numpy(), sampling_rate=16000, return_tensors="pt").input_features
input_features = input_features.to(device)

# 모델을 통해 예측 수행
with torch.no_grad():
    predicted_ids = model.generate(input_features)

# 예측된 텍스트 디코딩
transcribed_text = processor.decode(predicted_ids[0], skip_special_tokens=True)

print("예측된 텍스트:", transcribed_text)


예측된 텍스트: 어 큰 나들만 좋아하는 것 같아
