In [1]:
import trainer
import torch
import random
import numpy as np
import os

from dataset import SeismogramDataset
from neural_networks.segnet import SegNet_3Head

In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
ROOT_DIR = os.path.join(os.getcwd())

In [4]:
dataset_path = os.path.join(ROOT_DIR, '..', 'datasets', 'heterogeneity')
train_dataset = SeismogramDataset(dataset_path)

In [5]:
from utils import number_of_parameters

model = SegNet_3Head()
number_of_parameters(model)

Total number of parameters: 4567619
Trainable number of parameters: 4567619


In [6]:
model

SegNet_3Head(
  (adapter): Sequential(
    (adaptive_pool): AdaptiveAvgPool2d(output_size=(128, 128))
    (input): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (activation): CELU(alpha=1.0)
  )
  (encoder): Sequential(
    (block_1): Sequential(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (act1): CELU(alpha=1.0)
      (dropout1): Dropout2d(p=0.0, inplace=False)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (act2): CELU(alpha=1.0)
      (dropout2): Dropout2d(p=0.0, inplace=False)
      (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm3): Instance

In [7]:
from trainer import BaseTrainer

solver_config = os.path.join(dataset_path, 'solver_config.yaml')

# TODO: test with nonempty logger
t = BaseTrainer(model,
                device,
                train_dataset,
                solver_config,
                optimizer_type=torch.optim.Adam,
                optimizer_params={'lr': 1e-3},
                snapshot_interval=10)

In [8]:
#TODO: test if everything is allright with batch_size > 1
t.train(batch_size=1, epochs=100, num_solver_type='dolfin_adjoint')

Running training procedure:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 of 100:   0%|          | 0/1000 [00:00<?, ?it/s]

Calling FFC just-in-time (JIT) compiler, this may take some time.
epoch: 0; loss: 0.002722756212214414


KeyboardInterrupt: 