In [16]:
import os
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from models.PatchNr import create_NF
import datetime
from PatchDataset import PatchDataset
from torchvision import transforms
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
def show_patch(patch, dims):
    plt.imshow(patch.reshape(*dims), cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
    plt.show()

In [18]:
p_size = 6
p_dims = (p_size, p_size)
batch_size = 32
channels = 1
net_dims = p_size**2
data_dims = (batch_size, 1, *p_dims)
layers = 5
hidden_nodes = 512
optimizer_steps = 1000
learning_rate = 1e-4
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

In [21]:
transform = transforms.Compose([transforms.Lambda(lambda patch: patch / 255.0),
                                transforms.Lambda(lambda patch: patch.to(DEVICE)),
                                 transforms.Lambda(lambda patch: patch.flatten())])
training_data_set = PatchDataset('./data/set12/train', p_dims, transform=transform, device=DEVICE)
training_data_loader = DataLoader(training_data_set, batch_size=batch_size, shuffle=True, )
training_data_loader_iter = iter(training_data_loader)
first_batch = next(training_data_loader_iter)
# show_patch(first_batch[20], p_dims)

print(f'The training set consists of {len(training_data_set)} patches with dimensions {first_batch.shape}')
print(f'dType {first_batch.dtype}')

test_data_set = PatchDataset('./data/set12/train', p_dims, transform=transform)
test_data_loader = DataLoader(test_data_set, batch_size=batch_size, shuffle=True)
test_data_loader_iter = iter(test_data_loader)

The training set consists of 250000 patches with dimensions torch.Size([32, 36])
dType torch.float32


In [23]:
print(f'Learning on device {DEVICE}')
# Flow
patch_nr_flow = create_NF(layers, hidden_nodes, net_dims)
patch_nr_flow.to(DEVICE)
# Optimizer
optimizer = torch.optim.Adam(patch_nr_flow.parameters(), lr=learning_rate)
loss_list = []

#create state dir
state_dir = f'./model-states/{current_time}'
if not os.path.isdir(state_dir):
    os.mkdir(state_dir)

for it in tqdm(range(optimizer_steps)):
    batch = next(training_data_loader_iter)
    loss = 0
    z, z_jac = patch_nr_flow(batch, rev=True)
    loss += torch.mean(0.5 * torch.sum(z**2, dim=1) - z_jac)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if it%100 == 0:
        with torch.no_grad():
            test_batch = next(test_data_loader_iter)
            z_test, z_test_jac = patch_nr_flow(test_batch, rev=True)
            test_loss = torch.mean(0.5 * torch.sum(z_test**2, dim=1) - z_test_jac).item()
            loss_list.append([datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'), loss, test_loss])
    if it%500 == 0:
        torch.save({'net_state': patch_nr_flow.state_dict(), 'net_loss': loss, 'lr': learning_rate, 'p_dims': p_dims, }, f'{state_dir}/it-{it}-{datetime.datetime.now().strftime("%H-%M-%S")}')

torch.save({'net_state': patch_nr_flow.state_dict(), 'net_loss': loss}, f'{state_dir}/{datetime.datetime.now().strftime("final-%H-%M-%S")}')

Learning on device cpu


  0%|          | 273/750000 [00:06<5:18:22, 39.25it/s]


KeyboardInterrupt: 