In [8]:
import torch
import torch.nn as nn
import torch.utils.data as DataLoader
from sklearn.cluster import KMeans
from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Processor
from typing import List, Dict, Union, Optional
import torch
from transformers import Wav2Vec2Processor
from dataclasses import dataclass
from torch.utils.data import DataLoader

In [3]:
hparams = {
    "num_clusters": 5,
    "batch_size": 16,
    "lr": 0.001,
    "w_decay": 0.0001,
    "warmup_epochs": 3,
    "max_epochs": 10,
    # Agrega otros hiperparámetros necesarios para tu modelo
}


In [10]:
model_name = "facebook/wav2vec2-base-960h"
model  = Wav2Vec2ForMaskedLM.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)

Some weights of Wav2Vec2ForMaskedLM were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
@dataclass
class DataCollatorForClustering:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        return batch


In [11]:
data_collator = DataCollatorForClustering(processor=processor,
                                        #max_length=188,
                                        padding=True)

In [None]:

batch_size = hparams["bs"]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
                              collate_fn = data_collator,
                              shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size,
                            collate_fn = data_collator,
                            shuffle=False, num_workers=4)

print('Número de minibatches de entrenamiento:', len(train_dataloader))
print('Número de minibatches de validación:', len(val_dataloader))

batch = next(iter(train_dataloader))
x_train = batch['input_values']
print("\nDimensión de los datos de un minibatch - Audio:", x_train.size())
print("Valores mínimo y máximo entrada: ", torch.min(x_train), torch.max(x_train))
print("Tipo de los datos de los audios: ", type(x_train))


In [None]:
class Wav2Vec2_Clusterization(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.hparams.update(hparams)


        #self.model = model4c
        self.val_f1_scores = []
        self.test_f1_scores = []
        
        self.model = Wav2Vec2_4ChannelModel.from_pretrained(hparams["pretrained"],
                                                 conv_dim = (512, 512, 512, 512, 512, 512),
                                                 conv_stride = (5, 2, 2, 2, 2, 2),
                                                 conv_kernel = (10, 3, 3, 3, 3, 2),
                                                 num_feat_extract_layers = 6,
                                                 apply_spec_augment=hparams["apply_mask"],
                                                 #mask_time_length=hparams["mask_time_length"],
                                                 ignore_mismatched_sizes=True)


        # self.model.feature_extractor._freeze_parameters()

        # freeze base-model
        # for param in self.model.parameters():
        #     param.requires_grad = False

        self.projector = nn.Linear(self.model.config.hidden_size, self.model.config.classifier_proj_size)
        n_classes = 11
        self.final_layer = nn.Linear(self.model.config.classifier_proj_size, n_classes)

    def forward(self, samples):

        ft = self.freeze_finetune_updates <= self.trainer.global_step

        with torch.no_grad() if not ft else contextlib.ExitStack():
              hidden_states = self.model(**samples).last_hidden_state

        padding_mask = self.model._get_feature_vector_attention_mask(hidden_states.shape[1], samples["attention_mask"])

        hidden_states[~padding_mask] = 0.0

        pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

        proj_pooled = self.projector(pooled_output)

        preds = self.final_layer(proj_pooled)

        return F.log_softmax(preds, dim=1)

    def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
        output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
        batch_size = attention_mask.shape[0]

        attention_mask = torch.zeros(
            (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
        )

        attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
        attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
        return attention_mask
        
        
    def training_step(self, train_batch, batch_idx):

        y_value = train_batch.pop("target")
        log_softs = self.forward(train_batch)
    

        loss = F.nll_loss(log_softs, y_value)
        
        
        

        self.log('loss_step', loss, on_step=True, prog_bar=True)
        

        return loss

    def training_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        

        self.log("train_loss", loss, prog_bar=True)
        
        
        

    def validation_step(self, val_batch, batch_idx):

        y_value = val_batch.pop("target")

        log_softs = self.forward(val_batch)
        preds = torch.argmax(log_softs, dim=1)

        val_acc = accuracy(preds.cpu(), y_value.cpu())
        val_f1 = f1(preds.cpu(), y_value.cpu())
        val_loss = F.nll_loss(log_softs, y_value)

        self.log('val_acc', val_acc, prog_bar=True)
        self.log('val_f1', val_f1, prog_bar=True)
        self.log('val_loss', val_loss, prog_bar=True)

        return {"val_acc_step": val_acc, "val_f1_step": val_f1, "val_loss_step": val_loss}

    
    def validation_epoch_end(self, outputs):
        acc_mean = torch.stack([x['val_acc_step'] for x in outputs]).mean()
        f1_mean = torch.stack([x['val_f1_step'] for x in outputs]).mean()
        loss_mean = torch.stack([x['val_loss_step'] for x in outputs]).mean()

        self.log("val_acc", acc_mean, prog_bar=True)
        self.log("val_f1", f1_mean, prog_bar=True)
        self.log("val_loss", loss_mean, prog_bar=True)

        self.val_f1_scores.append(f1_mean)
    
    
    """def validation_epoch_end(self, outputs):
        acc_mean = torch.stack([x['val_acc_step'] for x in outputs]).mean()
        f1_mean = torch.stack([x['val_f1_step'] for x in outputs]).mean()
        loss_mean = torch.stack([x['val_loss_step'] for x in outputs]).mean()

        self.log("val_acc", acc_mean, prog_bar=True)
        self.log("val_f1", f1_mean, prog_bar=True)
        self.log("val_loss", loss_mean, prog_bar=True)"""

    def test_step(self, test_batch, batch_idx):

        y_value = test_batch.pop("target")

        log_softs = self.forward(test_batch)
        preds = torch.argmax(log_softs, dim=1)

        test_acc = accuracy(preds.cpu(), y_value.cpu())
        test_f1 = f1(preds.cpu(), y_value.cpu())
        test_loss = F.nll_loss(log_softs, y_value)

        self.log('test_acc', test_acc, prog_bar=True)
        self.log('test_f1', test_f1, prog_bar=True)
        self.log('test_loss', test_loss, prog_bar=True)

        return {"test_acc_step": test_acc, "test_f1_step": test_f1,  "test_loss_step": test_loss}

    def test_epoch_end(self, outputs):
        acc_mean = torch.stack([x['test_acc_step'] for x in outputs]).mean()
        f1_mean = torch.stack([x['test_f1_step'] for x in outputs]).mean()
        loss_mean = torch.stack([x['test_loss_step'] for x in outputs]).mean()

        self.log("test_acc", acc_mean, prog_bar=True)
        self.log("test_f1", f1_mean, prog_bar=True)
        self.log("test_loss", loss_mean, prog_bar=True)
        
        self.test_f1_scores.append(f1_mean)

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(),
                         lr=self.hparams["lr"],
                         betas=(0.9,0.98),
                         eps=1e-6,
                         weight_decay=self.hparams["w_decay"])

        scheduler = LinearWarmupCosineAnnealingLR(optimizer,
                                                  eta_min=0,
                                                  warmup_start_lr=self.hparams["lr"],
                                                  warmup_epochs=self.hparams["warmup_epochs"],
                                                  max_epochs=self.hparams["max_epochs"])

        return {'optimizer': optimizer, 'lr_scheduler': scheduler}
    