In [None]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch import nn, optim
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import datetime
import os
import readligo as rl
from gwpy.timeseries import TimeSeries
import math
import random
from matplotlib.lines import Line2D

from scipy import integrate

In [None]:
epochs = 400
test_sample_ratio = .2
validation_sample_ratio = 0.1
batch_size = 32
coef_delta = 0

In [None]:
class SupervisedModel(nn.Module):
    def __init__(self):
        super(SupervisedModel, self).__init__()
        self.fc1 = nn.Linear(2, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 6)
        self.fc3 = nn.Linear(6, 4)
        self.fc4 = nn.Linear(4, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x

In [None]:
Reconstruction_error_L = np.load('../Data_cached/Reconstruction_error_GWAKset_L_1.npz')
Reconstruction_error_H = np.load('../Data_cached/Reconstruction_error_GWAKset_H_1.npz')

In [None]:
Feature_space_noise_train = np.vstack((Reconstruction_error_L['noise_train_pattern'],Reconstruction_error_H['noise_train_pattern'])).T
Feature_space_noise_test = np.vstack((Reconstruction_error_L['noise_test_pattern'],Reconstruction_error_H['noise_test_pattern'])).T
Feature_space_bbh = np.vstack((Reconstruction_error_L['bbh_pattern'],Reconstruction_error_H['bbh_pattern'])).T
Feature_space_sg = np.vstack((Reconstruction_error_L['sg_pattern'],Reconstruction_error_H['sg_pattern'])).T

In [None]:
Geometric_mean_train = np.sqrt(Feature_space_noise_train[:,0] * Feature_space_noise_train[:,1])
Geometric_mean_test = np.sqrt(Feature_space_noise_test[:,0] * Feature_space_noise_test[:,1])
Geometric_mean_bbh = np.sqrt(Feature_space_bbh[:,0] * Feature_space_bbh[:,1])
Geometric_mean_sg = np.sqrt(Feature_space_sg[:,0] * Feature_space_sg[:,1])

In [None]:
data_train = Feature_space_noise_train
data_signal_noise = Feature_space_noise_test[np.argwhere(Geometric_mean_test > 0.0033)].reshape(-1,2)
data_signal_bbh = Feature_space_bbh[np.argwhere(Geometric_mean_bbh > 0.0033)].reshape(-1,2)

In [None]:
snr_list = np.arange(1,37,7)

In [None]:
for i in range(len(snr_list)):
    
    data_signal = np.vstack((data_signal_noise, data_signal_bbh[np.random.choice(len(data_signal_bbh), size = int(snr_list[i] * len(data_signal_noise)))]))
    data_bkg = data_train[np.random.choice(len(data_train),size = len(data_signal),  replace=False)]
    
    print(data_signal.shape)
    print(data_bkg.shape)
    
    data_signal_labeled = np.hstack((data_signal,np.ones(len(data_signal)).reshape(-1,1)))
    data_bkg_labeled = np.hstack((data_bkg,np.zeros(len(data_bkg)).reshape(-1,1)))
    
    Final_training_and_validation_set = np.vstack((data_signal_labeled, data_bkg_labeled))
    np.random.shuffle(Final_training_and_validation_set)

    trainData = torch.FloatTensor(Final_training_and_validation_set[:int(0.9 * len(Final_training_and_validation_set))])
    validationData = torch.FloatTensor(Final_training_and_validation_set[int(0.9 * len(Final_training_and_validation_set)):])
    train_dataset = TensorDataset(trainData, trainData)
    validation_dataset = TensorDataset(validationData, validationData)

    trainDataLoader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    validationDataLoader = DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle = True)


    autoencoder = SupervisedModel().cuda()
    optimizer = optim.Adam(autoencoder.parameters(), lr=0.00005)
    loss_func = nn.BCELoss().cuda()
    loss_train = np.zeros((epochs, 1))
    loss_validation = np.zeros((epochs, 1))


    for epoch in range(epochs):
        # 训练阶段
        autoencoder.train()
        for batchidx, (x, _) in enumerate(trainDataLoader):
            x = x.cuda()
            decoded = autoencoder(x[:,:2])
            # weighted_lossTrain = loss_func(decoded, x)
            loss_overall = loss_func(decoded.flatten(), x[:,-1])
            # loss_norm = loss_func(decoded[:, -1], x[:, -1])
            
            weighted_lossTrain = loss_overall
            # weighted_lossTrain = loss_overall + loss_norm * coef_delta
            
            # output_norm = torch.norm(decoded[:, :100], dim=1) 
            # penalty_term = torch.mean((output_norm - 1) ** 2) 
            # weighted_lossTrain = lossTrain + coef_delta * penalty_term
            
            optimizer.zero_grad()
            weighted_lossTrain.backward()
            optimizer.step()

        # 验证阶段
        autoencoder.eval()
        with torch.no_grad():
            val_loss = 0
            for batchidx, (x, _) in enumerate(validationDataLoader):
                x = x.cuda()
                decoded = autoencoder(x[:,:2])
                lossVal_overall = loss_func(decoded.flatten(), x[:,-1])
                lossVal_norm = loss_func(decoded[:, -1], x[:, -1])
                lossVal = lossVal_overall + lossVal_norm * coef_delta
                val_loss += lossVal.item()

            val_loss /= len(validationDataLoader)
        
        loss_train[epoch,0] = weighted_lossTrain.item()
        loss_validation[epoch,0] = val_loss
        # print('Epoch: %04d, Training loss=%.8f' % (epoch+1, weighted_lossTrain.item()))
        print('Epoch: %04d, Training loss=%.8f, Validation loss=%.8f' % (epoch+1, weighted_lossTrain.item(), val_loss))
        
    torch.save(autoencoder.cpu(),'../Model_cached/model_weakly_supervised_2-8-6-4-1_GWAK_snr'+str(snr_list[i])+'.pt')
    
    
    fig = plt.figure(figsize=(6, 3))
    ax = plt.subplot(1, 1, 1)
    ax.grid()
    ax.plot(loss_train, color=[245/255, 124/255, 0/255], linestyle='-', linewidth=2, label = 'Train set') 
    ax.plot(loss_validation, color=[245/255, 124/255, 0/255], linestyle='--', linewidth=2, label = 'Validation set') 
    ax.set_xlabel('Epoches')
    ax.set_ylabel('Loss')
    ax.set_title('Trainging loss, weakly supervised learning model, 0.1 anomaly cut, snr' +str(snr_list[i]))
    plt.ylim(0,1)
    plt.legend()
    # plt.show()
    plt.savefig('../Pic_cached/Trainging loss, weakly supervised learning model, 0.1 anomaly cut, snr'+str(snr_list[i])+',2-8-6-4-1.png')