In [1]:
import math
import random
import threading
import numpy as np

from matplotlib import pyplot as plt
from tqdm import tqdm

import librosa
import librosa.display

import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import TensorDataset, DataLoader, Dataset, random_split
from torch.cuda.amp import autocast, GradScaler
# from torchsummary import summary

from torch import nn, optim
import torch.nn.functional as F

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(torch.cuda.current_device()))

directory = 'E:/split_data/'

NVIDIA GeForce RTX 4070


### Model

In [2]:
# Basic Block for the CNN part of the network, using 1D convolutions
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        # First 1D convolutional layer
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        # Second 1D convolutional layer
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# CRNN Model for EEG Encoding, with 1D convolutions
class CRNNEncoder(nn.Module):
    def __init__(self, rnn_type='LSTM', in_channel=64):
        super(CRNNEncoder, self).__init__()
        self.current_channels = in_channel  # Current number of channels

        # 1D CNN layers with residual connections
        self.layer1 = self._make_layer(BasicBlock, 128, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 512, 2, stride=2)
        # self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)

        # RNN layer
        self.hidden_size = 128
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(512, self.hidden_size, batch_first=True)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(512, self.hidden_size, batch_first=True)
        else: # RNN
            self.rnn = nn.RNN(512, self.hidden_size, batch_first=True)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.current_channels, out_channels, stride))
            self.current_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        # Pass input through 1D CNN layers
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        # out = self.layer4(out)

        # Flatten and pass through RNN
        out = out.view(out.size(0), -1, 512)
        out, _ = self.rnn(out)

        # Option 1: Return the output of the last CNN layer for encoding
        # cnn_encoding = out.view(out.size(0), -1)
        # return cnn_encoding

        # Option 2: Return the output of the RNN layer for encoding
        rnn_encoding = out[:, -1, :]  # Using the last time step
        return rnn_encoding




### Training

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(Encoder, self).__init__()
        self.fc = nn.Linear(input_dim, embedding_dim)
        
    def forward(self, x):
        return self.fc(x)

# Define the contrastive loss with cosine similarity
class ContrastiveCosineLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ContrastiveCosineLoss, self).__init__()
        self.margin = margin
        self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6)

    def forward(self, output1, output2, label):
        # Cosine similarity
        cos_sim = self.cosine_similarity(output1, output2)
        
        # Calculate the loss
        loss = (1 - label) * 0.5 * cos_sim**2 + \
               label * 0.5 * (torch.relu(self.margin - cos_sim) ** 2)
        
        return loss.mean()
    
    
class EEGMelDataset(Dataset):
    def __init__(self, eeg_data, mel_data):
        """
        Initializes the dataset with EEG and Mel spectrogram data.
        :param eeg_data: Numpy array of EEG data with shape [10000, 320, 65]
        :param mel_data: Numpy array of Mel spectrogram data with shape [10000, 320, 10]
        """
        assert eeg_data.shape[0] == mel_data.shape[0], "EEG and Mel data must have the same number of samples"
        
        self.eeg_data = torch.from_numpy(eeg_data)
        self.mel_data = torch.from_numpy(mel_data)
        
        self.non_matching_indices = [random.choice([i for i in range(len(self)) if i != idx]) for idx in range(len(self))]

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return self.eeg_data.shape[0]

    def __getitem__(self, idx):
        """
        Returns a pair of EEG and Mel spectrogram samples along with a label indicating correspondence.
        """
        # Randomly decide whether to fetch a corresponding pair or not
        match = np.random.randint(0, 2)  # 0 or 1
        eeg_sample = self.eeg_data[idx]

        if match:
            mel_sample = self.mel_data[idx]
        else:
            non_matching_idx = self.non_matching_indices[idx]
            mel_sample = self.mel_data[non_matching_idx]

        return eeg_sample, mel_sample, torch.tensor(match, dtype=torch.float32)

