In [26]:
import os
import re
import json
import sqlite3
import torch
from tqdm import tqdm
import torchvision.transforms as T

from file_utils import create_versioned_dir, get_version_dir
from flow_models import FlowModel
from flow_models.PatchFlowModel import PatchFlowModel
from img_utils import ImageLoader, PatchExtractor
from transforms import image_dequantization, image_normalization

In [27]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
def patch_flow_trainer(name: str, path: str, model: FlowModel, loss_fn, train_images: ImageLoader, validation_images: ImageLoader,
                       patch_size=6, batch_size=64, steps=750000, val_each_steps=1000, loss_log_each_step=100, device='cpu',
                       quiet=False, lr=0.005):
    if not quiet:
        print(f'Started training for model {name}. \n Will train {steps} steps in device={device}')
    dir = create_versioned_dir(path, name)
    if not quiet:
        print(f'The weights, loss and the parameters will be stored at this location: {dir}')
    hparams = model.get_hparams()
    hparams['patch_size'] = patch_size
    hparams['batch_size'] = patch_size
    hparams['device'] = device
    hparams['train_img_path'] = train_images.path
    hparams['validation_img_path'] = validation_images.path
    hparams['model_name'] = name
    hparams['lr'] = lr
    json.dump(hparams, open(os.path.join(dir, 'hparams.yaml'), 'w'))


    # create sqllite3 conection to save the loss values
    connection = sqlite3.connect(os.path.join(dir, 'loss.db'))
    cursor = connection.cursor()
    cursor.execute("CREATE TABLE flow_model_train_loss(step, loss)")
    cursor.execute("CREATE TABLE flow_model_validation_loss(step, loss)")
    connection.commit()



    model.to(DEVICE)

    patch_extractor = PatchExtractor(p_size=patch_size, device=device)
    progress_bar = tqdm(range(steps)) if not quiet else range(steps)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    if not quiet:
        print(f'Optimizer is initialized with this parameters: {optimizer.state_dict()}')

    tmp_validation_loss = 0
    tmp_loss =0

    loss_buffer = []

    for step in progress_bar:
        train_image = train_images.get_random_image()
        train_patch_batch = patch_extractor.extract(train_image, batch_size)

        loss = 0
        z, z_log_det = model(train_patch_batch, rev=True)
        loss += loss_fn(z, z_log_det)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % loss_log_each_step == 0:
            tmp_loss = loss.item()
            if not quiet:
                progress_bar.set_description_str(f'T: {tmp_loss}, V:{tmp_validation_loss}')
            loss_buffer.append((step, loss.item()))

        if step % val_each_steps == 0:
            with torch.no_grad():
                val_image = validation_images.get_random_image()
                val_patch_batch = patch_extractor.extract(val_image, batch_size)
                z_val, z_val_log_det = model(val_patch_batch, rev=True)
                val_loss = loss_fn(z_val, z_val_log_det)
                tmp_validation_loss = val_loss.item()
                # write train loss buffer and validation loss to db
                cursor.execute("INSERT INTO flow_model_validation_loss VALUES(?, ?)", (step, val_loss.item()))
                cursor.executemany("INSERT INTO flow_model_train_loss VALUES(?, ?)", loss_buffer)
                connection.commit()
                # save checkpoint
                torch.save(optimizer.state_dict(), os.path.join(dir, f'optimizer_dict.pth'))
                torch.save(model.get_state(), os.path.join(dir, f'{name}_intermediate.pth'))
                if not quiet:
                    progress_bar.set_description_str(f'T: {tmp_loss}, V:{tmp_validation_loss}')
    connection.close()
    torch.save(model.get_state(), os.path.join(dir, f'{name}_final.pth'))


In [31]:
patch_size = 8

In [32]:
def log_likelihood_loss(z, z_log_det):
    return torch.mean(0.5 * torch.sum(z**2, dim=1) - z_log_det)

In [24]:
model = PatchFlowModel(hparams={"num_layers": 5, "sub_net_size": 512, "dimension": patch_size ** 2}) #create_NF(num_layers=10, sub_net_size=512, dimension=patch_size**2)

In [17]:
deq_normalization = T.Compose([
    image_dequantization(device=DEVICE),
    image_normalization()])

