In [1]:
import os 
import os.path as path

import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


from src.data.load_dataset import load_frequency_detection
from src.models.networks import sensilla_RFNet, classical_RFNet
from src.models.utils import train, test

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# load data
num_samples, sampling_rate, duration, freq, snr, seed = 7000, 1500, 0.1, 5, 0.8, 5
train_batch_size, train_percentage = 2048, 0.8
train_loader, val_loader, test_loader = load_frequency_detection(num_samples, sampling_rate,
                                                                freq, duration, snr, 
                                                                train_batch_size, train_percentage, seed=2)

#### Mechanosensory RFNet with $\omega_a=20$ Hz, $\omega_b = 80$ Hz, \& $\gamma=6$

In [None]:
num_neurons = sorted(set(np.logspace(0, 3, 50).astype('int')))
inp_size = int(sampling_rate * duration)
omega_a, omega_b, gamma = 2, 8, 6
lr = 0.01
num_epochs = 30
log_interval = 100

test_sensilla = {'hidden_size': [], 'mean': [], 'std': []}
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(50):
        model = sensilla_RFNet(inp_size, hidden_size, 
                                 omega_a, omega_b, gamma).to(device)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
        accuracy.append(test(model, device, test_loader, verbose=False))
        
    test_sensilla['hidden_size'].append(hidden_size)
    test_sensilla['mean'].append(np.mean(accuracy))
    test_sensilla['std'].append(np.std(accuracy))

  2%|▏         | 1/42 [00:57<39:18, 57.54s/it]

#### Classical RFNet

In [None]:
num_neurons = sorted(set(np.logspace(0, 3, 50).astype('int')))
inp_size = int(sampling_rate * duration)
lr = 0.01
num_epochs = 30
log_interval = 100

test_classical = {'hidden_size': [], 'mean': [], 'std': []} 
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(50):
        model = classical_RFNet(inp_size, hidden_size).to(device)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
        accuracy.append(test(model, device, test_loader, verbose=False))

    test_classical['hidden_size'].append(hidden_size)
    test_classical['mean'].append(np.mean(accuracy))
    test_classical['std'].append(np.std(accuracy))

#### Mechanosensory RFNet with incompatible parameters: $\omega_a=10$ Hz, $\omega_b = 40$ Hz, \& $\gamma=6$

In [None]:
# incompatible sensilla RFNet 
num_neurons = sorted(set(np.logspace(0, 3, 50).astype('int')))
inp_size = int(sampling_rate * duration)
omega_a, omega_b, gamma = 1, 4, 6
lr = 0.01
num_epochs = 30
log_interval = 100

test_incompatible = {'hidden_size': [], 'mean': [], 'std': []}
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(50):
        model = sensilla_RFNet(inp_size, hidden_size, 
                                 omega_a, omega_b, gamma).to(device)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
        accuracy.append(test(model, device, test_loader, verbose=False))
        
    test_incompatible['hidden_size'].append(hidden_size)
    test_incompatible['mean'].append(np.mean(accuracy))
    test_incompatible['std'].append(np.std(accuracy))

In [None]:
# save
test = {'sensilla': test_sensilla, 'classical': test_classical, 'incompatible': test_incompatible}
data_dir = path.abspath(path.join(os.getcwd(), '../../'))
with open(data_dir + '/models/results/freq_detection/freq_detection_sensilla.pickle', 'wb') as handle:
    pickle.dump(test, handle, protocol=pickle.HIGHEST_PROTOCOL) 