In [4]:
def calculate_accuracy(cos_sim, labels, threshold=0.5):
    """
    Calculate the accuracy based on cosine similarity and labels.
    Args:
    - cos_sim (Tensor): Cosine similarity between pairs.
    - labels (Tensor): Actual labels indicating if pairs are matching.
    - threshold (float): Threshold for deciding if pairs are considered a match.
    Returns:
    - accuracy (float): The accuracy of predictions.
    """
    preds = cos_sim > threshold
    correct = torch.sum(preds == labels.bool()).item()
    total = labels.size(0)
    accuracy = correct / total
    return accuracy

def train(eeg_encoder, mel_eocoder, criterion, optimizer, train_set, test_set, batch_size, num_epochs=10, lr_decay_patience=5, early_stopping_patience=10):
    torch.cuda.empty_cache()
    
    scaler = GradScaler()
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    
    # stop_training = False

    # def check_for_input():
    #     global stop_training
    #     input("Press Enter to stop training after the current epoch...\n")
    #     stop_training = True
    
    # input_thread = threading.Thread(target=check_for_input)
    # input_thread.daemon = True  # This ensures the thread will be killed when the main program exits
    # input_thread.start()

    eeg_encoder = eeg_encoder.to(device)
    mel_eocoder = mel_eocoder.to(device)

    # Scheduler for learning rate decay
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=lr_decay_patience, factor=0.5)
    prev_lr = optimizer.param_groups[0]['lr']

    best_test_loss = float('inf')
    best_test_accu = float('-inf')
    
    epochs_no_improve = 0
    early_stop = False
    
    train_losses = []
    test_losses = []
    train_accus = []
    test_accus = []
    
    for epoch in tqdm(range(num_epochs)):
        # if stop_training:
        #     print("Stopping training...")
        #     break
    
        # ========== Training Loop ==========
        eeg_encoder.train()
        mel_eocoder.train()
        
        train_loss = 0
        train_accu = 0
        for eeg, mel, labels in train_loader:
            
            optimizer.zero_grad()
            
            with autocast():
                eeg_embedding = eeg_encoder(eeg.to(device))
                mel_embedding = mel_eocoder(mel.to(device))
                
                cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)(eeg_embedding, mel_embedding)
                train_accu += calculate_accuracy(cos_sim, labels.to(device))
                
                loss = criterion(eeg_embedding, mel_embedding, labels.to(device))
            
            # Scales loss, calls backward() to create scaled gradients
            scaler.scale(loss).backward()

            # Optimizer step and update scaler
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            
        train_losses.append(train_loss / len(train_loader))
        train_accus.append(train_accu / len(train_loader))

        # ========== Testing Loop ==========
        eeg_encoder.eval()
        mel_eocoder.eval()
        test_loss = 0
        test_accu = 0
        with torch.no_grad():
            for eeg, mel, labels in test_loader:
                
                eeg_embedding = eeg_encoder(eeg.to(device))
                mel_embedding = mel_eocoder(mel.to(device))
                
                cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)(eeg_embedding, mel_embedding)
                test_accu += calculate_accuracy(cos_sim, labels.to(device))
                
                loss = criterion(eeg_embedding, mel_embedding, labels.to(device))
                test_loss += loss.item()

            test_losses.append(test_loss / len(test_loader))
            test_accus.append(test_accu / len(test_loader))
    
        # ========== Update scheduler ==========
        scheduler.step(test_losses[-1])
        current_lr = optimizer.param_groups[0]['lr']
        if current_lr != prev_lr:
            print(f"Epoch {epoch}: Learning rate changed to {current_lr}")
            prev_lr = current_lr
            
        # Save model parameters
        curr_accu = test_accus[-1]
        if curr_accu > best_test_accu:
            best_test_accu = curr_accu
            torch.save(eeg_encoder.state_dict(), f'eeg_encoder_best.pth')
            torch.save(mel_eocoder.state_dict(), f'mel_encoder_best.pth')
            print(f"Epoch {epoch+1}: Test accuracy improved to {curr_accu:.4f}, saving model.")
        
        # Early stopping
        if test_losses[-1] < best_test_loss:
            best_test_loss = test_losses[-1]
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= early_stopping_patience:
                print("Early stopping triggered.")
                early_stop = True

        if early_stop:
            break

        print(f"Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Train Accu: {train_accus[-1]:.4f}, Test Accu: {test_accus[-1]:.4f}")
        print('===============')

    return train_losses, train_accus, test_losses, test_accus

