In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import Wav2Vec2ForPreTraining, Trainer, TrainingArguments

In [None]:
class AudioDataset(Dataset):
    def __init__(self, directory):
        self.directory = directory
        self.filenames = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.wav')]
        self.files = []

        for file in tqdm(self.filenames, desc="Loading audio files"):
            audio, _ = librosa.load(file, sr=16000, mono=True)
            self.files.append(torch.tensor(audio, dtype=torch.float32))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        return self.files[idx]

# 데이터셋 인스턴스화
dataset = AudioDataset('./reduce/data/train')  
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# 사전 학습된 모델 로드
model = Wav2Vec2ForPreTraining.from_pretrained("kresnik/wav2vec2-large-xlsr-korean")

# Trainer 설정
training_args = TrainingArguments(
    output_dir='./wav2vec2_pretrained',
    per_device_train_batch_size=1,  # GPU 메모리에 따라 조정
    num_train_epochs=10,            # 적절한 에폭 수 설정
    logging_dir='./logs',
    logging_steps=10,
    save_steps=500,
    do_train=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=lambda data: {'input_values': torch.cat([x.unsqueeze(0) for x in data], 0)},
    train_dataset=dataset
)

# Pre-training 시작
trainer.train()

In [1]:
import os
import torch
from transformers import Wav2Vec2ForSequenceClassification, Trainer, TrainingArguments
import torchaudio

# Dummy dataset for demonstration
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, directory):
        self.directory = directory
        self.sr = 16000  # Sampling rate
        self.files = [f for f in os.listdir(directory) if f.endswith('.wav')]
        self.labels = [i % 2 for i in range(len(self.files))]  # Dummy binary labels

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = os.path.join(self.directory, self.files[idx])
        waveform, _ = torchaudio.load(file_path)
        waveform = torchaudio.transforms.Resample(orig_freq=_, new_freq=self.sr)(waveform)
        return waveform.squeeze(0), self.labels[idx]

# Dummy model (Wav2Vec2ForSequenceClassification)
model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-large-xlsr-53", num_labels=2)

# Dummy dataset
dataset = AudioDataset('./reduce/data/train')

# Custom data collator (concatenates input_values)
def data_collator(data):
    waveforms, labels = zip(*data)
    # Pad sequences to the same length
    batched_waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
    return {
        'input_values': batched_waveforms,
        'labels': torch.tensor(labels)  # Class labels
    }

# Custom Trainer class inheriting from Trainer
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs):
        outputs = model(input_values=inputs['input_values'], labels=inputs['labels'])
        return outputs.loss

# Trainer configuration
training_args = TrainingArguments(
    output_dir='./wav2vec2_pretrained',
    per_device_train_batch_size=1,  # 배치 크기 줄임
    gradient_accumulation_steps=8,  # Gradient Accumulation
    num_train_epochs=3,
    logging_dir='./logs',
    logging_steps=10,
    save_steps=500,
    do_train=True
)

# Custom Trainer initialization
trainer = CustomTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset
)

# Start training
trainer.train()


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at facebook/wav2vec2-large-xlsr-53 were not used when initializing Wav2Vec2ForSequenceClassification: ['quantizer.weight_proj.bias', 'project_q.bias', 'quantizer.codevectors', 'quantizer.weight_proj.weight', 'project_hid.bias', 'project_hid.weight', 'project_q.weight']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are ne

{'train_runtime': 757.3309, 'train_samples_per_second': 0.008, 'train_steps_per_second': 0.004, 'train_loss': 0.17487879594167074, 'epoch': 3.0}





TrainOutput(global_step=3, training_loss=0.17487879594167074, metrics={'train_runtime': 757.3309, 'train_samples_per_second': 0.008, 'train_steps_per_second': 0.004, 'train_loss': 0.17487879594167074, 'epoch': 3.0})