# Imports

In [6]:
from dataset import MyData
from model import Net
from config import device
from helper import downscale_map

import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import glob
import os 
import matplotlib.image as mpimg
import pandas as pd

## Big Test

### Map Cascade

#### Downscale visualization

In [4]:
@interact(path_to_file=glob.glob('../data/*.abz'))
def s_(path_to_file):
    dataset1 = MyData(path_to_file=path_to_file, mode='abz', mode_3d='cylinder', radius=10, reduce_fctor=1, need_help=False)
    normal_img_array_1 = 1 - np.array(dataset1.img_array, dtype='float64') / 255.
    map_cascade_1 = [downscale_map(normal_img_array_1, sq_size=2 ** x) for x in [7, 6, 5, 3, 2]]

    # convolve = np.vectorize(np.convolve, signature='(n),(m)->(k)')
    # map_cascade = [(1 - convolve(convolve(normal_img_array, np.ones((i))).T, np.ones((i))).T) > 0 for i in range(200, 0, -50)]

    for cas_map in map_cascade_1:
        plt.imshow(cas_map, cmap="gray")
        plt.title(f'{cas_map.shape}')
        plt.show()

interactive(children=(Dropdown(description='path_to_file', options=('../data\\sn1996k1904eng.abz', '../data\\s…

#### Loop for training many models

In [48]:
def big_test(path_to_test='./BIG_TEST_MAP_CASCADE/'):
    for path_to_file in glob.glob('./data/*.abz'):
        filename = path_to_file.split('\\')[-1]
        cur_dir = path_to_test + filename[:-4]
        os.makedirs(cur_dir, exist_ok=True)
        
        for radius in [1., 2., 5., 10., 15., 20.]:
            cur_dir = path_to_test + filename[:-4] + f'/radius_{radius}'
            os.makedirs(cur_dir, exist_ok=True)

            dataset1 = MyData(path_to_file=path_to_file, mode='abz', mode_3d='cylinder', radius=radius, reduce_fctor=1, need_help=False)
            dataset_list = [dataset1]
            normal_img_array = 1 - np.array(dataset1.img_array, dtype='float64') / 255.
            map_cascade = [downscale_map(normal_img_array, sq_size=2**x) for x in [7, 6, 5, 3, 2]]
            
            for i, cas_map in enumerate(map_cascade):
                dataset1 = MyData(path_to_file=255 * cas_map, mode='2img', mode_3d='cylinder', radius=radius, reduce_fctor=1, need_help=False)

                dataset_list = [dataset1]
                model = Net(dataset_list=dataset_list, lr=1e-3)
                model.to(device)
                # summary(model, input_size=(sum([len(x) for x in model.data_list]), 3))

                if i:
                    model.load_state_dict(torch.load(cur_dir+f'/{i}_state_dict.pt'))
                # _ = model.test_model(model.data_list, need_plot=True)
                # dataset1.show_image()
                model.start_training(num_epochs=5e+3, my_weight=0.1, need_plot=False, need_save=False)
                model.save_state_dict(cur_dir+f'/{i+1}_state_dict.pt')
        
# big_test()

#### Show results

In [None]:
for map_name in os.listdir('./BIG_TEST_MAP_CASCADE/'):
    grid = plt.GridSpec(4, 3)
    plt.figure(figsize=(15, 15))
    # plt.suptitle(f'Visualization of the function $f(x,y,z)$ on {map_name}')

    for i, radius_path in enumerate(os.listdir(f'./BIG_TEST_MAP_CASCADE/{map_name}')):
        radius = float(radius_path.split('_')[-1])

        dataset = MyData(path_to_file=f'./data/{map_name}.abz', mode='abz', mode_3d='cylinder', radius=radius, reduce_fctor=1, need_help=False)
        model=Net(dataset_list=[dataset], lr=1e-3).to(device)
        model.load_state_dict(torch.load(f'./BIG_TEST_MAP_CASCADE/{map_name}/{radius_path}/5_state_dict.pt'))
        with torch.no_grad():
            output_list = [model(input.to(device)).cpu().detach() for input in model.data_list]
            loss = model.compute_loss(output_list, my_weight=0.1)
        plt.subplot(grid[i // 3, i % 3])
        plt.imshow(output_list[0].view(dataset.img_array.shape), cmap='PuOr', vmin=-1, vmax=1)
        plt.title(f'Radius = {radius}, Loss = {loss:.3f}')
    plt.subplot(grid[2, :])
    # plt.title(f'Target map {map_name}')
    path_to_target_img = f'./data/{map_name}.gif'
    plt.imshow(mpimg.imread(path_to_target_img), cmap='gray')
    plt.axis('off')
    plt.subplot(grid[3, :])
    plt.title(f'Original map {map_name}')
    plt.imshow(dataset.img_array, cmap='gray')
    
    os.makedirs('./BIG_TEST_MAP_CASCADE_results', exist_ok=True)
    plt.savefig(f'./BIG_TEST_MAP_CASCADE_results/{map_name}.png', facecolor='white')


#### Make table of results

In [69]:
res_dict = {}
for map_name in os.listdir('./BIG_TEST_MAP_CASCADE/'):
    res_dict[map_name] = {}
    for radius_path in os.listdir(f'./BIG_TEST_MAP_CASCADE/{map_name}'):
        radius = float(radius_path.split('_')[-1])
        dataset = MyData(path_to_file=f'./data/{map_name}.abz', mode='abz', mode_3d='cylinder', radius=radius, reduce_fctor=1, need_help=False)
        model=Net(dataset_list=[dataset], lr=1e-3).to(device)
        model.load_state_dict(torch.load(f'./BIG_TEST_MAP_CASCADE/{map_name}/{radius_path}/5_state_dict.pt'))
        with torch.no_grad():
            output_list = [model(input.to(device)).cpu().detach() for input in model.data_list]
            loss = model.compute_loss(output_list, my_weight=0.1)
        res_dict[map_name][radius] = loss.numpy()

In [70]:
df = pd.DataFrame.from_dict(res_dict, orient='index', columns=[1., 2., 5., 10., 15., 20.])
df

Unnamed: 0,1.0,2.0,5.0,10.0,15.0,20.0
sn1996k1904eng,0.17910649,0.13343988,0.16337517,0.12618695,0.1429933,0.21951117
sn1996k1905eng,0.3080644,0.30417284,0.17025696,0.24258286,0.1130567,0.08923017
sn1996k1906eng,0.18068337,0.11606892,0.19258796,0.09723622,0.17927095,0.0938478
sn1996k1907eng,0.23718216,0.37504908,0.2142726,0.3711627,0.2098669,0.19954531
sn1996k1908eng,0.13678554,0.15115261,0.07683112,0.08129287,0.6223429,0.078284696
sn1996k1909eng,0.114870355,0.25997916,0.25830814,0.13832031,0.2772987,0.09218259
sn1996k1910eng,0.7841412,0.7093341,0.16984908,0.16212864,0.16853276,0.16014926
sn1996k1911eng,0.21512854,0.17002791,0.17895973,0.13767628,0.21002138,0.724503
sn1996k1912eng,0.16261458,0.19972321,0.25447634,0.14700937,0.15166387,0.14492929
sn1996k1913eng,0.20427893,0.1648998,0.70845187,0.14818043,0.14949252,0.17160212


In [None]:
def plot_df_by_rows(df):
    grid = plt.GridSpec(df.shape[0], 1)
    plt.figure(figsize=(10, 100))
    for i, row in enumerate(df.iterrows()):
        plt.subplot(grid[i, 0])
        plt.title(row[0])
        plt.plot(row[1].index, row[1].values)
        plt.xlabel('Radius')
        plt.ylabel('Loss')
    plt.show()
plot_df_by_rows(df)