In [5]:
train_eeg = np.transpose(np.load('train_x.npy'), (0, 2, 1))
train_mel = np.transpose(np.load('train_y.npy'), (0, 2, 1))
test_eeg = np.transpose(np.load('test_x.npy'), (0, 2, 1))
test_mel = np.transpose(np.load('test_y.npy'), (0, 2, 1))

train_set = EEGMelDataset(train_eeg, train_mel)
test_set = EEGMelDataset(test_eeg, test_mel)

In [6]:
eeg_encoder = CRNNEncoder(rnn_type='LSTM', in_channel=65).cuda()
audio_encoder = CRNNEncoder(rnn_type='LSTM', in_channel=10).cuda()

criterion = ContrastiveCosineLoss()
optimizer = optim.Adam(eeg_encoder.parameters(), lr=0.01)

epochs = 300
batch_size = 900
lr_decay = 10
early_stopping = 20
train_loss, train_accu, test_loss, test_accu = train(eeg_encoder, 
                                                     audio_encoder, 
                                                     criterion, 
                                                     optimizer, 
                                                     train_set, 
                                                     test_set, 
                                                     batch_size, 
                                                     epochs, 
                                                     lr_decay, 
                                                     early_stopping)

trained_epochs = len(train_loss)
x_range = np.linspace(1, trained_epochs, trained_epochs)
plt.figure(figsize=(12,6))
plt.plot(x_range, train_loss, label='Training loss')
plt.plot(x_range, test_loss, label='Test loss')
plt.title('loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend()
plt.grid(True)
plt.show()

plt.figure(figsize=(12,6))
plt.plot(x_range, train_accu, label='Training Accu')
plt.plot(x_range, test_accu, label='Test Accu')
plt.title('Accu')
plt.xlabel('Epochs')
plt.ylabel('Accu')
plt.legend()
plt.grid(True)
plt.show()

NVIDIA GeForce RTX 4070


  0%|          | 1/300 [00:03<19:05,  3.83s/it]

Epoch 1: Test accuracy improved to 0.5109, saving model.
Train Loss: 0.0471, Test Loss: 0.0431, Train Accu: 0.5005, Test Accu: 0.5109


  1%|          | 2/300 [00:06<14:42,  2.96s/it]

Train Loss: 0.0384, Test Loss: 0.0421, Train Accu: 0.4938, Test Accu: 0.5109


  1%|          | 3/300 [00:08<13:11,  2.66s/it]

Train Loss: 0.0348, Test Loss: 0.0333, Train Accu: 0.5047, Test Accu: 0.4746


  1%|▏         | 4/300 [00:10<12:31,  2.54s/it]

Train Loss: 0.0341, Test Loss: 0.0333, Train Accu: 0.4977, Test Accu: 0.4717


  2%|▏         | 5/300 [00:13<12:08,  2.47s/it]

Train Loss: 0.0339, Test Loss: 0.0328, Train Accu: 0.4974, Test Accu: 0.4891


  2%|▏         | 6/300 [00:15<11:49,  2.41s/it]

Train Loss: 0.0337, Test Loss: 0.0328, Train Accu: 0.5017, Test Accu: 0.5022


  2%|▏         | 7/300 [00:17<11:40,  2.39s/it]

Epoch 7: Test accuracy improved to 0.5269, saving model.
Train Loss: 0.0338, Test Loss: 0.0337, Train Accu: 0.5089, Test Accu: 0.5269


  3%|▎         | 8/300 [00:20<11:32,  2.37s/it]

Train Loss: 0.0337, Test Loss: 0.0331, Train Accu: 0.5054, Test Accu: 0.5080


  3%|▎         | 9/300 [00:22<11:29,  2.37s/it]

Train Loss: 0.0339, Test Loss: 0.0339, Train Accu: 0.4978, Test Accu: 0.4615


  3%|▎         | 10/300 [00:24<11:15,  2.33s/it]

Train Loss: 0.0336, Test Loss: 0.0336, Train Accu: 0.4857, Test Accu: 0.4964


  4%|▎         | 11/300 [00:27<11:08,  2.31s/it]

Train Loss: 0.0335, Test Loss: 0.0340, Train Accu: 0.4935, Test Accu: 0.4790


  4%|▍         | 12/300 [00:29<11:03,  2.31s/it]

Train Loss: 0.0331, Test Loss: 0.0338, Train Accu: 0.5055, Test Accu: 0.5109


  4%|▍         | 13/300 [00:31<11:02,  2.31s/it]

Train Loss: 0.0330, Test Loss: 0.0333, Train Accu: 0.4997, Test Accu: 0.5094


  5%|▍         | 14/300 [00:33<11:01,  2.31s/it]

Train Loss: 0.0329, Test Loss: 0.0331, Train Accu: 0.4992, Test Accu: 0.5080


  5%|▌         | 15/300 [00:36<10:59,  2.31s/it]

Train Loss: 0.0328, Test Loss: 0.0326, Train Accu: 0.4949, Test Accu: 0.4993


  5%|▌         | 16/300 [00:38<10:57,  2.31s/it]

Train Loss: 0.0324, Test Loss: 0.0325, Train Accu: 0.4977, Test Accu: 0.4499


  6%|▌         | 17/300 [00:41<11:01,  2.34s/it]

Epoch 17: Test accuracy improved to 0.5356, saving model.
Train Loss: 0.0324, Test Loss: 0.0325, Train Accu: 0.4929, Test Accu: 0.5356


  6%|▌         | 18/300 [00:43<11:01,  2.34s/it]

Train Loss: 0.0324, Test Loss: 0.0326, Train Accu: 0.4964, Test Accu: 0.5080


  6%|▋         | 19/300 [00:45<10:55,  2.33s/it]

Train Loss: 0.0323, Test Loss: 0.0314, Train Accu: 0.4963, Test Accu: 0.4862


  7%|▋         | 20/300 [00:47<10:50,  2.32s/it]

Train Loss: 0.0323, Test Loss: 0.0324, Train Accu: 0.5045, Test Accu: 0.5109


  7%|▋         | 21/300 [00:50<10:47,  2.32s/it]

Train Loss: 0.0322, Test Loss: 0.0329, Train Accu: 0.5115, Test Accu: 0.5094


  7%|▋         | 22/300 [00:52<10:46,  2.33s/it]

Train Loss: 0.0322, Test Loss: 0.0325, Train Accu: 0.5084, Test Accu: 0.4877


  8%|▊         | 23/300 [00:54<10:47,  2.34s/it]

Train Loss: 0.0320, Test Loss: 0.0322, Train Accu: 0.5014, Test Accu: 0.4833


  8%|▊         | 24/300 [00:57<10:45,  2.34s/it]

Train Loss: 0.0320, Test Loss: 0.0328, Train Accu: 0.5004, Test Accu: 0.4659


  8%|▊         | 25/300 [00:59<10:39,  2.33s/it]

Train Loss: 0.0318, Test Loss: 0.0325, Train Accu: 0.5007, Test Accu: 0.4978


  9%|▊         | 26/300 [01:01<10:40,  2.34s/it]

Train Loss: 0.0316, Test Loss: 0.0324, Train Accu: 0.5019, Test Accu: 0.4935


  9%|▉         | 27/300 [01:04<10:38,  2.34s/it]

Train Loss: 0.0317, Test Loss: 0.0328, Train Accu: 0.4944, Test Accu: 0.5239


  9%|▉         | 28/300 [01:06<10:36,  2.34s/it]

Train Loss: 0.0312, Test Loss: 0.0321, Train Accu: 0.5037, Test Accu: 0.5094


 10%|▉         | 29/300 [01:09<10:35,  2.34s/it]

Train Loss: 0.0311, Test Loss: 0.0316, Train Accu: 0.5145, Test Accu: 0.5036


 10%|█         | 30/300 [01:11<10:33,  2.34s/it]

Epoch 29: Learning rate changed to 0.005
Train Loss: 0.0303, Test Loss: 0.0318, Train Accu: 0.5048, Test Accu: 0.4848


 10%|█         | 31/300 [01:13<10:19,  2.30s/it]

Train Loss: 0.0297, Test Loss: 0.0313, Train Accu: 0.4927, Test Accu: 0.5152
