In [None]:
import torch
import torch.nn as nn
from est_lib.util.obspy_util import *
from est_lib.dataset.seismic_dataset import CNDataset
from obspy import UTCDateTime as dt
from torch.utils.data import DataLoader

In [None]:
from torch_geometric_temporal.nn.attention import MTGNN

In [None]:
import os
from tqdm.notebook import tqdm

In [None]:
# Env Stuff
inv_file_name = 'inv.xml'
str_file_name = 'str.pkl'
file_path = os.path.abspath('../est_lib/_data/')
sta_list = ['QEPB','HOPB']

def check_files():
    contents = os.listdir(file_path)
    if inv_file_name not in contents:
        inv = inventory_retriever(network="CN",
                              sta_list=sta_list,
                              level='response')
        # Save File
        f_path = os.path.join(file_path,inv_file_name)
        op_file_path = inventory_writer(inv,
                                    f_path,
                                    file_format="STATIONXML")
    if str_file_name not in contents:
        stream = stream_retriever(event_time = dt('2021-05-26T08:46:00'),
                              seconds_before = 600,
                              seconds_after = 1500,
                              network="CN",
                              sta_list=sta_list,
                              channel_list=['HHE'])
        f_path = os.path.join(file_path,str_file_name)
        op_file_path = stream_writer(stream,
                                 f_path,
                                 file_format="PICKLE")
    print(file_path)
    print(os.listdir(file_path))

In [None]:
# Retrieve Stream and Inventory
check_files()

In [None]:
# Set Up Dataset Object
obj = CNDataset(os.path.join(file_path,inv_file_name),
                os.path.join(file_path,str_file_name),
                    sta_list=sta_list,
                    ip_dim=1,
                    num_nodes=2,
                    seq_length=101) # Logical Bug with length

In [None]:
# Set Up Dataloader
proto_batch_size = 16
train_loader = DataLoader(obj,batch_size=proto_batch_size,shuffle=False)

In [None]:
feat, lab = next(iter(train_loader))

In [None]:
print(feat.shape)
print(lab.shape)

In [None]:
# Set Up CUDA
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    
obj.data = obj.data.to(device)
obj.labels = obj.labels.to(device)

In [None]:
print(obj.data.device)
print(obj.labels.device)

In [None]:
# Set Up Model
gcn_true = True
build_adj = True
dropout = 0.3
subgraph_size = 2
gcn_depth = 2
num_nodes = obj.num_nodes # Important
node_dim = 3 # Need to verify from Paper
dilation_exponential = 1
conv_channels = 32
residual_channels = 32
skip_channels = 64
end_channels = 128
in_dim = obj.ip_dim # Important
seq_in_len = obj.seq_length-1 # Important # seq-len bug to be adressed
seq_out_len = 1
layers = 3
batch_size = proto_batch_size # Set by Data Loader
propalpha = 0.05
tanhalpha = 3
num_split = 1
num_edges = 10
kernel_size = 7
kernel_set = [2, 3, 6, 7]

In [None]:
model = MTGNN(gcn_true=gcn_true, build_adj=build_adj, gcn_depth=gcn_depth, num_nodes=num_nodes,
                kernel_size=kernel_size, kernel_set=kernel_set, dropout=dropout, subgraph_size=subgraph_size,
                node_dim=node_dim, dilation_exponential=dilation_exponential,
                conv_channels=conv_channels, residual_channels=residual_channels,
                skip_channels=skip_channels, end_channels=end_channels,
                seq_length=seq_in_len, in_dim=in_dim, out_dim=seq_out_len,
                layers=layers, propalpha=propalpha, tanhalpha=tanhalpha, layer_norm_affline=True)
#model.double()
model.to(device)

In [None]:
# Train Param
criterion = nn.MSELoss(size_average=False)
evaluateL2 = nn.MSELoss(size_average=False)
evaluateL1 = nn.L1Loss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)

In [None]:
# Train Loop
losses = list()
epochs = 1
model.train() # Train Mode
for e in tqdm(range(epochs)):
    for bid, samp in tqdm(enumerate(train_loader)):
        model.zero_grad()
#        if bid == 1:
#           break
        x = samp[0]
        y = torch.unsqueeze(samp[1],3) # or squeeze op
        # cuda stuff goes here
        op = model(x.float())
        loss = criterion(op,y.float())
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

In [None]:
print(losses)

In [None]:
len(losses)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.plot(losses)
#plt.yscale('log')
plt.show()

In [None]:
model._graph_constructor(model._idx.to(device))

In [None]:
op

In [None]:
# Eval Loop
test_loader = DataLoader(obj,batch_size=1000,shuffle=False)
loss_eval = list()
op_eval = list()
model.eval()
for e in tqdm(range(epochs)):
    for bid, samp in tqdm(enumerate(test_loader)):
#        if bid == 1:
#           break
        x = samp[0]
        # cuda stuff goes here
        op = model(x.float())
        op_eval.append(op.detach().cpu())
        y = torch.unsqueeze(samp[1],3) # or squeeze op
        loss = criterion(op,y.float())
        loss_eval.append(loss.item())
        del loss

In [None]:
op_eval