In [18]:
train_images = ImageLoader('data/material_pt_nr/train.png', transform=deq_normalization, device=DEVICE)
validation_images = ImageLoader('data/material_pt_nr/validate.png', transform=deq_normalization, device=DEVICE)

In [25]:
patch_flow_trainer('custom_patch_nr', 'results/patch_nr', model, log_likelihood_loss, train_images, validation_images, steps=3000, patch_size=patch_size, device=DEVICE)

Started training for model custom_patch_nr. 
 Will train 3000 steps in device=cuda
The weights, loss and the parameters will be stored at this location: results/patch_nr/custom_patch_nr/version_2


T: 8.30406379699707, V:8835.1484375:   0%|          | 4/3000 [00:00<01:22, 36.45it/s]

Optimizer is initialized with this parameters: {'state': {}, 'param_groups': [{'lr': 0.005, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69]}]}


T: -218.11044311523438, V:-231.1395721435547: 100%|██████████| 3000/3000 [00:46<00:00, 64.58it/s] 


In [21]:
version_folder = get_version_dir('results/patch_nr', 'custom_patch_nr', 1)
loaded_model = PatchFlowModel(path=os.path.join(version_folder, 'custom_patch_nr_intermediate.pth'))

RuntimeError: Error(s) in loading state_dict for ReversibleGraphNet:
	Missing key(s) in state_dict: "module_list.10.subnet1.0.weight", "module_list.10.subnet1.0.bias", "module_list.10.subnet1.2.weight", "module_list.10.subnet1.2.bias", "module_list.10.subnet1.4.weight", "module_list.10.subnet1.4.bias", "module_list.10.subnet2.0.weight", "module_list.10.subnet2.0.bias", "module_list.10.subnet2.2.weight", "module_list.10.subnet2.2.bias", "module_list.10.subnet2.4.weight", "module_list.10.subnet2.4.bias", "module_list.11.perm", "module_list.11.perm_inv", "module_list.12.subnet1.0.weight", "module_list.12.subnet1.0.bias", "module_list.12.subnet1.2.weight", "module_list.12.subnet1.2.bias", "module_list.12.subnet1.4.weight", "module_list.12.subnet1.4.bias", "module_list.12.subnet2.0.weight", "module_list.12.subnet2.0.bias", "module_list.12.subnet2.2.weight", "module_list.12.subnet2.2.bias", "module_list.12.subnet2.4.weight", "module_list.12.subnet2.4.bias", "module_list.13.perm", "module_list.13.perm_inv", "module_list.14.subnet1.0.weight", "module_list.14.subnet1.0.bias", "module_list.14.subnet1.2.weight", "module_list.14.subnet1.2.bias", "module_list.14.subnet1.4.weight", "module_list.14.subnet1.4.bias", "module_list.14.subnet2.0.weight", "module_list.14.subnet2.0.bias", "module_list.14.subnet2.2.weight", "module_list.14.subnet2.2.bias", "module_list.14.subnet2.4.weight", "module_list.14.subnet2.4.bias", "module_list.15.perm", "module_list.15.perm_inv", "module_list.16.subnet1.0.weight", "module_list.16.subnet1.0.bias", "module_list.16.subnet1.2.weight", "module_list.16.subnet1.2.bias", "module_list.16.subnet1.4.weight", "module_list.16.subnet1.4.bias", "module_list.16.subnet2.0.weight", "module_list.16.subnet2.0.bias", "module_list.16.subnet2.2.weight", "module_list.16.subnet2.2.bias", "module_list.16.subnet2.4.weight", "module_list.16.subnet2.4.bias", "module_list.17.perm", "module_list.17.perm_inv", "module_list.18.subnet1.0.weight", "module_list.18.subnet1.0.bias", "module_list.18.subnet1.2.weight", "module_list.18.subnet1.2.bias", "module_list.18.subnet1.4.weight", "module_list.18.subnet1.4.bias", "module_list.18.subnet2.0.weight", "module_list.18.subnet2.0.bias", "module_list.18.subnet2.2.weight", "module_list.18.subnet2.2.bias", "module_list.18.subnet2.4.weight", "module_list.18.subnet2.4.bias", "module_list.19.perm", "module_list.19.perm_inv". 
	size mismatch for module_list.0.subnet1.0.weight: copying a param with shape torch.Size([512, 24]) from checkpoint, the shape in current model is torch.Size([1024, 24]).
	size mismatch for module_list.0.subnet1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.0.subnet1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.0.subnet1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.0.subnet1.4.weight: copying a param with shape torch.Size([50, 512]) from checkpoint, the shape in current model is torch.Size([50, 1024]).
	size mismatch for module_list.0.subnet2.0.weight: copying a param with shape torch.Size([512, 25]) from checkpoint, the shape in current model is torch.Size([1024, 25]).
	size mismatch for module_list.0.subnet2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.0.subnet2.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.0.subnet2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.0.subnet2.4.weight: copying a param with shape torch.Size([48, 512]) from checkpoint, the shape in current model is torch.Size([48, 1024]).
	size mismatch for module_list.2.subnet1.0.weight: copying a param with shape torch.Size([512, 24]) from checkpoint, the shape in current model is torch.Size([1024, 24]).
	size mismatch for module_list.2.subnet1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.2.subnet1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.2.subnet1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.2.subnet1.4.weight: copying a param with shape torch.Size([50, 512]) from checkpoint, the shape in current model is torch.Size([50, 1024]).
	size mismatch for module_list.2.subnet2.0.weight: copying a param with shape torch.Size([512, 25]) from checkpoint, the shape in current model is torch.Size([1024, 25]).
	size mismatch for module_list.2.subnet2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.2.subnet2.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.2.subnet2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.2.subnet2.4.weight: copying a param with shape torch.Size([48, 512]) from checkpoint, the shape in current model is torch.Size([48, 1024]).
	size mismatch for module_list.4.subnet1.0.weight: copying a param with shape torch.Size([512, 24]) from checkpoint, the shape in current model is torch.Size([1024, 24]).
	size mismatch for module_list.4.subnet1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.4.subnet1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.4.subnet1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.4.subnet1.4.weight: copying a param with shape torch.Size([50, 512]) from checkpoint, the shape in current model is torch.Size([50, 1024]).
	size mismatch for module_list.4.subnet2.0.weight: copying a param with shape torch.Size([512, 25]) from checkpoint, the shape in current model is torch.Size([1024, 25]).
	size mismatch for module_list.4.subnet2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.4.subnet2.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.4.subnet2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.4.subnet2.4.weight: copying a param with shape torch.Size([48, 512]) from checkpoint, the shape in current model is torch.Size([48, 1024]).
	size mismatch for module_list.6.subnet1.0.weight: copying a param with shape torch.Size([512, 24]) from checkpoint, the shape in current model is torch.Size([1024, 24]).
	size mismatch for module_list.6.subnet1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.6.subnet1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.6.subnet1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.6.subnet1.4.weight: copying a param with shape torch.Size([50, 512]) from checkpoint, the shape in current model is torch.Size([50, 1024]).
	size mismatch for module_list.6.subnet2.0.weight: copying a param with shape torch.Size([512, 25]) from checkpoint, the shape in current model is torch.Size([1024, 25]).
	size mismatch for module_list.6.subnet2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.6.subnet2.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.6.subnet2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.6.subnet2.4.weight: copying a param with shape torch.Size([48, 512]) from checkpoint, the shape in current model is torch.Size([48, 1024]).
	size mismatch for module_list.8.subnet1.0.weight: copying a param with shape torch.Size([512, 24]) from checkpoint, the shape in current model is torch.Size([1024, 24]).
	size mismatch for module_list.8.subnet1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.8.subnet1.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.8.subnet1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.8.subnet1.4.weight: copying a param with shape torch.Size([50, 512]) from checkpoint, the shape in current model is torch.Size([50, 1024]).
	size mismatch for module_list.8.subnet2.0.weight: copying a param with shape torch.Size([512, 25]) from checkpoint, the shape in current model is torch.Size([1024, 25]).
	size mismatch for module_list.8.subnet2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.8.subnet2.2.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for module_list.8.subnet2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for module_list.8.subnet2.4.weight: copying a param with shape torch.Size([48, 512]) from checkpoint, the shape in current model is torch.Size([48, 1024]).