In [88]:
import argparse
import os
from ast import arg
from pickletools import optimize

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from wlcorr import EncoderDecoderStaticDataset, EncoderDecoder1DCNN


In [89]:
class VRLoss(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.cosinesim1 = nn.CosineSimilarity(dim=1, eps=1e-6)
        self.cosinesim2 = nn.CosineSimilarity(dim=2, eps=1e-6)

    def forward(self, output:torch.Tensor, target:torch.Tensor):
        b, n, _ = output.shape
        sim1 = self.cosinesim1(output, target)
        sim2 = self.cosinesim2(output, target)

        l = sim1.square().exp().mean() + sim2.square().exp().mean()
        return l

In [90]:
dataset = EncoderDecoderStaticDataset('/home/shivam/DKLabs/OilGasProject/WellLogCorrelation/data')
dataloader = DataLoader(dataset, batch_size = 10, shuffle = True)
data_len = len(dataloader)
log_n = int(data_len//2)

In [91]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [92]:
model = EncoderDecoder1DCNN(2, 50)
model.to(device)
model.train()

EncoderDecoder1DCNN(
  (encoder): Sequential(
    (0): Conv1d(2, 5, kernel_size=(10,), stride=(1,), padding=(9,))
    (1): SELU()
    (2): Conv1d(5, 10, kernel_size=(10,), stride=(1,), padding=(9,))
    (3): SELU()
    (4): Conv1d(10, 10, kernel_size=(10,), stride=(1,), padding=(9,))
    (5): SELU()
    (6): Conv1d(10, 10, kernel_size=(10,), stride=(1,), padding=(9,))
    (7): SELU()
    (8): Conv1d(10, 1, kernel_size=(10,), stride=(1,), padding=(9,))
    (9): SELU()
    (10): Linear(in_features=145, out_features=50, bias=True)
  )
  (decoder): Sequential(
    (0): ConvTranspose1d(1, 5, kernel_size=(10,), stride=(1,))
    (1): SELU()
    (2): ConvTranspose1d(5, 10, kernel_size=(10,), stride=(1,))
    (3): SELU()
    (4): ConvTranspose1d(10, 10, kernel_size=(10,), stride=(1,))
    (5): SELU()
    (6): ConvTranspose1d(10, 5, kernel_size=(10,), stride=(1,))
    (7): SELU()
    (8): ConvTranspose1d(5, 5, kernel_size=(10,), stride=(1,))
    (9): SELU()
    (10): ConvTranspose1d(5, 2, kernel

In [93]:
criterion = VRLoss() # Loss function
params_list = model.parameters() # model parameters
optimizer = optim.AdamW(params_list, lr = 0.0007, weight_decay=0.01)

In [94]:
def compute_loss(dl, model, crt):
    total_loss = 0.
    cnt = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in dl:
            data = data.to(device)
            
            # calculate outputs by running images through the network
            output = model(data)

            loss = crt(output, data)
            
            # print(predicted, labels)
            total_loss += loss.item()*data.size(0)
            cnt += data.size(0)
    return total_loss/cnt-1.

In [95]:
for epoch in range(200):
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        data = data.to(device) # Move data to target device

        # zero the parameter gradients
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, data)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i%log_n == log_n-1:
            print(f'Epoch : {epoch}, Iteration : {i},Running loss : {running_loss}, Total loss : {compute_loss(dataloader, model, criterion)}')
            running_loss = 0

Epoch : 0, Iteration : 4,Running loss : 12.603748321533203, Total loss : 1.2728073499640642
Epoch : 0, Iteration : 9,Running loss : 11.072587013244629, Total loss : 1.1810150341111787
Epoch : 1, Iteration : 4,Running loss : 11.07320499420166, Total loss : 1.1432597637176514
Epoch : 1, Iteration : 9,Running loss : 10.594058990478516, Total loss : 1.1294843518004125
Epoch : 2, Iteration : 4,Running loss : 10.6955885887146, Total loss : 1.0888108185359409
Epoch : 2, Iteration : 9,Running loss : 10.352745056152344, Total loss : 1.0845445613471831
Epoch : 3, Iteration : 4,Running loss : 10.581680059432983, Total loss : 1.0864492241217167
Epoch : 3, Iteration : 9,Running loss : 10.351811408996582, Total loss : 1.0629205801049055
Epoch : 4, Iteration : 4,Running loss : 10.31750774383545, Total loss : 1.069719830337836
Epoch : 4, Iteration : 9,Running loss : 10.250628232955933, Total loss : 1.0422755455484194
Epoch : 5, Iteration : 4,Running loss : 10.204556465148926, Total loss : 1.0371567521

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), './basemodel')