In [3]:
### DenseNet121 N_FFT 256 - 23 EPOCHS - 0.96 ACC ON 20% TEST
import os, math
import numpy as np
import pandas as pd 

import librosa
from scipy import signal

import tqdm 
from matplotlib import pyplot as plt

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

from torchvision.models import densenet121
import torch
import torch.nn as nn; import torch.nn.functional as F
import torch.optim as optim; from torch.optim import lr_scheduler
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
#from keras.layers import Input

#from keras.callbacks import ModelCheckpoint
#from keras.callbacks import EarlyStopping
#from keras.callbacks import ReduceLROnPlateau
#from keras.callbacks import CSVLogger

#from keras import Model
#from keras import backend as K

#from keras.utils import np_utils
#from keras.preprocessing import image
 
#from keras.applications.densenet import DenseNet121


In [4]:
current_model = densenet121


model_name = 'wingbeats_' + current_model.__name__

best_weights_path = model_name + '.pt'
log_path = model_name + '.log'
monitor = 'val_acc'
batch_size = 32
epochs = 100
es_patience = 7
rlr_patience = 3

SR = 8000
N_FFT = 256
HOP_LEN = N_FFT // 6
input_shape = (129, 120, 1)
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available else 'cpu'
seed = 2018
np.random.seed(seed)

In [6]:
target_names = ['Ae. aegypti', 'Ae. albopictus', 'An. gambiae', 'An. arabiensis', 'C. pipiens', 'C. quinquefasciatus']

X_names = []
y = []
target_count = []

for i, target in enumerate(target_names):
    target_count.append(0)
    path = './../Wingbeats/' + target + '/'
    for [root, dirs, files] in os.walk(path, topdown = False):
        for filename in files:
            name,ext = os.path.splitext(filename)
            if ext == '.wav':
                name = os.path.join(root, filename)
                y.append(i)
                X_names.append(name)
                target_count[i]+=1
                # if target_count[i] > 20000:
                #     break
    print (target, '#recs = ', target_count[i])

print ('total #recs = ', len(y))

X_names, y = shuffle(X_names, y, random_state = seed)
X_train, X_test, y_train, y_test = train_test_split(X_names, y, stratify = y, test_size = 0.20, random_state = seed)

n_samples = len(X_train)
n_tests = len(X_test)
print('train #recs = ', len(X_train))
print('test #recs = ', len(X_test))
print('Total : ', len(X_names))

Ae. aegypti #recs =  85553
Ae. albopictus #recs =  20231
An. gambiae #recs =  49471
An. arabiensis #recs =  19297
C. pipiens #recs =  30415
C. quinquefasciatus #recs =  74599
total #recs =  279566
train #recs =  223652
test #recs =  55914
Total :  279566


In [7]:
def shift(x, wshift, hshift, row_axis=1, col_axis=2, channel_axis=0, fill_mode='constant', cval=0.):
    """
    Versión corregida de la función shift para PyTorch
    """
    # Asegurar que el tensor es 4D (batch, channels, height, width)
    if x.dim() == 3:
        # Si es (C, H, W), añadir dimensión batch
        x = x.unsqueeze(0)
        was_3d = True
    else:
        was_3d = False
    
    # Obtener dimensiones
    batch_size, channels, h, w = x.shape
    
    # Calcular desplazamiento en píxeles
    tx = hshift * h  # Desplazamiento vertical
    ty = wshift * w  # Desplazamiento horizontal
    
    # Crear matriz de transformación para PyTorch (2x3)
    theta = torch.tensor([
        [1, 0, ty],  # Transformación en X (ancho)
        [0, 1, tx]   # Transformación en Y (alto)
    ], dtype=torch.float32)
    
    # Repetir para todo el batch
    theta = theta.unsqueeze(0).repeat(batch_size, 1, 1)
    
    # Mover al mismo dispositivo que x
    theta = theta.to(x.device)
    
    # Crear grid de transformación (necesita tamaño 4D)
    grid = F.affine_grid(theta, x.size(), align_corners=False)
    
    # Mapear modos de relleno
    padding_mode = 'zeros' if fill_mode == 'constant' else fill_mode
    
    # Aplicar la transformación
    x = F.grid_sample(x, grid, mode='bilinear', padding_mode=padding_mode, align_corners=False)
    
    # Remover dimensión batch si originalmente era 3D
    if was_3d:
        x = x.squeeze(0)
    
    return x

