In [11]:
import torch
from tqdm import tqdm
import torchvision.transforms as T
from img_utils import ImageLoader, PatchExtractor
from patch_nr_paper_model import create_NF
from transforms import image_dequantization, image_normalization

In [12]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
def patch_flow_trainer(name: str, model: torch.nn.Module, loss_fn, train_images: ImageLoader, validation_images: ImageLoader,
                       patch_size=6, batch_size=64, steps=1000000, val_each_steps=1000, loss_log_each_step=100, device='cpu',
                       quiet=False, lr=0.005):
    if not quiet:
        print(f'Started training for model {name}. \n Will train {steps} steps in device={device}')
    model.to(device)
    patch_extractor = PatchExtractor(p_size=patch_size, device=device)
    progress_bar = tqdm(range(steps)) if not quiet else range(steps)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    if not quiet:
        print(f'Optimizer is initialized with this parameters: {optimizer.state_dict()}')

    tmp_validation_loss = 0
    tmp_loss =0

    for step in progress_bar:
        train_image = train_images.get_random_image()
        train_patch_batch = patch_extractor.extract(train_image, batch_size)

        loss = 0
        z, z_log_det = model(train_patch_batch, rev=True)
        loss += loss_fn(z, z_log_det)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % loss_log_each_step == 0:
            tmp_loss = loss
            if not quiet:
                progress_bar.set_description_str(f'T: {tmp_loss}, V:{tmp_validation_loss}')

        if step % val_each_steps == 0:
            with torch.no_grad():
                val_image = validation_images.get_random_image()
                val_patch_batch = patch_extractor.extract(val_image, batch_size)
                z_val, z_val_log_det = model(val_patch_batch, rev=True)
                val_loss = loss_fn(z_val, z_val_log_det)
                tmp_validation_loss = val_loss
                if not quiet:
                    progress_bar.set_description_str(f'T: {tmp_loss}, V:{tmp_validation_loss}')

In [14]:
patch_size = 9

In [15]:
def log_likelihood_loss(z, z_log_det):
    return torch.mean(0.5 * torch.sum(z**2, dim=1) - z_log_det)

In [16]:
model = create_NF(num_layers=10, sub_net_size=512, dimension=patch_size**2)

In [17]:
deq_normalization = T.Compose([image_dequantization(device=DEVICE), image_normalization()])

TypeError: image_dequantization() got an unexpected keyword argument 'device'

In [9]:
train_images = ImageLoader('data/set12/train', transform=deq_normalization, device=DEVICE)
validation_images = ImageLoader('data/set12/validate/05.png', transform=deq_normalization, device=DEVICE)

In [10]:
patch_flow_trainer('set12_flow', model, log_likelihood_loss, train_images, validation_images, patch_size=patch_size, device=DEVICE)

Started training for model set12_flow. 
 Will train 1000000 steps in device=cuda


  0%|          | 0/1000000 [00:00<?, ?it/s]

Optimizer is initialized with this parameters: {'state': {}, 'param_groups': [{'lr': 0.005, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139]}]}





RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!