# This notebook conducts training of the net with and with out additional betti curve backbone

In [1]:
from data.generate_datasets import make_gravitational_waves
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from net import Clasificator, train
import torch.nn.functional as F
import torch 
from pipeline import pipeline_complex, pipeline_embedder

np.random.seed(42)

In [2]:
R = 0.65
n_signals = 1000
DATA = Path("./data")

noisy_signals, gw_signals, labels = make_gravitational_waves(
    path_to_data=DATA, n_signals=n_signals, r_min=R, r_max=R, n_snr_values=1
)

print(f"Number of noisy signals: {len(noisy_signals)}")
print(f"Number of timesteps per series: {len(noisy_signals[0])}")

Number of noisy signals: 1000
Number of timesteps per series: 8692


### Create dataset. This may take a while, because the topological feateres need to be calculated

In [3]:
class GravWavesDataSet(Dataset): 
    def __init__(self, noise_signal, labels): 
        self.noise_signal = noise_signal
        self.labels = labels
        self.betti_curves = pipeline_complex.fit_transform(
            pipeline_embedder.fit_transform(
                self.noise_signal
                )
            )
        
    def __len__(self): 
        return len(self.noise_signal)

    def _transform(self, data): 
        data = torch.Tensor(data)

        mean = data.mean()
        std = data.std()
        data = (data - mean) / (std + 1e-8) 

        if len(data.shape) == 1:
            data = data.unsqueeze(0)
            
        return data 

    def __getitem__(self, idx): 
        signal = self._transform(self.noise_signal[idx])
        betti = self._transform(self.betti_curves[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        return (signal, betti), label
    
X_train, X_test, y_train, y_test = train_test_split(
    np.array(noisy_signals), np.array(labels), test_size=0.1, random_state=42, stratify=labels
) 

train_dataset = GravWavesDataSet(X_train, y_train)
test_dataset = GravWavesDataSet(X_test, y_test)

In [4]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Train neural network 

## Train with betti curve

In [5]:
model = Clasificator()
optim = torch.optim.Adam(model.parameters())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

train(model, device, train_loader, test_loader, optim, epochs=40)

device: cuda

Test set: Average loss: 0.0275, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0259, Accuracy: 74/100 (74%)


Test set: Average loss: 0.0112, Accuracy: 84/100 (84%)


Test set: Average loss: 0.0043, Accuracy: 96/100 (96%)


Test set: Average loss: 0.0037, Accuracy: 99/100 (99%)


Test set: Average loss: 0.0039, Accuracy: 98/100 (98%)


Test set: Average loss: 0.0040, Accuracy: 97/100 (97%)


Test set: Average loss: 0.0055, Accuracy: 95/100 (95%)


Test set: Average loss: 0.0033, Accuracy: 98/100 (98%)


Test set: Average loss: 0.0069, Accuracy: 94/100 (94%)


Test set: Average loss: 0.0103, Accuracy: 89/100 (89%)


Test set: Average loss: 0.0049, Accuracy: 96/100 (96%)


Test set: Average loss: 0.0027, Accuracy: 99/100 (99%)


Test set: Average loss: 0.0031, Accuracy: 99/100 (99%)


Test set: Average loss: 0.0025, Accuracy: 99/100 (99%)


Test set: Average loss: 0.0028, Accuracy: 99/100 (99%)


Test set: Average loss: 0.0030, Accuracy: 97/100 (97%)


Test set: Average

# Train without betti curve

In [6]:
model = Clasificator(with_betti=False)
optim = torch.optim.Adam(model.parameters())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

train(model, device, train_loader, test_loader, optim, epochs=40)

device: cuda

Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0276, Accuracy: 51/100 (51%)


Test set: Average loss: 0.0277, Accuracy: 51/100 (51%)


Test set: Average