In [24]:
import os
import re

import torch
from tqdm import tqdm
import torchvision.transforms as T

from file_utils import create_versioned_dir
from flow_models import FlowModel
from flow_models.PatchFlowModel import PatchFlowModel
from img_utils import ImageLoader, PatchExtractor
from transforms import image_dequantization, image_normalization

In [25]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [26]:
def patch_flow_trainer(name: str, path: str, model: FlowModel, loss_fn, train_images: ImageLoader, validation_images: ImageLoader,
                       patch_size=6, batch_size=64, steps=750000, val_each_steps=1000, loss_log_each_step=100, device='cpu',
                       quiet=False, lr=0.005):

    dir = create_versioned_dir(path, name)

    model_hparams = model.get_hparams()
    torch.save(model_hparams, os.path.join(dir, 'hparams.yaml'))

    if not quiet:
        print(f'Started training for model {name}. \n Will train {steps} steps in device={device}')

    model.to(DEVICE)

    patch_extractor = PatchExtractor(p_size=patch_size, device=device)
    progress_bar = tqdm(range(steps)) if not quiet else range(steps)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    if not quiet:
        print(f'Optimizer is initialized with this parameters: {optimizer.state_dict()}')

    tmp_validation_loss = 0
    tmp_loss =0

    for step in progress_bar:
        train_image = train_images.get_random_image()
        train_patch_batch = patch_extractor.extract(train_image, batch_size)

        loss = 0
        z, z_log_det = model(train_patch_batch, rev=True)
        loss += loss_fn(z, z_log_det)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % loss_log_each_step == 0:
            tmp_loss = loss
            if not quiet:
                progress_bar.set_description_str(f'T: {tmp_loss}, V:{tmp_validation_loss}')

        if step % val_each_steps == 0:
            with torch.no_grad():
                val_image = validation_images.get_random_image()
                val_patch_batch = patch_extractor.extract(val_image, batch_size)
                z_val, z_val_log_det = model(val_patch_batch, rev=True)
                val_loss = loss_fn(z_val, z_val_log_det)
                tmp_validation_loss = val_loss
                if not quiet:
                    progress_bar.set_description_str(f'T: {tmp_loss}, V:{tmp_validation_loss}')

In [27]:
patch_size = 7

In [28]:
def log_likelihood_loss(z, z_log_det):
    return torch.mean(0.5 * torch.sum(z**2, dim=1) - z_log_det)

In [29]:
model = PatchFlowModel(hparams={"num_layers": 5, "sub_net_size": 512, "dimension": patch_size ** 2}) #create_NF(num_layers=10, sub_net_size=512, dimension=patch_size**2)

In [30]:
deq_normalization = T.Compose([
    image_dequantization(device=DEVICE),
    image_normalization()])

In [31]:
train_images = ImageLoader('data/material_pt_nr/train.png', transform=deq_normalization, device=DEVICE)
validation_images = ImageLoader('data/material_pt_nr/validate.png', transform=deq_normalization, device=DEVICE)

In [32]:
patch_flow_trainer('custom_patch_nr', 'results/patch_nr', model, log_likelihood_loss, train_images, validation_images, steps=750000, patch_size=patch_size, device=DEVICE)

results/patch_nr/custom_patch_nr
[]


AttributeError: 'PatchFlowModel' object has no attribute 'get_hparams'