In [7]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import pandas as pd
from datetime import datetime
from data_handling_L1 import get_data, sliding_data

In [8]:
#Define paths
PATH_TO_L1 = '//NAS24/solo/remote/data/L1'
PATH_TO_MODEL = 'C:/Githubs/kandidat/Low_freq_files/Neural Network/model_low_freq.pt'
#Use GPU if possible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
def pre_process(E, V, EPOCH, start_ind, window_size=512):
    '''Preprocesses data for model.
    1. Slices data from start_ind and window_size
    2. Shapes into correct format
    3. Removes bias
    4. Normalizes each input channel with respect to max
    Returns pytorch tensor'''
    ind = start_ind
    #Slice the data for prediction
    time_processed = (np.array(EPOCH[ind:ind+window_size]) - EPOCH[ind]) / 10**9 #convert ns to s
    E1_window = np.array(E[ind:ind+window_size, 0])
    E2_window = np.array(E[ind:ind+window_size, 1])
    V_window = np.array(V[ind:ind+window_size])
    
    #Reshape the data
    data_shaped = np.array([V_window, E1_window, E2_window]).reshape(3, 512)
    
    #Remove bias
    median = np.median(data_shaped, axis=1, keepdims=True)
    data_nobias = data_shaped - median
    
    #Normalize data for each channel (3)
    max_vals = np.max(np.abs(data_nobias), axis=1, keepdims=True)
    data_normalized = data_nobias / max_vals
    
    maxes = np.max(data_nobias, axis=1, keepdims=True)
    minis = np.min(data_nobias, axis=1, keepdims=True)
    '''
    for mn, mx in minis, maxes:
        if mx-mn > 10000:
            '''
    return time_processed, data_normalized#, problem

In [10]:
#Define the architecture for the neural net

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=3, out_channels=128, kernel_size=8, stride=1)
        self.bn1 = nn.BatchNorm1d(128)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, stride=1)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU()

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        #self.maxpool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(128, 2)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.tensor(x, dtype=self.conv1.weight.dtype).to(device)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.avgpool(x)
        #x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.softmax(x)

        return x

In [11]:
#Create neural network
model = ConvNet()

#Load trained variables
model.load_state_dict(torch.load(PATH_TO_MODEL, map_location=device))
model.to(device)

#Set evaluation mode
model.eval()

ConvNet(
  (conv1): Conv1d(3, 128, kernel_size=(8,), stride=(1,))
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv2): Conv1d(128, 256, kernel_size=(5,), stride=(1,))
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (conv3): Conv1d(256, 128, kernel_size=(3,), stride=(1,))
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (avgpool): AdaptiveAvgPool1d(output_size=1)
  (fc): Linear(in_features=128, out_features=2, bias=True)
  (softmax): Softmax(dim=1)
)

In [12]:
window_size = 512
overlap = 0.2
batch_size = 100
start_date_str = '20220101' #20220302
end_date_str = '20220102'

start_date = datetime.strptime(start_date_str, '%Y%m%d')
end_date = datetime.strptime(end_date_str, '%Y%m%d')

plot = 'day'


%matplotlib inline
%matplotlib qt
data_dic = {}
for root, dirs, files in os.walk(PATH_TO_L1):    #iterate over L1 data
    for file in files:
        if 'rpw-lfr-surv-cwf-cdag' in file:
                date_str = file.split('_')[3]
                date = datetime.strptime(date_str, '%Y%m%d')
                if start_date <= date < end_date:
                    CURRENT_PATH = f'{PATH_TO_L1}/{file[-16:-12]}/{file[-12:-10]}/{file[-10:-8]}/{file}'
                    #Load file
                    E, V, EPOCH  = get_data(CURRENT_PATH)
                    #Slice day into windows
                    start_indices = sliding_data(E, overlap, window_size)
                    good_ind = []
                    ind_dust = []
                    predictions = np.array([])
                    good_pred = []
                    recent_pos = False
                    for i in range(0, len(start_indices), batch_size):
                        batch_indices = start_indices[i:i+batch_size]
                        batch_time = []
                        batch_data = []
                        for ind in batch_indices:
                            #Preprocess data for prediction
                            time, data = pre_process(E, V, EPOCH, ind)
                            if time[-1] - time[0] < 32: 
                                batch_time.append(time)
                                batch_data.append(data)
                                good_ind.append(ind)
                            else:
                                print('Gap in data detected')
                        model_data = torch.from_numpy(np.stack(batch_data, 0)).to(device)       
                        batch_pred = model(model_data).cpu().detach().numpy()[:,1]
                        predictions = np.append(predictions, batch_pred)
                        print(len(predictions))
                    print(len(predictions))
                    print(len(good_ind))    
                    for i in range(len(predictions)):
                        if predictions[i] > 0.5:
                            ind_dust.append(good_ind[i]) 
                            
                    '''if plot == 'window' and predictions[-1] > 0.5:
                        fig, axs = plt.subplots(3, 1, sharex=True, sharey=True)
                        titles = ['V', 'E1', 'E2']
                        for i in range(3):
                            axs[i].plot(time, data[0, i, :])
                            axs[i].set_title(titles[i])
                        fig.suptitle(f'Prediction = {predictions[-1]:.2f}')
                        fig.supxlabel('Time [s]')
                        plt.show()'''

                    print(date)
                    data_dic[date_str] = EPOCH[ind_dust]
                          
                    if plot == 'day':
                        fig, axs = plt.subplots(3, 1, sharex=True)
                        titles = ['V', 'E1', 'E2']
                        Ys = [V, E[:,0], E[:,1]]
                        for i in range(3):
                            axs[i].plot(EPOCH, Ys[i])
                            axs[i].set_title(titles[i])
                            for ind in ind_dust:
                                axs[i].axvspan(EPOCH[ind], EPOCH[ind+window_size], alpha=0.5, color='green')
                        fig.supxlabel('Time [s]')
                        fig.suptitle(date)
                        plt.show()
                    
new_df = pd.DataFrame.from_dict(data_dic, orient='index')
new_df = new_df.transpose()
new_df.to_pickle('data.pkl')

  x = torch.tensor(x, dtype=self.conv1.weight.dtype).to(device)


100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
Gap in data detected
3399
3499
3599
3699
3799
3899
3999
4099
4199
4299
4399
4499
4599
4699
4799
Gap in data detected
Gap in data detected
4897
4997
Gap in data detected
5096
5196
5296
5396
5496
5596
5696
5796
5896
5996
6096
6196
6296
6396
6496
6596
6696
6796
6896
6996
7096
7196
7296
7396
7496
7596
7696
7796
7896
7996
8096
8196
8296
8396
8496
8596
8696
8796
8896
8996
9096
9196
9296
9396
9496
9596
9696
9796
9896
9996
10096
10196
10296
10396
10496
10596
10696
10796
10896
10996
11096
11196
11296
11396
11496
11596
11696
11796
11896
11996
12096
12196
12296
12396
12496
12596
12696
12796
12896
12996
13096
13196
13296
13396
13496
13596
13696
13796
13896
13996
14096
14196
14296
14396
14496
14596
14696
14796
14896
14996
15096
15196
15296
15396
15496
15596
15696
15796
15896
15996
16096
16196
16296
16396
16496
16596
16696
16796
16896
16996
17096