In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Import standard modules
import numpy as np
import h5py
import os
import pickle

In [3]:
# Import torch modules
import torch
from torch import FloatTensor, cat, from_numpy
from torch.autograd import Variable
from torch.optim import Adam
from torchsummary import summary

In [4]:
import sys
sys.path.append('../src/')

In [5]:
# Import transform net
from model_transform_net import ModelTransformNet

In [6]:
# Import utils function
from utils import dot_dict, load_model, gram_matrix

In [7]:
# Import C3D model
from c3d import C3D, load_c3d_weights

#### 1. Specify options

In [8]:
device = 'cuda'

In [9]:
data_dir = './data/'

In [10]:
args = {}
args['m_petrel'] = os.path.join(data_dir, 'm_petrel_train3000_case1.h5')  # Petrel training models
args['m_pca_rec'] = os.path.join(data_dir, 'm_pca_rec_train3000_case1.h5') # Reconstructed PCA models
args['m_pca'] = os.path.join(data_dir, 'm_pca_train3000_case1.h5')         # New PCA models
args['hard_data'] = os.path.join(data_dir, 'hard_data_case1.pickle')
args['gamma_s'] = 100.0  # weighting factor for style loss
args['gamma_r'] = 500.0 # weighting factor for reconstruction loss
args['gamma_h'] = 10.0  # weighting factor for hard data loss
args['c3d_model'] =  '../src/c3d.pickle'
args['epochs'] = 10
args['batch_size'] = 8  # reduce batch size if GPU memory overflows
args['save_model'] = './weights/fw_weights_case1'
args['log_interval'] = 10 
args = dot_dict(args)

#### 2. Load data

In [11]:
# Load m_petrel
m_petrel = load_model(args.m_petrel).astype(np.float32)
m_petrel = m_petrel.transpose((0, -1, 1,2,3))
m_petrel.shape

(3000, 1, 40, 60, 60)

In [12]:
# Load m_pca_rec
m_pca_rec = load_model(args.m_pca_rec).astype(np.float32)
m_pca_rec = m_pca_rec.transpose((0, -1, 1,2,3))
m_pca_rec.shape

(3000, 1, 40, 60, 60)

In [13]:
# Load m_pca
m_pca = load_model(args.m_pca).astype(np.float32)
m_pca = m_pca.transpose((0, -1, 1,2,3))
m_pca.shape

(3000, 1, 40, 60, 60)

In [14]:
# Normalize data
max_, min_ = 1., 0.
m_petrel = (m_petrel - min_) / (max_ - min_) * 255.
m_pca_rec = (m_pca_rec - min_) / (max_ - min_) * 255.
m_pca = (m_pca - min_) / (max_ - min_) * 255.

In [15]:
# Load hard data
with open(args.hard_data, 'rb') as fid:
    well_hd = pickle.load(fid)
# Normalize well_hd
for wn in well_hd:
    well_hd[wn][:, -1] = (well_hd[wn][:,-1] - min_) / (max_ - min_) * 255.
# Assemble hard data into one single np.array
well_hd_all = np.concatenate(list(well_hd.values()), axis=0)
print('Total number of hard data:', well_hd_all.shape[0])

Total number of hard data: 160


#### 3. Construct C3D Net

In [16]:
c3d = C3D()
c3d.load_state_dict(load_c3d_weights(c3d, args.c3d_model))
c3d = c3d.to(device)
c3d.eval()

C3D(
  (conv1): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool1): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv3a): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv3b): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool3): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv4a): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv4b): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool4): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (softmax): Softmax(dim=None)
)

#### 4. Construct Model Transform Net

