In [61]:
import torch
from tqdm import trange
from sklearn.model_selection import train_test_split
import math

from src.focus_lens_quadric_primitive import FocusLensQuadricPrimitive
import src.utils_energy_distribution as en_util
from src.nn_modules import *
from src.utils_scaling import *
import src.utils_optimization as opt_util
import src.utils_display as disp_util

from torch.masked import masked_tensor, as_masked_tensor

# torch.autograd.set_detect_anomaly(True)

device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [62]:
save_NN_file_name = 'data/16x16_UNextNN.ebv'
load_NN_file_name = 'data/16x16_UNextNN.ebv'

load_surface_file_name = 'data/16x16_surf.ebv'
load_project_file_name = 'data/30k16x16.ebv'

Load data

In [63]:
quadric_surface = FocusLensQuadricPrimitive.load(load_surface_file_name)
# quadric_surface = FocusDOEQuadricPrimitive.load(load_surface_file_name)
#
with open(load_project_file_name, 'rb') as file:
    project = pickle.load(file)

    quadrics = project['quadrics']
    req_distrs = project['req_distrs']
    rays = project['rays']
    optim_tracing_results = project['design_logs']['tracing_results']

torch_dtype = torch.float32 if quadrics.dtype == np.single else torch.double

Prepare data

In [64]:
num_imgs = quadrics.shape[0]
good_quads = np.isfinite(quadrics)

# # # Привести значения квадрик к новому интервалу
# new_interval = [10, 255]
# quadrics[good_quads], linear_coeffs, old_interval = linear_scaling(quadrics[good_quads], new_interval)

linear_coeffs = [1, 0]
# quadrics[~good_quads] = 0
quadric_surface.to_CNN(quadrics, req_distrs)

In [65]:
# #DATA CONSISTANCE CHECK
# tmp_quadrics = linear_inverse_scaling(quadrics, linear_coeffs)
# tmp_quadrics = opt_util.numpy2torch_batch(tmp_quadrics, device)
# tmp_quadrics = opt_util.torch2numpy_batch(tmp_quadrics)
# tmp_quadrics[req_distrs <= 0] = -np.inf
# tmp_traced, _ = quadric_surface.trace_rays_batch(tmp_quadrics, rays)
#
# i = disp_util.display_random_pairs([req_distrs, req_distrs], [optim_tracing_results, tmp_traced],
#                                    titles=['Req Distr(Ideal)', 'Req Distr(optim)', 'Req Distr(Ideal)', 'Traced Quads'])

Pack data

In [66]:
quadrics = opt_util.numpy2torch_batch(quadrics, device)
req_distrs = opt_util.numpy2torch_batch(req_distrs, device)

X_train, X_test, y_train, y_test = train_test_split(req_distrs, quadrics, test_size=0.15, random_state=33)

In [67]:
X_test_np = opt_util.torch2numpy_batch(X_test.cpu())
y_test_np = opt_util.torch2numpy_batch(y_test.cpu())

In [68]:
mask = X_train <= 0
X_train = X_train/X_train
X_train[mask] = 0

mask = X_test <= 0
X_test = X_test/X_test
X_test[mask] = 0

Prepare U-Net MODEL

In [69]:
# ndepth = 3
# nwidth = 2
# ndeg = 4
#
# model = OlikerUNet(quad_linear_coeffs=linear_coeffs, net_depth=ndepth, net_width=nwidth, base_degree=ndeg)

In [70]:
# model = convnext_base(output_linear_coeffs=linear_coeffs,output_dim=16)

In [71]:
model = unext_base(output_linear_coeffs=linear_coeffs,output_dim=16)

In [72]:
# criterion = nn.CrossEntropyLoss()
# criterion = nn.BCELoss()
# criterion = nn.KLDivLoss(reduction="batchmean")
criterion = nn.MSELoss()

In [73]:
del quadrics, req_distrs, optim_tracing_results

Load MODEL

In [74]:
# model = OlikerUNet.from_file(load_NN_file_name)
# model = ConvNeXt.from_file(load_NN_file_name)
# model.eval()

In [75]:
num_model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Module parameters count: ' + str(num_model_params))

