In [1]:
from comet_ml import Experiment
import trainer
import torch
import random
import numpy as np
import os

from dataset import SeismogramDataset
from neural_networks.segnet import SegNet_3Head
from logger import CometMlLogger

In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
ROOT_DIR = os.path.join(os.getcwd())

In [4]:
dataset_path = os.path.join(ROOT_DIR, 'dolfin_adjoint', '2_subdomains')

train_dataset = SeismogramDataset(dataset_path)

In [5]:
from utils import number_of_parameters

model = SegNet_3Head()
number_of_parameters(model)

Total number of parameters: 4567619
Trainable number of parameters: 4567619


In [6]:
model

SegNet_3Head(
  (adapter): Sequential(
    (adaptive_pool): AdaptiveAvgPool2d(output_size=(128, 128))
    (input): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (activation): CELU(alpha=1.0)
  )
  (encoder): Sequential(
    (block_1): Sequential(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (act1): CELU(alpha=1.0)
      (dropout1): Dropout2d(p=0.0, inplace=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (act2): CELU(alpha=1.0)
      (dropout2): Dropout2d(p=0.0, inplace=False)
      (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm3): Instance

In [7]:
import matplotlib.pyplot as plt

def visualize(batch, preds):
    
    preds_lambda = preds[0].detach().cpu().numpy()
    preds_mu     = preds[1].detach().cpu().numpy()
    preds_rho    = preds[2].detach().cpu().numpy()
    
    idx = 0
    
    fig, axes = plt.subplots(3, 2, figsize=(10, 18))
    axes[0][0].imshow(preds_lambda[idx])
    axes[0][0].set_title('predictions for \n lambda distr.')
    axes[0][0].axis('off')
    axes[0][1].imshow(batch.masks[idx].cpu().data.numpy())
    axes[0][1].set_title('true lambda distr')
    axes[0][1].axis('off')
    
    axes[1][0].imshow(preds_mu[idx])
    axes[1][0].set_title('predictions for \n mu distr.')
    axes[1][0].axis('off')
    axes[1][1].imshow(batch.masks[idx].cpu().data.numpy())
    axes[1][1].set_title('true mu distr')
    axes[1][1].axis('off')
    
    axes[2][0].imshow(preds_rho[idx])
    axes[2][0].set_title('predictions for \n rho distr.')
    axes[2][0].axis('off')
    axes[2][1].imshow(batch.masks[idx].cpu().data.numpy())
    axes[2][1].set_title('true rho distr')
    axes[2][1].axis('off')
              
    return fig

In [9]:
experiment = Experiment(
    api_key="dMLdaEjHpSSAclMOlMSQxkriB",
    project_name="rheology-reconstruction",
    workspace="stankevich-mipt",
    auto_metric_logging=False
)

params = {
  "model"    : "Segnet with regular convolutions",
  "grid_size":"10x10",
  "lr"       :"1e-3",
  "numerical solver":"adjoint_equation",
  "dataset": "2 subdomains"
}

experiment.log_parameters(params)

logger = CometMlLogger(
    experiment,
    log_interval=1,
    val_interval=1,
    visualize=visualize
)

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/stankevich-mipt/rheology-reconstruction/4014332905294de48b121ecb3a50a173
COMET INFO:   Uploads:
COMET INFO:     environment details      : 1
COMET INFO:     filename                 : 1
COMET INFO:     git metadata             : 1
COMET INFO:     git-patch (uncompressed) : 1 (2 MB)
COMET INFO:     installed packages       : 1
COMET INFO:     os packages              : 1
COMET INFO: ---------------------------
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/stankevich-mipt/rheology-reconstruction/1a52c03d959b4916873c7211a6e0242d



In [10]:
from trainer import BaseTrainer

# NB: such format for coordinates is required for dolfin_adjoint to 
# pass grads through indexing. Why? God knows

detector_coords = [(np.array([c, 2000.])) for c in np.linspace(0., 2000., 128)]

# TODO: test with nonempty logger
t = BaseTrainer(model, 
                device, 
                train_dataset,
                optimizer_type=torch.optim.Adam,
                optimizer_params={'lr': 1e-3},
                snapshot_interval=250,
                logger=logger)

In [None]:
#TODO: test if everything is allright with batch_size > 1
t.train(detector_coords, batch_size=1, epochs=100, num_solver_type='adjoint_equation')
experiment.end()

Running training procedure:   0%|          | 0/100 [00:00<?, ?it/s]
Epoch 1 of 100:   0%|          | 0/1 [00:00<?, ?it/s][A

integrating the state problem:   0%|          | 0/199 [00:00<?, ?it/s][A[A

integrating the state problem:  10%|█         | 20/199 [00:01<00:09, 18.18it/s][A[A

integrating the state problem:  20%|██        | 40/199 [00:02<00:08, 18.39it/s][A[A

integrating the state problem:  30%|███       | 60/199 [00:03<00:07, 18.21it/s][A[A

integrating the state problem:  40%|████      | 80/199 [00:04<00:06, 17.90it/s][A[A

integrating the state problem:  50%|█████     | 100/199 [00:05<00:05, 18.11it/s][A[A

integrating the state problem:  60%|██████    | 120/199 [00:06<00:04, 18.19it/s][A[A

integrating the state problem:  70%|███████   | 140/199 [00:07<00:03, 17.95it/s][A[A

integrating the state problem:  80%|████████  | 160/199 [00:08<00:02, 18.10it/s][A[A

integrating the state problem: 100%|██████████| 199/199 [00:11<00:00, 18.04it/s][A[A


integra

Calling FFC just-in-time (JIT) compiler, this may take some time.



Epoch 1 of 100: 100%|██████████| 1/1 [01:23<00:00, 83.47s/it][A
Running training procedure:   1%|          | 1/100 [01:23<2:17:44, 83.48s/it]
Epoch 2 of 100:   0%|          | 0/1 [00:00<?, ?it/s][A

integrating the state problem:   0%|          | 0/199 [00:00<?, ?it/s][A[A

integrating the state problem:  10%|█         | 20/199 [00:01<00:09, 18.72it/s][A[A

integrating the state problem:  20%|██        | 40/199 [00:02<00:08, 19.21it/s][A[A

integrating the state problem:  30%|███       | 60/199 [00:03<00:07, 18.32it/s][A[A

integrating the state problem:  40%|████      | 80/199 [00:04<00:06, 18.80it/s][A[A

integrating the state problem:  50%|█████     | 100/199 [00:05<00:05, 19.11it/s][A[A

integrating the state problem:  60%|██████    | 120/199 [00:06<00:04, 19.26it/s][A[A

integrating the state problem:  70%|███████   | 140/199 [00:07<00:03, 19.39it/s][A[A

integrating the state problem:  80%|████████  | 160/199 [00:08<00:02, 18.74it/s][A[A

integrating the stat

integrating adjoint problem in reverse time:  80%|████████  | 160/199 [00:26<00:06,  6.08it/s][A[A

integrating adjoint problem in reverse time: 100%|██████████| 199/199 [00:32<00:00,  6.12it/s][A[A

Epoch 5 of 100: 100%|██████████| 1/1 [00:48<00:00, 48.34s/it][A
Running training procedure:   5%|▌         | 5/100 [04:33<1:20:39, 50.94s/it]
Epoch 6 of 100:   0%|          | 0/1 [00:00<?, ?it/s][A

integrating the state problem:   0%|          | 0/199 [00:00<?, ?it/s][A[A

integrating the state problem:  10%|█         | 20/199 [00:01<00:10, 17.55it/s][A[A

integrating the state problem:  20%|██        | 40/199 [00:02<00:08, 17.71it/s][A[A

integrating the state problem:  30%|███       | 60/199 [00:03<00:07, 18.07it/s][A[A

integrating the state problem:  40%|████      | 80/199 [00:04<00:06, 17.91it/s][A[A

integrating the state problem:  50%|█████     | 100/199 [00:05<00:05, 18.20it/s][A[A

integrating the state problem:  60%|██████    | 120/199 [00:06<00:04, 18.19it/s]

integrating adjoint problem in reverse time:  70%|███████   | 140/199 [00:22<00:09,  6.30it/s][A[A

integrating adjoint problem in reverse time:  80%|████████  | 160/199 [00:25<00:06,  6.31it/s][A[A

integrating adjoint problem in reverse time: 100%|██████████| 199/199 [00:31<00:00,  6.29it/s][A[A

Epoch 9 of 100: 100%|██████████| 1/1 [00:46<00:00, 46.48s/it][A
Running training procedure:   9%|▉         | 9/100 [07:47<1:13:54, 48.73s/it]
Epoch 10 of 100:   0%|          | 0/1 [00:00<?, ?it/s][A

integrating the state problem:   0%|          | 0/199 [00:00<?, ?it/s][A[A

integrating the state problem:  10%|█         | 20/199 [00:01<00:15, 11.39it/s][A[A

integrating the state problem:  20%|██        | 40/199 [00:02<00:10, 14.76it/s][A[A

integrating the state problem:  30%|███       | 60/199 [00:03<00:08, 16.02it/s][A[A

integrating the state problem:  40%|████      | 80/199 [00:05<00:07, 16.74it/s][A[A

integrating the state problem:  50%|█████     | 100/199 [00:06<00:

integrating adjoint problem in reverse time:  50%|█████     | 100/199 [00:15<00:15,  6.22it/s][A[A

integrating adjoint problem in reverse time:  60%|██████    | 120/199 [00:19<00:12,  6.12it/s][A[A

integrating adjoint problem in reverse time:  70%|███████   | 140/199 [00:22<00:09,  6.06it/s][A[A

integrating adjoint problem in reverse time:  80%|████████  | 160/199 [00:26<00:06,  6.05it/s][A[A

integrating adjoint problem in reverse time: 100%|██████████| 199/199 [00:32<00:00,  6.18it/s][A[A

Epoch 13 of 100: 100%|██████████| 1/1 [00:47<00:00, 47.31s/it][A
Running training procedure:  13%|█▎        | 13/100 [10:59<1:09:35, 48.00s/it]
Epoch 14 of 100:   0%|          | 0/1 [00:00<?, ?it/s][A

integrating the state problem:   0%|          | 0/199 [00:00<?, ?it/s][A[A

integrating the state problem:  10%|█         | 20/199 [00:01<00:09, 18.56it/s][A[A

integrating the state problem:  20%|██        | 40/199 [00:02<00:08, 18.09it/s][A[A

integrating the state problem:  30

In [None]:
experiment.end()