<a href="https://colab.research.google.com/github/Artur4ik2304/Yandex_ML_3.0/blob/main/hw_06/hw_06.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [110]:
import pandas as pd
import numpy as np
from PIL import Image, ImageOps
import torch
from torch.optim.lr_scheduler import StepLR
import itertools
from scipy.signal import convolve2d

In [111]:
algos = pd.read_csv('algos.csv', header=None)

In [112]:
filter_1 = algos.iloc[0, :].values

In [113]:
filter_2 = algos.iloc[1, :].values

In [114]:
filter_3 = np.random.rand(3, 3)

In [121]:
def optimize_filter_3(img, target, filter_1, filter_2, filter_3_init, lr=0.01, num_iters=100000):
    # Проверка данных
    assert not np.isnan(img).any(), "Есть NaN в img!"
    assert not np.isinf(img).any(), "Есть Inf в img!"
    assert not np.isnan(target).any(), "Есть NaN в target!"
    assert not np.isinf(target).any(), "Есть Inf в target!"

    # Конвертация в тензоры
    img_tensor = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
    target_tensor = torch.from_numpy(target).float().unsqueeze(0).unsqueeze(0)
    f1 = torch.from_numpy(filter_1).float().unsqueeze(0).unsqueeze(0)
    f2 = torch.from_numpy(filter_2).float().unsqueeze(0).unsqueeze(0)
    f3 = torch.from_numpy(filter_3_init).float().unsqueeze(0).unsqueeze(0).requires_grad_(True)

    optimizer = torch.optim.Adam([f3], lr=lr)
    scheduler = StepLR(optimizer, step_size=10000, gamma=0.1)

    for i in range(num_iters):
        optimizer.zero_grad()

        # Перебираем все порядки фильтров
        filters = [f1, f2, f3]
        best_loss = float('inf')
        best_output = None

        for order in itertools.permutations([0, 1, 2]):
            output = img_tensor
            for idx in order:
                output = torch.nn.functional.conv2d(output, filters[idx], padding='same')

            loss = torch.mean((output - target_tensor)**2)
            if loss < 1e-8:
              print(order)
              return f3.squeeze().detach().numpy()
            if loss < best_loss:
                best_loss = loss
                best_output = output
            elif best_output == None:
                best_output = output

        # Обратное распространение для лучшего порядка
        loss = torch.mean((best_output - target_tensor)**2)

        if torch.isnan(loss):
            print(f"NaN в loss на итерации {i}! Выход...")
            break

        loss.backward()
        optimizer.step()
        scheduler.step()

        if i % 1000 == 0:
            print(f"Iter {i}, Loss: {loss.item():.6f}")

    return f3.squeeze().detach().numpy()

In [122]:
for i in range(1, 1001):
  img = Image.open(f'{i}.png')
  mas = np.loadtxt(f'{i}.txt')
  filter_3 = optimize_filter_3(np.array(img), mas, filter_1.reshape(3, 3), filter_2.reshape(3, 3), filter_3.reshape(3, 3))
  print(filter_3)

Iter 0, Loss: 15085.978516
Iter 1000, Loss: 1.356217
Iter 2000, Loss: 0.530284
Iter 3000, Loss: 0.323739
Iter 4000, Loss: 0.173075
Iter 5000, Loss: 0.082704
Iter 6000, Loss: 0.046734
Iter 7000, Loss: 0.031552
Iter 8000, Loss: 0.020395
Iter 9000, Loss: 0.010639
Iter 10000, Loss: 0.006449
Iter 11000, Loss: 0.004128
Iter 12000, Loss: 0.003423
Iter 13000, Loss: 0.002514
Iter 14000, Loss: 0.001510
Iter 15000, Loss: 0.000724
Iter 16000, Loss: 0.000359
Iter 17000, Loss: 0.000122
Iter 18000, Loss: 0.000053
Iter 19000, Loss: 0.000024
Iter 20000, Loss: 0.000010
Iter 21000, Loss: 0.000009
Iter 22000, Loss: 0.000007
Iter 23000, Loss: 0.000006
Iter 24000, Loss: 0.000003
Iter 25000, Loss: 0.000009
Iter 26000, Loss: 0.000001
Iter 27000, Loss: 0.000000
Iter 28000, Loss: 0.000000
Iter 29000, Loss: 0.000000
Iter 30000, Loss: 0.000000
Iter 31000, Loss: 0.000000
Iter 32000, Loss: 0.000000
Iter 33000, Loss: 0.000000
Iter 34000, Loss: 0.000000
(2, 1, 0)
[[0.12496181 0.25006548 0.12497157]
 [0.250069   0.499

In [123]:
filter_3.reshape(-1)

array([0.12501355, 0.24997701, 0.12501109, 0.24997613, 0.50004417,
       0.24997681, 0.12501192, 0.24997613, 0.12501368], dtype=float32)

In [105]:
filter_2

array([0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625,
       0.0625])

In [106]:
filter_1

array([-1. , -0.5,  0. , -0.5,  0.5,  0.5,  0. ,  0.5,  1. ])

In [126]:
import csv

# Данные для записи (список списков или список словарей)
data = [
    [0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625],
    [0.12501355, 0.24997701, 0.12501109, 0.24997613, 0.50004417, 0.24997681, 0.12501192, 0.24997613, 0.12501368],
    [-1. , -0.5,  0. , -0.5,  0.5,  0.5,  0. ,  0.5,  1. ]
]

# Открываем файл на запись
with open('sample_data/reconstructed_algos.csv', 'w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    for row in data:
        writer.writerow(row)  # Записываем строку

In [128]:
output = convolve2d(img, filter_2.reshape(3, 3), mode='same')

In [129]:
output = convolve2d(output, filter_3.reshape(3, 3), mode='same')

In [130]:
output = convolve2d(img, filter_1.reshape(3, 3), mode='same')

In [131]:
output

array([[-214.5, -134. , -149.5, ..., -168.5, -166. ,   84.5],
       [-154. ,   66.5,   57. , ...,   77.5,  102.5,  337. ],
       [-162. ,   57.5,   61. , ...,   83. ,  119.5,  328.5],
       ...,
       [-135. ,    5.5,  -13.5, ...,   54.5,   56. ,  354.5],
       [-121.5,   40. ,    4. , ...,   85.5,   95. ,  417. ],
       [  76.5,  262.5,  253.5, ...,  437.5,  428. ,  540. ]])

In [132]:
mas

array([[ 248.15625   ,  276.62890625,  275.24609375, ...,  285.9765625 ,
         200.7421875 ,   30.5859375 ],
       [ 285.12109375,  244.01171875,  200.51171875, ...,  176.73046875,
          71.4609375 , -106.234375  ],
       [ 289.86328125,  209.4296875 ,  142.421875  , ...,   87.65625   ,
         -20.56640625, -182.0234375 ],
       ...,
       [ 242.890625  ,  184.47265625,  143.4296875 , ...,   90.96484375,
         -15.50390625, -208.0078125 ],
       [ 174.55078125,   92.44921875,   37.9921875 , ...,  -43.5703125 ,
        -125.30859375, -260.25      ],
       [  28.2578125 ,  -79.5       , -146.421875  , ..., -267.484375  ,
        -282.7734375 , -284.484375  ]])