In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

from scripts.imagedata import ImageData
from scripts.neutralliner import NeutralLiner
from scripts.config import device, PATH_TO_MCINTOSH

In [6]:
def test_me(help_step_size,
            n_runs=100,
            num_epochs=50000,
            lr=5e-3,
            arch=(3, 6, 12, 24, 12, 6, 3, 1),
            weight_decay=1e-4):
    filenames = os.listdir(PATH_TO_MCINTOSH)
    path_to_save = f'./Tests/Fits/Mean/big/step{help_step_size}'
    os.makedirs(path_to_save, exist_ok=True)

    for i, filename in enumerate(filenames):
        predictions = []
        imgdata = ImageData(os.path.join(PATH_TO_MCINTOSH, filename), data_mode='fits')
        for n in tqdm(range(n_runs), desc=f'Predicting map {i+1}/{len(filenames)}'):
            model = NeutralLiner(image_list=[imgdata],
                                lr=lr,
                                help_step_size=help_step_size,
                                mode='3d',
                                arch=arch,
                                weight_decay=weight_decay)
            model.to(device)
            model.start_training(num_epochs=num_epochs, need_plot=False)
            prediction = model.test_model(need_plot=False)[0].view(model.image_list[0].img_array.shape).cpu().detach()
            predictions.append(prediction)
            model.save_state_dict(os.path.join(path_to_save, f'map_{i:02d}_{n:02d}.pt'))
        predictions = torch.stack(predictions)
        prediction = torch.mean(predictions, dim=0)
        plt.figure(figsize=(20, 10))
        plt.subplot(1, 2, 1)
        plt.imshow(imgdata.target_img, cmap='PuOr')
        plt.title(f'Target for map {i}')
        plt.subplot(1, 2, 2)
        plt.imshow(prediction, cmap='PuOr')
        plt.title(f'Mean prediction for map {i}')
        x, y = np.where(imgdata.img_array < imgdata.img_array.max())
        plt.scatter(y, x, s=0.2, c='green', alpha=0.5)
        plt.savefig(os.path.join(path_to_save, f'map_{i:02d}_mean.png'), bbox_inches='tight', pad_inches=0, facecolor='white')
        plt.close()

In [3]:
for hsz in [1, 31, 63, 127, None]:
    test_me(help_step_size=hsz)

Predicting map 1/2: 100%|██████████| 1/1 [00:15<00:00, 15.72s/it]
Predicting map 2/2: 100%|██████████| 1/1 [00:12<00:00, 12.21s/it]
Predicting map 1/2: 100%|██████████| 1/1 [00:12<00:00, 12.60s/it]
Predicting map 2/2: 100%|██████████| 1/1 [00:13<00:00, 13.86s/it]
Predicting map 1/2: 100%|██████████| 1/1 [00:13<00:00, 13.84s/it]
Predicting map 2/2: 100%|██████████| 1/1 [00:14<00:00, 14.07s/it]
Predicting map 1/2: 100%|██████████| 1/1 [00:13<00:00, 13.39s/it]
Predicting map 2/2: 100%|██████████| 1/1 [00:14<00:00, 14.11s/it]
Predicting map 1/2: 100%|██████████| 1/1 [00:12<00:00, 12.86s/it]
Predicting map 2/2: 100%|██████████| 1/1 [00:13<00:00, 13.16s/it]
