!pip install transformers
!pip install ipywidgets
!pip install pytorch-lightning==1.5.10
!pip install nvidia-ml-py3
!pip install neptune-client
!pip install lightning-bolts
!pip install torchmetrics

from transformers import Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2Config
import torch
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2FeatureEncoder, Wav2Vec2NoLayerNormConvLayer, Wav2Vec2LayerNormConvLayer
from torch import nn
from transformers.activations import ACT2FN
import ipywidgets
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import csv
import torchaudio
import torchtext
import pytorch_lightning as pl
import nvidia_smi
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers import NeptuneLogger
from IPython.display import display, HTML
from dataclasses import dataclass, field
from torch.utils.data import DataLoader
from typing import Any, Dict, List, Optional, Union
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torchmetrics import Accuracy
from torchmetrics import F1Score
import torch.nn.functional as F
import numpy as np
import contextlib

!nvidia-smi

print(f"Pytorch Lightning Version: {pl.__version__}")
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
print(f"Device name: {nvidia_smi.nvmlDeviceGetName(handle)}")

neptune_logger = NeptuneLogger(
    api_key="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJjMWEyNTJlZS05ZDI5LTQzZjktYTkzNy00MDczMmZhODU3OWUifQ==",
    project='kgrosero/IA025-Project-wav2vec2')

version = "wav2vec2-sound_detection_train1" #@param {type: "string"}
lr = 1e-5#@param {type: "number"}
w_decay = 0#@param {type: "number"}
bs = 16#@param {type: "integer"}
accum_grads = 4#@param {type: "integer"}
patience = 30#@param {type: "integer"}
max_epochs = 300#@param {type: "integer"}
warmup_steps = 1000#@param {type: "integer"}
hold_epochs = 20#@param {type: "integer"}
pretrained = "facebook/wav2vec2-base-960h"#@param {type: "string"}
wav2vec2_processor = "facebook/wav2vec2-base-960h"#@param {type: "string"}
freeze_finetune_updates = 0#@param {type: "integer"}
warmup_epochs = 40#@param {type: "integer"}
apply_mask=False#@param {type: "boolean"}
mask_time_length= 10#@param {type: "integer"}, era 1

# Define hyperparameters
hparams = {"version": version,
          "lr": lr,
          "w_decay": w_decay,
          "bs": bs,
          "patience": patience,
          "hold_epochs":hold_epochs,
          "accum_grads": accum_grads,
          "pretrained":pretrained,
          "wav2vec2_processor": wav2vec2_processor,
          "freeze_finetune_updates":freeze_finetune_updates,
          "warmup_epochs":warmup_epochs,
          "apply_mask":apply_mask,
          "mask_time_length":mask_time_length,
          "max_epochs": max_epochs}
hparams

processor = Wav2Vec2Processor.from_pretrained(hparams["wav2vec2_processor"], return_attention_mask=True)

print(processor)

class Wav2Vec2GroupNormConvLayer(nn.Module):
    def __init__(self, config, layer_id=0):
        super().__init__()
        self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 4 # define 4 canais na camada de entrada
        self.out_conv_dim = config.conv_dim[layer_id]

        self.conv = nn.Conv1d(
            self.in_conv_dim,
            self.out_conv_dim,
            kernel_size=config.conv_kernel[layer_id],
            stride=config.conv_stride[layer_id],
            bias=config.conv_bias,
        )
        self.activation = ACT2FN[config.feat_extract_activation]

        self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)

    def forward(self, hidden_states):
        hidden_states = self.conv(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.activation(hidden_states)
        return hidden_states

class Wav2Vec2_4ChannelFeatureEncoder(nn.Module):
    """Construct the features from raw audio waveform"""

    def __init__(self, config):
        super().__init__()

        if config.feat_extract_norm == "group":
            conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
                Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
            ]
        elif config.feat_extract_norm == "layer":
            conv_layers = [
                Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
            ]
        else:
            raise ValueError(
                f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
            )
        self.conv_layers = nn.ModuleList(conv_layers)
        self.gradient_checkpointing = False
        self._requires_grad = True

    def _freeze_parameters(self):
        for param in self.parameters():
            param.requires_grad = False
        self._requires_grad = False

    def forward(self, input_values):
        hidden_states = input_values[:] # mudou para que receba todos os canais (4)
        #print("hidden_states", hidden_states.shape)

        # make sure hidden_states require grad for gradient_checkpointing
        if self._requires_grad and self.training:
            hidden_states.requires_grad = True

        for conv_layer in self.conv_layers:
            if self._requires_grad and self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(conv_layer),
                    hidden_states,
                )
            else:
                hidden_states = conv_layer(hidden_states)

        return hidden_states

# Crio o novo modelo que herda os processos de Wav2Vec2, mas usa o extrator de features baseado em 4 canais
class Wav2Vec2_4ChannelModel(Wav2Vec2Model):
    def __init__(self, config: Wav2Vec2Config):
        super().__init__(config)

        #del self.feature_extractor
        self.feature_extractor = Wav2Vec2_4ChannelFeatureEncoder(config)