In [1]:
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm
import os
import pickle
import time

%load_ext autoreload
%autoreload 2

This notebooks is for training deep learning models

In [2]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
print(torch.cuda.is_available())
display([(gpu, gpu.idx) for gpu in available_gpus])

False




[]

In [3]:
# data = np.load('../Processed/TrainData/R2B5_vcg_20221118-153949.npz')
data_path = '../local_data/TrainData/20230131-171851-R2B5_y13y16_vcg-fluxes_rho_fluct.npz'
# data_path = '../local_data/TrainData/20230210-131835-R2B5_y13y16_vcg-fluxes_rho_fluct_neglect.npz'
data = np.load(data_path)
print(data.files)

X_train, X_val, X_test, Y_train, Y_val, Y_test, X_expl, Y_expl = \
data['X_train'], data['X_val'], data['X_test'], data['Y_train'], data['Y_val'], data['Y_test'], data['X_expl'], data['Y_expl']

# Convert Data to torch Tensors and permute to conform to pytorch channels first format
transform_to_unet_shape = False
if transform_to_unet_shape:
    x_transform = nn.Upsample(size=(32), mode='linear')
else:
    x_transform = nn.Identity()

X_train = x_transform(torch.Tensor(X_train).permute(0,2,1))
X_val = x_transform(torch.Tensor(X_val).permute(0,2,1))
X_test = x_transform(torch.Tensor(X_test).permute(0,2,1))
Y_train = torch.Tensor(Y_train)
Y_val = torch.Tensor(Y_val)
Y_test = torch.Tensor(Y_test)

with open('../local_data/TrainData/20230131-171851-R2B5_y13y16_vcg-fluxes_rho_fluct_Ymask.pickle', 'rb') as handle:
    Y_mask = pickle.load(handle)

print('X_train shape: ', X_train.shape)
print('X_val shape: ', X_val.shape)
print('X_test shape: ', X_test.shape)
print('Y_train shape: ', Y_train.shape)
print('Y_val shape: ', Y_val.shape)
print('Y_test shape: ', Y_test.shape)
print('len X_expl', len(X_expl))
print('len Y_expl', len(Y_expl))

['X_train', 'X_val', 'X_test', 'Y_train', 'Y_val', 'Y_test', 'X_expl', 'Y_expl', 'train_coords', 'val_coords', 'test_coords']
X_train shape:  torch.Size([1613616, 9, 23])
X_val shape:  torch.Size([201702, 9, 23])
X_test shape:  torch.Size([201702, 9, 23])
Y_train shape:  torch.Size([1613616, 189])
Y_val shape:  torch.Size([201702, 189])
Y_test shape:  torch.Size([201702, 189])
len X_expl 207
len Y_expl 189


In [4]:
from HelperFuncs import unique_unsorted

# vars_to_neglect = ['qr','qi','qs']
vars_to_neglect = ['qr','qs']
# vars_to_neglect = []
vars_to_neglect_mask = ~np.isin(unique_unsorted([e[0] for e in X_expl]), vars_to_neglect)
print(vars_to_neglect_mask)

X_train = X_train[:,vars_to_neglect_mask,:]
X_val = X_val[:,vars_to_neglect_mask,:]
X_test = X_test[:,vars_to_neglect_mask,:]
X_expl = np.array([e for e in X_expl if e[0] not in vars_to_neglect])

[ True  True  True  True False False  True  True  True]


In [6]:
# batch_size = 512
batch_size = 1024
batch_size_val = 1024
# Create data loaders.
train_data = TensorDataset(X_train, Y_train)
val_data = TensorDataset(X_val, Y_val)
test_data = TensorDataset(X_test, Y_test)
# torch.save([train_data, val_data, test_data], '../local_data/TrainData/20230131-171851-R2B5_y13y16_vcg-fluxes_rho_fluct_woqrqs.torch_data')
# torch.save([train_data, val_data, test_data], '../local_data/TrainData/20230131-171851-R2B5_y13y16_vcg-fluxes_rho_fluct_woqrqs_rhofluctneglect.torch_data')
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size_val, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size_val, shuffle=False)

for X, y in val_dataloader:
    print('---------------------------------------')
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

---------------------------------------
Shape of X [N, C, H, W]: torch.Size([1024, 7, 23])
Shape of y: torch.Size([1024, 189]) torch.float32


In [74]:
from convection_param.NetworksTorch import ResDNN, Sequential, Unet, SeqConv

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

n_channels=7
# model = ResDNN(in_size=23*n_channels,
#                out_size=189,
#                n_neurons=2048,
#                bn=True,
#                n_layers_per_block=4,
#                n_levels=10,
#                activation=nn.LeakyReLU())
model = Unet(n_channels=n_channels,
                n_classes=8,
                output_channels_total=189,
                n_levels=2,
                n_features=512,
                bn1=False,
                bn2=False,
                column_height=23,
                activation=F.leaky_relu,
                linear=False)