Module parameters count: 13664788


Train model

In [76]:
loss_per_iter = []
loss_per_epoch = []
lr_per_iter = []

In [77]:
train_size = X_train.shape[0]
num_epochs = 100
batch_size = 128
num_batches_in_epoch = math.ceil(train_size / batch_size)
num_iters_to_step = 10

start_lr = 1e-2
end_lr = 1e-6
num_steps = int(num_batches_in_epoch * num_epochs / num_iters_to_step)
lr_gamma = pow(end_lr/start_lr, 1/num_steps)
# lr_gamma = 0.99

optimizer = torch.optim.Adam(model.parameters(), lr=start_lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_gamma)

In [None]:
model.to(dtype=torch_dtype,device=device)
X_train = X_train.to(device)
y_train = y_train.to(device)

model.train()
inds = np.arange(train_size)
progress = trange(num_epochs * num_batches_in_epoch)

mean_loss = 0
for iter in progress:
    if iter % num_batches_in_epoch == 0:
        np.random.shuffle(inds)
        start_ind = 0

    end_ind = start_ind + batch_size
    end_ind = end_ind if end_ind <= X_train.shape[0] else X_train.shape[0]

    cur_inds = inds[start_ind:end_ind]

    # Прямой запуск
    inputs = X_train[cur_inds]
    outputs = model(inputs)

    # mask = inputs <= 0
    # # out_masked = masked_tensor(outputs, mask)
    # outputs = outputs.masked_fill(mask, 0)

    # loss = torch.sum(((1-y_train[cur_inds]/outputs)*inputs)**2,dim=(2,3)) / torch.sum(inputs,dim=(2,3))
    # loss = loss.max()
    loss = criterion(outputs, y_train[cur_inds])
    # Обратное распространение и оптимизатор
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_per_iter.append(loss.item())
    lr_per_iter.append(scheduler.get_last_lr()[0])

    if (iter + 1) % num_iters_to_step == 0:
        scheduler.step()
    # if (iter + 1) % 2 == 0:
    #     progress.set_postfix_str(f"Loss: {loss.item() :.6f}, lr: {scheduler.get_last_lr()[0] :e}")

    mean_loss = mean_loss + loss.item()

    if (iter+1) % num_batches_in_epoch == 0:
        loss_per_epoch.append(mean_loss/num_batches_in_epoch)
        progress.set_postfix_str(f"Mean_loss: {loss_per_epoch[-1] :.6f}, lr: {lr_per_iter[-1] :e}, num iters per epoch: {num_batches_in_epoch:d}")
        mean_loss = 0

    start_ind = end_ind

 39%|███▉      | 7898/20100 [4:16:45<7:17:51,  2.15s/it, Mean_loss: 0.002099, lr: 2.765611e-04, num iters per epoch: 201] 

In [None]:
disp_util.display_optimization_stats(loss=loss_per_iter, loss_per_epoch=loss_per_epoch, lr=lr_per_iter)

Save MODEL

In [None]:
model.save(save_NN_file_name)

Model evaluation

In [None]:
model.eval()
model.cpu()
X_test = X_test.cpu()
batch_size = 128

test_pred = model.batch_eval(X_test, batch_size)
test_pred = opt_util.torch2numpy_batch(test_pred)

test_pred_shifted = model.batch_scaled_eval(X_test, batch_size)
test_pred_shifted = opt_util.torch2numpy_batch(test_pred_shifted)

quadric_surface.to_Oliker(test_pred_shifted, X_test_np)
pred_traced, _ = quadric_surface.trace_rays_batch(test_pred_shifted, rays)

quad_rrmse = en_util.masked_rrmse_loss_batch(y_test_np, test_pred)
distr_rrmse = en_util.masked_rrmse_loss_batch(X_test_np, pred_traced)

print(f"Quad RRMSE= {quad_rrmse.mean() :.4f}%. Illum RRMSE= {distr_rrmse.mean() :.4f}%.")

In [None]:
i = disp_util.display_random_pairs([X_test_np, y_test_np], [pred_traced, test_pred],
                                   titles=['Req Distr', 'Prediction traced', 'Req Quads', 'Predicted Quads'])