# Imports

In [4]:
from dataset import MyData
from model import Net
from config import device

import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact

### Research of range of $f(x,y,z)$


In [8]:
def make_data(mode_3d, radius):
    x, y = torch.from_numpy(np.stack(np.indices((200, 400)), axis=2).reshape(-1, 2)).float().T
    x -= x.mean()
    x /= x.abs().max()
    y -= y.mean()
    y /= y.abs().max()
    match mode_3d:
        case 'sphere':
            z = x
            x = torch.cos(torch.pi * y) * (1 - z.pow(2)).pow(0.5)
            y = torch.sin(torch.pi * y) * (1 - z.pow(2)).pow(0.5)
            data_3d = radius * torch.stack((x, y, z), dim=1)
        case 'cylinder':
            z = x
            x = torch.cos(torch.pi * y)
            y = torch.sin(torch.pi * y)
            data_3d = radius * torch.stack((x, y, z), dim=1)
        case _:
            print('Incorrect mode\n')
            raise SystemError(f'Incorrect mode: {mode_3d}\n')
    return data_3d

In [9]:
dataset1 = MyData(path_to_file='../imgs/2.png', mode='img', mode_3d='cylinder', radius=5, reduce_fctor=1, need_help=False)
dataset_list = [dataset1]
model = Net(dataset_list=dataset_list, lr=1e-3)
model.to(device)
model.load_state_dict(torch.load('./state_dict/test.pt'))

<All keys matched successfully>

In [11]:
@interact(radius=(0, 20, 0.01), mode_3d=['cylinder', 'sphere'])
def my_homotopy(radius=1, mode_3d='cylinder'):
    batch_list = [make_data(mode_3d=mode_3d, radius=radius)]
    prediction_list = model.test_model(batch_list)
    for prediction in prediction_list:
        plt.figure(figsize=(12, 6))
        plt.title(f'Visualization of the function $f(x,y,z)$ on {mode_3d} with $r={radius}$')
        plt.imshow(prediction.view((200, 400)), cmap='PuOr', vmin=-1, vmax=1)

interactive(children=(FloatSlider(value=1.0, description='radius', max=20.0, step=0.01), Dropdown(description=…