In [None]:
import os
import numpy as np
import pickle as pkl
from nn_model import vanilla_nn
from pprint import pprint
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
dis_npy_dir = './disrupt_630_npy'
nondis_npy_dir = './nondisrupt_1136_npy/'

In [None]:
dis_header = np.load(os.path.join(dis_npy_dir, 'header_arr.npy'))
nondis_header = np.load(os.path.join(nondis_npy_dir, 'header_arr.npy'))
dis_time = np.load(os.path.join(dis_npy_dir, 'time_arr.npy'))
nondis_time = np.load(os.path.join(nondis_npy_dir, 'time_arr.npy'))

print((dis_header.astype(str) == nondis_header.astype(str)).all())
print((dis_time==nondis_time))

### Decide final time index

In [None]:
final_time = -150
for time_idx in range(dis_time.shape[0]):
    if dis_time[time_idx] > final_time:
        final_time_idx = time_idx - 1
        break
print(final_time_idx)

### Decide signals to use

In [None]:
signal_list = dis_header.astype(str).tolist()
pprint(signal_list)
use_signal = ['ip', 'efsbetan','efsli', 'efsvolume', 'pinj']
use_signal_idx = [signal_list.index(x) for x in use_signal]
num_state = 4
num_action = 1
print(len(signal_list))
print(use_signal_idx)

In [None]:
dis_arrs = []
for item in os.listdir(dis_npy_dir):
    if 'header' in item or 'time' in item:
        continue
    arr = np.load(os.path.join(dis_npy_dir, item))
    try:
        filtered_arr = arr[use_signal_idx, :final_time_idx]
        if np.isnan(filtered_arr).any():
            print('{} has nans'.format(item))
            continue
        dis_arrs.append(filtered_arr)
    except:
        print(arr.shape)

In [None]:
num_dis_shots = len(dis_arrs)

In [None]:
full_dis_arr = np.array(dis_arrs)
print(full_dis_arr.shape)

### Make the dataset by scanning through time

In [None]:
pred_interval = 250
total_timesteps = full_dis_arr.shape[-1]
print(total_timesteps)

In [None]:
a = np.arange(10).reshape(2,5)

In [None]:
a[0,0:3]

In [None]:
['ip', 'efsbetan','efsli', 'efsvolume', 'pinj']

In [None]:
states = []
actions = []
rewards = []
sa_pairs = []

for start_time_idx in range(total_timesteps):
    start_time = dis_time[start_time_idx]
    # find end time idx
    end_time_idx = None
    for temp_time_idx in range(dis_time.shape[0]):
        if dis_time[temp_time_idx] > start_time + pred_interval:
            end_time_idx = temp_time_idx
            break
    
    # stop if beyond current data time range
    if end_time_idx >= total_timesteps:
        break
    
    curr_state = full_dis_arr[:,:-1,start_time_idx]
    curr_action = np.mean(full_dis_arr[:,-1,start_time_idx:end_time_idx], axis=1).reshape(-1,1)
    curr_reward = full_dis_arr[:,1,end_time_idx]
    curr_sa = np.concatenate([curr_state, curr_action], axis=1)
    
    states.append(curr_state)
    actions.append(curr_action)
    rewards.append(curr_reward)
    sa_pairs.append(curr_sa)

In [None]:
states[0].shape

In [None]:
train_prop = 0.8
num_train_shots = int(num_dis_shots * train_prop)
train_shot_idx = np.random.choice(num_dis_shots, 
                                  size=num_train_shots,
                                  replace=False)
test_shot_idx = [i for i in range(num_dis_shots) if i not in train_shot_idx]

In [None]:
np.array(rewards).shape

In [None]:
train_X = np.array(sa_pairs)[:,train_shot_idx,:].reshape(-1, num_state+num_action)
train_y = np.array(rewards)[:,train_shot_idx].reshape(-1, 1)
test_X = np.array(sa_pairs)[:,test_shot_idx,:].reshape(-1, num_state+num_action)
test_y = np.array(rewards)[:,test_shot_idx].reshape(-1, 1)
print('Train: {}, {}'.format(train_X.shape, train_y.shape))
print('Test: {}, {}'.format(test_X.shape, test_y.shape))

### Define model and train

#### Normalize features

In [None]:
train_X_mean = np.mean(train_X, axis=0)
train_X_std = np.std(train_X, axis=0)
train_y_mean = np.mean(train_y, axis=0)
train_y_std = np.std(train_y, axis=0)