# model = Sequential(input_dim=X_train.shape[1]*X_train.shape[2],
#                     output_dim=189,
#                     n_hidden=2048,
#                     n_layers=5,
#                     activation=F.leaky_relu,#nn.Identity(),#
#                     bn=True)
# model = SeqConv(n_channels=X_train.shape[1],
#                 n_feature_channels=1024,
#                 column_height=23,
#                 n_hidden=1024,
#                 n_layers=1,
#                 output_dim=189,
#                 activation=F.leaky_relu,
#                 kernel_size=5).to(device)

# model = nn.DataParallel(model)
model.to(device)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

Using cpu device


In [75]:
X_train[:2].shape

torch.Size([2, 7, 23])

In [76]:
test_pred = model(X_train[:2].to(device));

In [77]:
# from torchview import draw_graph

# # model_graph = draw_graph(model, input_data=X_train[:1], depth=2, graph_dir='BT')#, save_graph=True, filename='torchview_unet_small')#, expand_nested=True)
# model_graph = draw_graph(model, input_data=X_train[:1], expand_nested=True)#, graph_dir='BT')#, save_graph=True, filename='torchview_unet_small')#, expand_nested=True)
# model_graph.visual_graph

In [78]:
# save_path = f"../Models/Optimized/Torch/test"
# torch.save(model, save_path)

In [79]:
def train(tepoch, model, loss_fn, optimizer, epoch, writer=None):
    # size = len(dataloader.dataset)
    size = tepoch.total
    model.train()
    # for batch, (X, y) in enumerate(dataloader):
    loss_sum = 0
    for batch, (X, y) in enumerate(tepoch):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        loss_sum += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 300 == 0:
            tepoch.set_postfix(loss=loss.item())
    if writer:
        writer.add_scalar('epoch_loss', loss_sum/size, epoch)

def test(dataloader, model, loss_fn, epoch, writer=None):
    # size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    if writer:
        writer.add_scalar('epoch_loss', test_loss, epoch)
    print(f"Avg val loss: {test_loss:>8f} \n")

In [13]:
from torch.utils.tensorboard import SummaryWriter
import datetime

now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# model_name = now + "R2B5_vlr_seqbndouble_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_linear"
# model_name = now + "R2B5_vlr_seqbndouble_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_hpoed"
# model_name = now + "R2B5_vlr_resDNNbn_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_hpoed"
model_name = now + "R2B5_vlr_unet_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluctneglect_alldays_woqrqs"
# model_name = model_name + "further_train"
# model_name = now + "R2B5_vlr_conv_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays"
log_dir = "../logs/from011222/" + model_name

writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))
writer_val = SummaryWriter(log_dir=os.path.join(log_dir, 'validation'))

epochs = 150
for t in range(epochs):
    with tqdm(train_dataloader, unit="batch") as tepoch:
        tepoch.set_description(f'Epoch {t+1}')
        train(tepoch, model, loss_fn, optimizer, t, writer_train)
        test(val_dataloader, model, loss_fn, t, writer_val)
print("Done!")

  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.254294 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.238535 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.184547 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.180252 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.168314 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.165747 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.156167 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.152745 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.154183 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.149269 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.162412 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.150933 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.160097 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.142658 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.143328 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.142092 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.140812 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.140354 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.138458 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.137427 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.138280 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.138448 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.137276 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.135696 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.135135 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.142681 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.132961 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.135331 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.135286 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.134823 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.132662 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.132704 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.132314 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.131309 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.131680 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130382 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.131300 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.129968 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.129582 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.132202 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130790 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130822 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130261 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.129823 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130027 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130562 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130246 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.128746 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130493 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.128840 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.128722 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.128834 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.129643 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.129107 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.129347 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.128644 



  0%|          | 0/1576 [00:00<?, ?batch/s]

Avg val loss: 0.130094 



  0%|          | 0/1576 [00:00<?, ?batch/s]


KeyboardInterrupt



In [None]:
model_path = f'../Models/NewFormat/Torch/{model_name}'
os.makedirs(model_path, exist_ok=True)
save_path = os.path.join(model_path, 'model')
data_desc_path = os.path.join(model_path, 'data_desc.txt')
with open(data_desc_path, 'w') as f:
    f.write('data_path: ' + data_path)
    f.write('\nvars_to_neglect' + str(vars_to_neglect))
# model_scripted = torch.jit.script(model) # Export to TorchScript
# model_scripted.save(f'{save_path}.pt') # Save
# torch.save(model, save_path)
torch.save({
            'model_state_dict': model.state_dict(),
            }, f"{save_path}.state_dict")

In [70]:
# def load_trainval_data(data_file, config):
#     datasets = torch.load(data_file)

#     train_data = datasets[0]
#     val_data = datasets[1]
#     # test_data = datasets[2]
    