def random_data_shift(data, w_limit=(-0.25, 0.25), h_limit=(-0.0, 0.0), cval=0., u=0.5):
    """Versión corregida de random_data_shift"""
    if torch.rand(1) < u:
        wshift = torch.empty(1).uniform_(w_limit[0], w_limit[1]).item()
        hshift = torch.empty(1).uniform_(h_limit[0], h_limit[1]).item()
        data = shift(data, wshift, hshift, cval=cval)
    return data

In [8]:
class AudioDataset(Dataset):
    def __init__(self, file_paths, labels, target_names, sr=22050, n_fft=2048, hop_len=512):
        self.file_paths = file_paths
        self.labels = labels
        self.target_names = target_names
        self.sr = sr
        self.n_fft = n_fft
        self.hop_len = hop_len
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        data, _ = librosa.load(self.file_paths[idx], sr=self.sr)
        
        data = librosa.stft(data, n_fft=self.n_fft, hop_length=self.hop_len)
        data = librosa.amplitude_to_db(np.abs(data))
        data = np.flipud(data)  # Voltear verticalmente
        data = data.copy()
        data = torch.tensor(data, dtype=torch.float32)
        data = data.unsqueeze(0)  # Añadir canal: (1, H, W)
        
        data = random_data_shift(data, 
                               w_limit=(-0.25, 0.25), 
                               h_limit=(-0.0, 0.0), 
                               cval=float(data.min()), 
                               u=1.0)
        
        label = self.labels[idx]
        label_tensor = torch.tensor(label, dtype=torch.long)
        
        return data, label_tensor

In [9]:
def create_loader(X_train, y_train, target_names, batch_size=32, sr=22050, n_fft=2048, hop_len=512, shuffle=True):
    dataset = AudioDataset(X_train, y_train, target_names, sr, n_fft, hop_len)
    
    def collate_fn(batch):
        data_list, label_list = [], []
        for data, label in batch:
            data_list.append(data)
            label_list.append(label)
        
        data_batch = torch.stack(data_list)
        
        labels_batch = torch.stack(label_list)
        labels_one_hot = torch.nn.functional.one_hot(labels_batch, num_classes=len(target_names)).float()
        
        return data_batch, labels_one_hot
    
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    return loader

In [10]:
train_dataloader = create_loader(X_train, y_train, target_names, batch_size=batch_size,
                        sr=SR, n_fft=N_FFT, hop_len=HOP_LEN, shuffle=True)
test_dataloader = create_loader(X_test, y_test, target_names, batch_size=batch_size,
                        sr=SR, n_fft=N_FFT, hop_len=HOP_LEN, shuffle=False)

In [11]:
def setup_model_for_audio(current_model, num_classes, input_channels=1):
    """
    Configura automáticamente el modelo para espectrogramas de audio
    """
    model = current_model(num_classes=num_classes)
    
    # Detectar si el modelo es de torchvision y necesita ajuste de canales
    model_name = model.__class__.__name__.lower()
    
    # Lista de modelos que normalmente esperan 3 canales
    rgb_models = ['resnet', 'densenet', 'alexnet', 'vgg', 'mobilenet', 'inception']
    
    needs_channel_adjustment = any(rgb_model in model_name for rgb_model in rgb_models)
    
    if needs_channel_adjustment:
        print(f"Ajustando {model_name} para {input_channels} canal(es) de entrada")
        model = adjust_first_conv_layer(model, input_channels)
    
    if torch.cuda.is_available():
        model = model.cuda()
    
    return model

def adjust_first_conv_layer(model, input_channels):
    """
    Ajusta la primera capa convolucional para aceptar input_channels
    """
    if hasattr(model, 'conv1'):
        # Para modelos como ResNet
        original_conv = model.conv1
        model.conv1 = nn.Conv2d(
            input_channels, 
            original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias is not None
        )
    elif hasattr(model, 'features') and hasattr(model.features, 'conv0'):
        # Para DenseNet
        original_conv = model.features.conv0
        model.features.conv0 = nn.Conv2d(
            input_channels,
            original_conv.out_channels,
            kernel_size=original_conv.kernel_size,
            stride=original_conv.stride,
            padding=original_conv.padding,
            bias=original_conv.bias is not None
        )
    elif hasattr(model, 'features') and len(model.features) > 0:
        # Para VGG, AlexNet, etc.
        first_conv = model.features[0]
        if isinstance(first_conv, nn.Conv2d):
            model.features[0] = nn.Conv2d(
                input_channels,
                first_conv.out_channels,
                kernel_size=first_conv.kernel_size,
                stride=first_conv.stride,
                padding=first_conv.padding,
                bias=first_conv.bias is not None
            )
    
    return model