In [None]:
normalized_train_X = (train_X - train_X_mean)/train_X_std
normalized_train_y = (train_y - train_y_mean)/train_y_std
print(normalized_train_X.shape, normalized_train_y.shape)

In [None]:
normalized_test_X = (test_X - train_X_mean)/train_X_std
normalized_test_y = (test_y - train_y_mean)/train_y_std
print(normalized_test_X.shape, normalized_test_y.shape)


In [None]:
num_train = normalized_train_X.shape[0]
num_test = normalized_test_X.shape[0]

In [None]:
rand_train_idx = np.random.choice(num_train, size=100000)
subset_train_X = train_X[rand_train_idx]
subset_train_y = train_y[rand_train_idx]

#### Define hyperparams

In [None]:
hidden_size = 2048
num_layers = 5
lr = 1e-3
batch_size = 4096
max_epochs = int(1e6)

In [None]:
gpu_idx = 3
cuda_device = "cuda:{}".format(gpu_idx)
use_cuda = torch.cuda.is_available()
device = torch.device(cuda_device if use_cuda else "cpu")
print('Using device ', device)

#### Define model

In [None]:
betan_net = vanilla_nn(input_size=5, output_size=1,
                      hidden_size=hidden_size, num_layers=num_layers,
                      use_bn=True).to(device)

In [None]:
opt = optim.Adam(betan_net.parameters(), lr=lr)

In [None]:
criterion = betan_net.loss

In [None]:
class FusionDataset(Dataset):
    def __init__(self, X_arr, y_arr):
        self.X = X_arr
        self.y = y_arr.reshape(-1,1)
        
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, idx):
        X = self.X[idx]
        y = self.y[idx]
        
        return X,y

In [None]:
train_dataset = FusionDataset(subset_train_X, subset_train_y)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
subset_train_X

#### Train

In [None]:
torch.cuda.empty_cache()

In [None]:
rand_test_idx = np.random.choice(num_test, size=3000)

In [None]:
running_train_loss = []
running_test_loss = []

for epoch in tqdm(range(max_epochs)):
    for batch_idx, batch_data in enumerate(train_loader):
        batch_X, batch_y = batch_data
        batch_X, batch_y = (batch_X.float()).to(device), (batch_y.float()).to(device)

        opt.zero_grad()

        batch_pred = betan_net(batch_X)
        loss = criterion(batch_pred, batch_y)
        loss.backward()
        opt.step()
        
        
    if epoch % 50 == 0:
        print('Epoch {0} finished: loss {1:.4f}'.format(epoch, loss.item()))
        running_train_loss.append(loss.item())
    if epoch % 50 == 0:
        with torch.no_grad():
            
            rand_test_X = (torch.from_numpy(test_X[rand_test_idx]).float()).to(device)
            rand_test_y = (torch.from_numpy(test_y[rand_test_idx]).float()).to(device)
            test_loss = criterion(betan_net(rand_test_X), rand_test_y)
            running_test_loss.append(test_loss.item())
            print('Epoch {0}: rand test loss {1:.4f}'.format(epoch, test_loss.item()))
            

In [None]:
plt.plot(np.arange(len(running_train_loss))*50, running_train_loss, label='train loss')
plt.plot(np.arange(len(running_test_loss))*50, running_test_loss, label='test loss')
plt.xlabel('Training epoch')
plt.ylabel('Loss')
plt.legend()
fig = plt.gcf()
fig.set_size_inches(10, 8)

In [None]:
betan_net.eval()

In [None]:
rand_test_idx = np.random.choice(num_test, size=5000)
rand_test_X = torch.from_numpy(test_X[rand_test_idx]).float().to(device)
rand_test_y = torch.from_numpy(test_y[rand_test_idx]).float().to(device)
betan_net(rand_test_X)

In [None]:
rand_test_X.shape

In [None]:
order = torch.argsort(rand_test_X[:,1])

In [None]:
rand_test_X[:,1][order].cpu()

In [None]:
print(criterion(betan_net(rand_test_X), rand_test_y).item())
with torch.no_grad():
    plt.plot(rand_test_X[:,1][order].cpu(), rand_test_y[order].cpu(), label='label')
    plt.plot(rand_test_X[:,1][order].cpu(), betan_net(rand_test_X)[order].cpu(), alpha = 0.5, label='prediction')
    plt.xlabel('Starting betan')
    plt.ylabel('Predicted betan')
    plt.legend()
fig = plt.gcf()
fig.set_size_inches(15,10)