#     train_dataloader = DataLoader(train_data, batch_size=config, shuffle=True)
#     val_dataloader = DataLoader(val_data, batch_size=1024, shuffle=False)
#     # test_dataloader = DataLoader(test_data, batch_size=1024, shuffle=False)

#     return train_dataloader, val_dataloader

# data_file = '../local_data/TrainData/20230131-171851-R2B5_y13y16_vcg-fluxes_rho_fluct.torch_data'
# trainloader, valloader = load_trainval_data(data_file, 1024)

In [80]:
# model_path = "../Models/NewFormat/Torch/20230415-021732R2B5_vlr_unet_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_hpoed"
# model_path = "../Models/NewFormat/Torch/20230324-110858R2B5_vlr_resDNNbn_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_hpoed"
# model_path = "../Models/NewFormat/Torch/20230414-111730R2B5_vlr_seqbndouble_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_hpoed"
# model_path = "../Models/NewFormat/Torch/20230314-172546R2B5_vlr_conv_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_hpoed"
model_path = "../Models/NewFormat/Torch/20230510-012946R2B5_vlr_unet_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_woqrqs_hpoed"

model_path_normed = os.path.normpath(model_path)
model_name = os.path.basename(model_path_normed)
load_path = os.path.join(model_path, "model.state_dict")
print('model_name: ', model_name)

state_dict = torch.load(load_path, map_location=torch.device('cpu'))['model_state_dict']
print(f'Number of parameters in model from state_dict approximated: {sum(p.numel() for p in state_dict.values())}')

checkpoint = torch.load(load_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])

model_name:  20230509-221550R2B5_vlr_unet_adam_lr0.0003_y13y16full_fluxes_prescaledeps1_wqrqstend_worhoprestemp_torch_rhofluct_alldays_woqrqs_hpoed
Number of parameters in model from state_dict approximated: 41524687


<All keys matched successfully>

In [81]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [82]:
count_parameters(model)
#seq:      17615037
#unet:     41507269
#resdnn:   168831165
#conv:     46566589

#WoQiQrQs
#seq:     17473725
#unet:    41502661
#resdnn:  101482685
#conv:    6493629

#WoQrQs
#seq:     21721277
#unet:    41504197
#resdnn:  135133373
#conv:    5445053

41504197

In [17]:
from sklearn.metrics import r2_score, mean_squared_error, d2_pinball_score
from HelperFuncs import calc_correlation, compute_correlation_per_var
import datetime

def predict(dataloader, model):
    model.eval()
    Y_pred = []
    with torch.no_grad():
        for X, _ in tqdm(dataloader):
            X = X.to(device)
            pred = model(X)
            Y_pred.append(pred)
    return torch.concatenate(Y_pred)


model.to(device)
Y_pred = predict(test_dataloader, model)

# now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# np.savez(os.path.join(model_path, 'TestPredictions'), Y_test=Y_test, Y_pred=Y_pred.cpu().numpy(), X_expl=X_expl, Y_expl=Y_expl)
# # np.savez(f'../local_data/TestPredictions/{now}-{model_name}', Y_test=Y_test, Y_pred=Y_pred.cpu().numpy(), X_expl=X_expl, Y_expl=Y_expl)

print('R2: ', r2_score(Y_test, Y_pred.cpu().numpy(), multioutput='variance_weighted'))
print('RMSE: ', loss_fn(Y_test.to(device), Y_pred))
print('R: ', compute_correlation_per_var(Y_test.numpy()[:,Y_mask], Y_pred.cpu().numpy()[:,Y_mask], multioutput='variance_weighted'))
# print('R: ', compute_correlation_per_var(Y_test.numpy(), Y_pred.cpu().numpy(), multioutput='variance_weighted'))
print('R_flattened_data: ', np.corrcoef(Y_test.numpy().flatten(), Y_pred.cpu().numpy().flatten())[0,1])

  0%|          | 0/197 [00:00<?, ?it/s]

R2:  0.8695710138292051
RMSE:  tensor(0.1279, device='cuda:0')
R:  0.924169401780359
R_flattened_data:  0.9347754415361269


In [None]:
# from captum.attr import FeaturePermutation
# 
# def forward_func(X, Y):
    # # return r2_score(model(X).cpu().numpy(), Y.cpu().numpy())
    # return loss_fn(Y, model(X))
# 
# fp = FeaturePermutation(forward_func)
# 
# fp_result = fp.attribute((X_test[:512].to(device), Y_test[:512].to(device)))
# 
# permutation_importances = fp_result[0].cpu().numpy().flatten()
# 
# plt.figure()
# plt.bar(range(len(permutation_importances)), permutation_importances)
# plt.show()

# from captum.attr import IntegratedGradients

# ig = IntegratedGradients(model)
# X_explain_captum = torch.clone(X_explain)
# X_explain_captum.requires_grad_()
# attr, delta = ig.attribute(X_explain_captum, target=1, return_convergence_delta=True)
# attr.detach().numpy()