In [12]:
model = setup_model_for_audio(current_model, len(target_names), input_channels=1)
if torch.cuda.is_available():
    model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.1, 
    patience=rlr_patience
)

Ajustando densenet para 1 canal(es) de entrada


In [14]:
class PyTorchCallbacks:
    def __init__(self, model, best_weights_path, log_path, monitor='val_loss', patience=10):
        self.model = model
        self.best_weights_path = best_weights_path
        self.log_path = log_path
        self.monitor = monitor
        self.patience = patience
        
        self.best_metric = float('inf') if monitor == 'val_loss' else 0
        self.epochs_no_improve = 0
        self.writer = SummaryWriter()  # Para logging (opcional)
        
        # Crear directorios si no existen
        #os.makedirs(os.path.dirname(best_weights_path), exist_ok=True)
        #os.makedirs(os.path.dirname(log_path), exist_ok=True)

    def on_epoch_end(self, current_metric, epoch):
        improved = False
        
        if self.monitor == 'val_loss' and current_metric < self.best_metric:
            improved = True
        elif self.monitor == 'val_accuracy' and current_metric > self.best_metric:
            improved = True
        
        if improved:
            print(f"\nMétrica mejorada de {self.best_metric:.4f} a {current_metric:.4f}")
            self.best_metric = current_metric
            self.epochs_no_improve = 0
            
            # Guardar mejores pesos (ModelCheckpoint)
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'best_metric': self.best_metric
            }, self.best_weights_path)
            print(f"Pesos guardados en {self.best_weights_path}")
        else:
            self.epochs_no_improve += 1
    
    def should_stop(self):
        if self.epochs_no_improve >= self.patience:
            print(f"\nEarlyStopping: No mejora después de {self.patience} épocas")
            return True
        return False
    
    def log_metrics(self, epoch, train_loss, val_loss, train_acc, val_acc):
        log_entry = f"{epoch},{train_loss:.4f},{val_loss:.4f},{train_acc:.4f},{val_acc:.4f}\n"
        
        with open(self.log_path, 'a') as f:
            if epoch == 0:
                f.write("epoch,train_loss,val_loss,train_accuracy,val_accuracy\n")
            f.write(log_entry)
        
        self.writer.add_scalar('Loss/train', train_loss, epoch)
        self.writer.add_scalar('Loss/val', val_loss, epoch)
        self.writer.add_scalar('Accuracy/train', train_acc, epoch)
        self.writer.add_scalar('Accuracy/val', val_acc, epoch)

In [15]:
def train_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        
        # CONVERTIR ONE-HOT A ÍNDICES
        if target.dim() > 1 and target.shape[1] > 1:
            target = target.argmax(dim=1)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
    
    return total_loss / len(dataloader), correct / total

def validate_epoch(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in dataloader:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            
            # CONVERTIR ONE-HOT A ÍNDICES
            if target.dim() > 1 and target.shape[1] > 1:
                target = target.argmax(dim=1)
            
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    return total_loss / len(dataloader), correct / total

In [16]:
# Inicializar callbacks
callbacks = PyTorchCallbacks(
    model=model,
    best_weights_path=best_weights_path,
    log_path=log_path,
    monitor=monitor,
    patience=es_patience
)

# Loop de entrenamiento
for epoch in range(epochs):
    # Entrenamiento
    model.train()
    train_loss, train_acc = train_epoch(model, train_dataloader, optimizer, criterion)
    
    # Validación
    model.eval()
    val_loss, val_acc = validate_epoch(model, test_dataloader, criterion)
    
    # Callbacks
    current_metric = val_loss if monitor == 'val_loss' else val_acc
    callbacks.on_epoch_end(current_metric, epoch)
    callbacks.log_metrics(epoch, train_loss, val_loss, train_acc, val_acc)
    
    # Verificar early stopping
    if callbacks.should_stop():
        break

callbacks.writer.close()


EarlyStopping: No mejora después de 7 épocas