In [17]:
trans_net = ModelTransformNet()
trans_net = trans_net.to(device)
summary(trans_net, input_size=(1, 40, 60, 60), device=device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
      CirularPad3d-1        [-1, 1, 42, 68, 68]               0
            Conv3d-2       [-1, 32, 40, 60, 60]           7,808
         ConvLayer-3       [-1, 32, 40, 60, 60]               0
       BatchNorm3d-4       [-1, 32, 40, 60, 60]              64
              ReLU-5       [-1, 32, 40, 60, 60]               0
      CirularPad3d-6       [-1, 32, 42, 62, 62]               0
            Conv3d-7       [-1, 64, 20, 30, 30]          55,360
         ConvLayer-8       [-1, 64, 20, 30, 30]               0
       BatchNorm3d-9       [-1, 64, 20, 30, 30]             128
             ReLU-10       [-1, 64, 20, 30, 30]               0
     CirularPad3d-11       [-1, 64, 22, 32, 32]               0
           Conv3d-12      [-1, 128, 20, 15, 15]         221,312
        ConvLayer-13      [-1, 128, 20, 15, 15]               0
      BatchNorm3d-14      [-1, 128, 20,

#### 6. Training

In [18]:
optimizer = Adam(trans_net.parameters())
mae_loss = torch.nn.L1Loss()

In [19]:
num_model = m_petrel.shape[0]
num_batch = num_model // args.batch_size if num_model % args.batch_size == 0 else num_model // args.batch_size + 1
num_batch

750

In [20]:
def compute_hd_loss(y_pred, well_hd_all):
    ix = list(well_hd_all[:,0].astype('int'))
    iy = list(well_hd_all[:,1].astype('int'))
    iz = list(well_hd_all[:,2].astype('int'))
    v = Variable(torch.from_numpy(well_hd_all[:, -1]).float()).to(device)
    hd_loss = mae_loss(y_pred[:,0,iz,iy,ix], v)
    return hd_loss

In [None]:
for e in range(args.epochs):
    trans_net.train()

    for ib in range(num_batch):
        optimizer.zero_grad()
        
        ind0, ind1 = ib * args.batch_size, min((ib+1) * args.batch_size, num_model)
        
        # convert data to Variable
        m_petrel_var = Variable(torch.from_numpy(m_petrel[ind0:ind1, ...]).float()).to(device)
        m_pca_rec_var = Variable(torch.from_numpy(m_pca_rec[ind0:ind1, ...]).float()).to(device)
        m_pca_var = Variable(torch.from_numpy(m_pca[ind0:ind1, ...]).float()).to(device)
        
        # transform pca models with model_transform_net (fw)
        fw_m_pca_var = trans_net(m_pca_var)
        fw_m_pca_rec_var = trans_net(m_pca_rec_var)
        
        # compute reconstruction loss
        rec_loss = args.gamma_r * mae_loss(fw_m_pca_rec_var, m_petrel_var)
        
        # compute hard data loss
        hd_loss1 = args.gamma_h * compute_hd_loss(fw_m_pca_rec_var, well_hd_all)
        hd_loss2 = args.gamma_h * compute_hd_loss(fw_m_pca_rec_var, well_hd_all)
        
        # compute style loss
        
        features_m_petrel = c3d(m_petrel_var.repeat(1, 3, 1, 1, 1))
        features_fw_m_pca = c3d(fw_m_pca_var.repeat(1, 3, 1, 1, 1))
        
        style_loss = 0.
        for m in range(len(features_fw_m_pca)):
            gram_m_petrel = gram_matrix(features_m_petrel[m])
            gram_fw_m_pca = gram_matrix(features_fw_m_pca[m])
            
            style_loss += args.gamma_s * mae_loss(gram_m_petrel, gram_fw_m_pca)

        # Compute hard data loss
        total_loss = rec_loss + style_loss + hd_loss1 + hd_loss2
        total_loss.backward()
        optimizer.step()
        #
        if ib % args.log_interval == 0:
            print('Epoch{}, Batch {}/{}, Rec Loss {}, Style Loss {}, Hd Loss {}'
                  .format(e+1, ib+1, num_batch, rec_loss.item(), style_loss.item(), hd_loss1.item() + hd_loss2.item()))    
    print('')
    torch.save(trans_net.state_dict(), args.save_model + '_sw%.1f_rw%.1f_hw%.1f_%dep.pth' 
               % (args.gamma_s, args.gamma_r, args.gamma_h, e))

In [23]:
exit()