In [1]:
import torch
import torch.fft as fft
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from scipy import io
import scipy as sc

from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import pandas as pd
from tqdm.notebook import trange, tqdm
from utility import *
from ONN_one_layer import *
from ONN_two_layer import *
%matplotlib inline
config = CONFIG()

In [2]:
class CONFIG():
    def __init__(self, array_size: int = 1024,
                 pixel_size: float = 7e-6,
                 wavelength: float = 532e-9,
                 image_size: int = 62,
                 distance: float = 0.3):
        self.array_size: int = array_size
        self.pixel_size: float = pixel_size
        self.wavelength: float = wavelength
        self.K: float = 2*np.pi/self.wavelength # VolnovoiVector
        self.aperture_size: float = self.array_size * self.pixel_size
        self.image_size: int = image_size
        self.distance: float = distance
        self.out_h = 4e-3
        self.out_w = 5e-3
        self.mask_size = 1e-3

        out_h2 = self.out_h/2-self.mask_size/2
        out_w2 = self.out_w/2-self.mask_size/2
        out_w3 = (self.out_w-2*self.mask_size)/2-(self.out_w-4*self.mask_size)/3-self.mask_size/2

        self.coords = torch.tensor([[-out_w2, -out_h2],
                                [-out_w3,	-out_h2],
                                [out_w3,	-out_h2],
                                [out_w2,	-out_h2],
                                [-out_w2,	0],
                                [out_w2,	0],
                                [-out_w2,	out_h2],
                                [-out_w3,	out_h2],
                                [out_w3,	out_h2],
                                [out_w2,	out_h2]])
config = CONFIG()

In [3]:
batch_size = 10
# Загрузка и трансформирование данных
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Resize(config.image_size, interpolation = transforms.InterpolationMode.NEAREST)])

train_data = datasets.MNIST(root="D:\Visual Studio Code\data",   train=True, download=True, transform=transform)
test_data = datasets.MNIST( root="D:\Visual Studio Code\data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle= True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle= False)



In [4]:
def custom_function(self,
                    image: torch.Tensor,
                    k: int = 0) -> torch.Tensor:
    #res = nn.MaxPool2d(kernel_size = k)(image)
    res = torch.sum(image, dim = (2, 3))
    #max_el = self.maxpool2(res)
    #res= res/(max_el*(100/99))+0.99*max_el
    #max_el = self.maxpool2(res)
    #sum_el = torch.sum(res, dim = (1))[:, None]
    #coef = max_el/sum_el-0.25
    #coef = 1 - nn.Sigmoid()(coef*40)
    #res= res/(max_el*(coef*100/99))+coef*0.99*max_el
    #res = self.dropout(res)
    res = res.reshape([len(res), 10])
    res = F.softmax(res, dim = 1) #F.relu(res)
    return res


In [9]:
class ONN(nn.Module):
    def __init__(self, config,
                 phase: None | torch.Tensor = None,
                 out_function = sum_func): 
        super(ONN, self).__init__()
        
        self.array_size = config.array_size
        self.pixel_size = config.pixel_size
        self.wavelength = config.wavelength
        self.K = config.K # VolnovoiVector
        self.aperture_size = config.aperture_size
        self.image_size = config.image_size
        self.distance = config.distance
        
        border = np.pi * self.array_size / self.aperture_size
        arr = torch.linspace(-border, border, self.array_size+1)[:self.array_size]
        xv, yv = torch.meshgrid(arr, arr, indexing='ij')
        xx = xv**2 + yv**2
        self.U = torch.roll(xx, (int(self.array_size/2), int(self.array_size/2)), dims = (0, 1))
        self.p = torch.sqrt(-self.U+self.K**2)

        coords = config.coords
        l = torch.linspace(-config.array_size*config.pixel_size, config.array_size*config.pixel_size, config.array_size)
        Y, X = torch.meshgrid(l, l, indexing='ij')
        
        self.mask = torch.stack([(X > coords[x][0]-config.mask_size/2) * (X < coords[x][0]+config.mask_size/2) * (Y > coords[x][1]-config.mask_size/2) * (Y < coords[x][1]+config.mask_size/2) for x in range(10)])

        #mask_add = (X > -15e-4) * (X < 15e-4) * (Y > -1e-3) * (Y < 1e-3)
        #self.mask = torch.cat((self.mask, mask_add.unsqueeze(0)), 0)
        #self.mask = torch.cat((self.mask, (torch.ones((self.array_size, self.array_size)) -torch.sum(self.mask, dim=(0))).unsqueeze(0)), 0)

        self.maxpool = nn.MaxPool2d(kernel_size = self.array_size)
        self.dropout = nn.Dropout(0.5)
        self.phase: torch.Tensor
        if(phase is not None):
            self.phase = phase
        else:
            self.phase = nn.Parameter(torch.rand(self.array_size, self.array_size, dtype=torch.float))
        self.zero = nn.ZeroPad2d(int((self.array_size - self.image_size)/2))
        self.zero_add = nn.ZeroPad2d(int(self.array_size/2))
        self.softmax = nn.Softmax(dim=1)
        self.one = torch.ones((self.array_size, self.array_size))
        self.function = out_function
        self.eta = torch.exp(1j*self.distance*self.p)
        self.zeros = torch.zeros((10,1)).cuda()

        kernel = torch.tensor(
            [[0, 1, 0],
            [1, -4, 1], 
            [0, 1, 0]], dtype=torch.float)
        kernel = kernel.reshape(1, 1, 3, 3)
        self.one_zero = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)
        self.conv.weight = nn.Parameter(kernel)
        #self.ratio = nn.Parameter(torch.rand(10, dtype=torch.float))
        
    def propagation(self, field, z, p):
        eta = torch.exp(1j*z*p)
        res = fft.ifft2(fft.fft2(field) * eta)
        #res = res * self.dropout(self.one)
        return res
    
    def DOE(self):
        return torch.exp(1j*self.phase)
    
    def forward(self, x):
         #x=x>0.5
         x = self.one_zero(self.conv(x))
         x = x/(torch.sum(x**2, dim = (1, 2, 3))[:, None, None, None]**0.5)
         x = self.zero(x)
         x = fft.fft2(x)
         x = x *  self.DOE()
         x = fft.fft2(x)
         x = x/(torch.sum(x**2, dim = (1, 2, 3))[:, None, None, None]**0.5)*np.sqrt(500)
         res = x * self.mask
         res = torch.abs(res)**2
         #res = self.dropout(res)
         res=self.function(self, res, self.array_size)#*self.ratio
         return x, res

In [10]:
onn = ONN(config,  out_function=custom_function)
onn = onn.cuda()
onn.mask = onn.mask.cuda()
onn.p = onn.p.cuda()
onn.one = onn.one.cuda()
onn.eta = onn.eta.cuda()

In [None]:
onn = ONN(config,  out_function=custom_function)

In [11]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(onn.parameters(), lr=0.03)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(onn.parameters(), lr=0.03)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)

In [None]:
onn.load_state_dict(torch.load("onn_binar_2.ckpt"))
#onn.load_state_dict(torch.load('Lerning\\Part1\\onn_sum_3_97.ckpt'))

In [None]:
onn = onn.cuda()
onn.mask = onn.mask.cuda()
onn.p = onn.p.cuda()
onn.p_extended = onn.p_extended.cuda()
onn.one = onn.one.cuda()

In [None]:
onn =onn.cpu()
onn.mask = onn.mask.cpu()
onn.p = onn.p.cpu()
onn.one = onn.one.cpu()

In [None]:
def train(onn: nn.Module,
          train_loader1: torch.utils.data.DataLoader,
          train_loader2: torch.utils.data.DataLoader,
          criterion: nn.Module,
          optimizer: torch.optim.Optimizer,
          scheduler,
          device: torch.device = torch.device('cpu'),
          test_loader: torch.utils.data.DataLoader = None,
          num_epochs: int = 5,
          func_transform = None,
          get_train_data: bool = False,
          loss_list: torch.Tensor | None = None,
          acc_list: torch.Tensor | None = None): #Duble
    
    acc_test_list = None

    divider: int
    if get_train_data:
        divider = int(num_epochs*len(train_loader1)*2/len(loss_list))
    
    if test_loader is not None:
        acc_test_list = torch.zeros(int(num_epochs*len(train_loader1)*2/divider))
    

    onn.train()

    data_iterator1 = init_batch_generator(train_loader1)
    data_iterator2 = init_batch_generator(train_loader2)
    test_iterator = init_batch_generator(test_loader)
    progress = trange(num_epochs*len(train_loader1))

    for epoch in progress:
        onn.train()
        # Прямой запуск
        images1, labels1 = next(data_iterator1)
        images2, labels2 = next(data_iterator2)
        labels = torch.cat((labels1, labels2), 0)
        images1 = images1.to(device)
        images2 = images2.to(device)
        labels = labels.to(device)
        #labels =  torch.tensor([[ 1. if j == labels[i] else 0. for j in range(10)]for i in range(len(labels))]).to(device)
        _, outputs = onn(images1, images2)
        loss = criterion(outputs, labels if func_transform is None else func_transform(labels))
        # if get_train_data and (epoch + 1) % divider == 0:
        #     loss_list[int(epoch/divider)] = loss.item()

        # Обратное распространение и оптимизатор
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Отслеживание точности на тренировочном наборе
        total = labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        if get_train_data and (epoch + 1) % divider == 0:
            onn.eval()
            _, outputs = onn(images)
            loss_list[int(epoch/divider)] = criterion(outputs, labels if func_transform is None else func_transform(labels)).item()
            #total = labels.size(0)
            _, predicted2 = torch.max(outputs.data, 1)
            acc_list[int(epoch/divider)] = (predicted2 == labels).sum().item() / total

        # Отслеживание точности на тестовом наборе
        if test_loader is not None:
            images, labels = next(test_iterator)
            images = images.to(device)
            labels = labels.to(device)
            _, outputs = onn(images)
            total_test = labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_test = (predicted == labels).sum().item()
            acc_test_list[epoch] = correct_test / total_test

        if (epoch + 1) % 5 == 0:
            scheduler.step()
        if (epoch+1) % 7 == 0:
            if test_loader is not None:
                progress.set_postfix_str(f"Loss: {loss.item() :.4f}, Accuracy: {(correct / total) * 100 :.2f}%, Test accuracy: {acc_test_list[epoch]*100 :.2f} lr: {scheduler.get_last_lr()[0] :e}")
            else:
                progress.set_postfix_str(f"Loss: {loss.item() :.4f}, Accuracy: {(correct / total) * 100 :.2f}%, lr: {scheduler.get_last_lr()[0] :e}")
    if test_loader is not None:
        return acc_test_list

In [None]:
def train(onn: nn.Module,
          train_loader1: torch.utils.data.DataLoader,
          train_loader2: torch.utils.data.DataLoader,
          train_loader3: torch.utils.data.DataLoader,
          criterion: nn.Module,
          optimizer: torch.optim.Optimizer,
          scheduler,
          device: torch.device = torch.device('cpu'),
          test_loader: torch.utils.data.DataLoader = None,
          num_epochs: int = 5,
          func_transform = None,
          get_train_data: bool = False,
          loss_list: torch.Tensor | None = None,
          acc_list: torch.Tensor | None = None): #Treeple
    
    acc_test_list = None

    divider: int
    if get_train_data:
        divider = int(num_epochs*len(train_loader1)*2/len(loss_list))
    
    if test_loader is not None:
        acc_test_list = torch.zeros(int(num_epochs*len(train_loader1)*2/divider))
    

    onn.train()

    data_iterator1 = init_batch_generator(train_loader1)
    data_iterator2 = init_batch_generator(train_loader2)
    data_iterator3 = init_batch_generator(train_loader3)
    test_iterator = init_batch_generator(test_loader)
    progress = trange(num_epochs*len(train_loader1))

    for epoch in progress:
        onn.train()
        # Прямой запуск
        images1, labels1 = next(data_iterator1)
        images2, labels2 = next(data_iterator2)
        images3, labels3 = next(data_iterator3)
        labels = torch.cat((labels1, labels2, labels3), 0)
        images1 = images1.to(device)
        images2 = images2.to(device)
        images3 = images3.to(device)
        labels = labels.to(device)
        #labels =  torch.tensor([[ 1. if j == labels[i] else 0. for j in range(10)]for i in range(len(labels))]).to(device)
        _, outputs = onn(images1, images2, images3)
        loss = criterion(outputs, labels if func_transform is None else func_transform(labels))
        # if get_train_data and (epoch + 1) % divider == 0:
        #     loss_list[int(epoch/divider)] = loss.item()

        # Обратное распространение и оптимизатор
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Отслеживание точности на тренировочном наборе
        total = labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        if get_train_data and (epoch + 1) % divider == 0:
            onn.eval()
            _, outputs = onn(images)
            loss_list[int(epoch/divider)] = criterion(outputs, labels if func_transform is None else func_transform(labels)).item()
            #total = labels.size(0)
            _, predicted2 = torch.max(outputs.data, 1)
            acc_list[int(epoch/divider)] = (predicted2 == labels).sum().item() / total

        # Отслеживание точности на тестовом наборе
        if test_loader is not None:
            images, labels = next(test_iterator)
            images = images.to(device)
            labels = labels.to(device)
            _, outputs = onn(images)
            total_test = labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct_test = (predicted == labels).sum().item()
            acc_test_list[epoch] = correct_test / total_test

        if (epoch + 1) % 5 == 0:
            scheduler.step()
        if (epoch+1) % 7 == 0:
            if test_loader is not None:
                progress.set_postfix_str(f"Loss: {loss.item() :.4f}, Accuracy: {(correct / total) * 100 :.2f}%, Test accuracy: {acc_test_list[epoch]*100 :.2f} lr: {scheduler.get_last_lr()[0] :e}")
            else:
                progress.set_postfix_str(f"Loss: {loss.item() :.4f}, Accuracy: {(correct / total) * 100 :.2f}%, lr: {scheduler.get_last_lr()[0] :e}")
    if test_loader is not None:
        return acc_test_list

In [13]:
train(onn, train_loader, criterion, optimizer, scheduler, num_epochs=3, device=torch.device('cuda'), func_transform=lambda labels:torch.tensor([[ 1. if j == labels[i] else 0. for j in range(10)]for i in range(len(labels))]).cuda())

  0%|          | 0/18000 [00:00<?, ?it/s]

In [17]:
onn.conv.weight

Parameter containing:
tensor([[[[ 1.8379, -0.8667,  1.0641],
          [ 2.0208,  0.1466,  3.1497],
          [-1.0503, -1.7738,  0.3553]]]], device='cuda:0', requires_grad=True)

In [None]:
torch.save(onn.state_dict(),'onn_binar_2.ckpt')

In [None]:
def test(onn: nn.Module,
         test_dataloader1: torch.utils.data.DataLoader,
         test_dataloader2: torch.utils.data.DataLoader,
         device: torch.device = torch.device('cpu'),
         get_acc_df: bool = False,
         get_energy_df: bool = False): #Double
         
    pd.options.styler.format.precision = 1

    df_energy1 = df_acc1 = energy_max1 = energy1 = acc_max1 = acc1 = None
    df_energy2 = df_acc2 = energy_max2 = energy2 = acc_max2 = acc2 = None
    
    if get_acc_df:
        acc1 = torch.zeros((10, 10)).to(device)
        acc_max1 = torch.zeros((10, 1)).to(device)
        acc2 = torch.zeros((10, 10)).to(device)
        acc_max2 = torch.zeros((10, 1)).to(device)

    if get_energy_df:
        energy1 = torch.zeros((10, 10)).to(device)
        energy_max1 = torch.zeros((10, 1)).to(device)
        energy2 = torch.zeros((10, 10)).to(device)
        energy_max2 = torch.zeros((10, 1)).to(device)

    onn.eval()

    data_iterator1 = init_batch_generator(test_dataloader1)
    data_iterator2 = init_batch_generator(test_dataloader2)
    progress = trange(len(test_dataloader1))

    with torch.no_grad():
        correct = 0
        total = 0
        for i in progress:
            images, labels = next(data_iterator1)
            out_image, outputs = onn(images.to(device), None)

            if get_energy_df:
                out_image=out_image*onn.mask
                out_energy = torch.sum(torch.abs(out_image)**2, dim = (2, 3))
                for i in range(len(labels)):
                    for j in range(10):
                        energy1[labels[i]][j]+=out_energy[i][j]
                        energy_max1[labels[i]][0]+=out_energy[i][j]

            _, predicted = torch.max(outputs.data, 1)

            if get_acc_df:
                for i in range(len(predicted)):
                    acc1[predicted[i]][labels[i]]+=1
                    acc_max1[labels[i]][0]+=1
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
            progress.set_postfix_str(f"Accuracy: {(correct / total) * 100}%")

    progress = trange(len(test_dataloader2))

    with torch.no_grad():
        correct = 0
        total = 0
        for i in progress:
            images, labels = next(data_iterator2)
            out_image, outputs = onn(None, images.to(device))

            if get_energy_df:
                out_image=out_image*onn.mask2
                out_energy = torch.sum(torch.abs(out_image)**2, dim = (2, 3))
                for i in range(len(labels)):
                    for j in range(10):
                        energy2[labels[i]][j]+=out_energy[i][j]
                        energy_max2[labels[i]][0]+=out_energy[i][j]

            _, predicted = torch.max(outputs.data, 1)

            if get_acc_df:
                for i in range(len(predicted)):
                    acc2[predicted[i]][labels[i]]+=1
                    acc_max2[labels[i]][0]+=1
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
            progress.set_postfix_str(f"Accuracy: {(correct / total) * 100}%")

    if get_acc_df:
        acc1/=acc_max1/100
        df_acc1 = pd.DataFrame(acc1.cpu(),
                                index=pd.MultiIndex.from_product([['Точность предсказания'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)]),).style.background_gradient(cmap='YlOrBr').set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
        acc2/=acc_max2/100
        df_acc2 = pd.DataFrame(acc2.cpu(),
                                index=pd.MultiIndex.from_product([['Точность предсказания'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)]),).style.background_gradient(cmap='YlOrBr').set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
    
    if get_energy_df:
        energy1/=energy_max1/100
        df_energy1 = pd.DataFrame(torch.transpose(energy1, 0, 1).cpu(),
                                index=pd.MultiIndex.from_product([['Расперделение энергии'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)])).style.background_gradient().set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
        energy2/=energy_max2/100
        df_energy2 = pd.DataFrame(torch.transpose(energy2, 0, 1).cpu(),
                                index=pd.MultiIndex.from_product([['Расперделение энергии'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)])).style.background_gradient().set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
    
    if get_energy_df and not get_acc_df:
        return df_energy1, df_energy2
    else:
        return df_acc1, df_energy1, df_acc2, df_energy2

In [None]:
def test(onn: nn.Module,
         test_dataloader1: torch.utils.data.DataLoader,
         test_dataloader2: torch.utils.data.DataLoader,
         test_dataloader3: torch.utils.data.DataLoader,
         device: torch.device = torch.device('cpu'),
         get_acc_df: bool = False,
         get_energy_df: bool = False): #Treeple
         
    pd.options.styler.format.precision = 1

    df_energy1 = df_acc1 = energy_max1 = energy1 = acc_max1 = acc1 = None
    df_energy2 = df_acc2 = energy_max2 = energy2 = acc_max2 = acc2 = None
    df_energy3 = df_acc3 = energy_max3 = energy3 = acc_max3 = acc3 = None
    
    if get_acc_df:
        acc1 = torch.zeros((10, 10)).to(device)
        acc_max1 = torch.zeros((10, 1)).to(device)
        acc2 = torch.zeros((10, 10)).to(device)
        acc_max2 = torch.zeros((10, 1)).to(device)
        acc3 = torch.zeros((10, 10)).to(device)
        acc_max3 = torch.zeros((10, 1)).to(device)

    if get_energy_df:
        energy1 = torch.zeros((10, 10)).to(device)
        energy_max1 = torch.zeros((10, 1)).to(device)
        energy2 = torch.zeros((10, 10)).to(device)
        energy_max2 = torch.zeros((10, 1)).to(device)
        energy3 = torch.zeros((10, 10)).to(device)
        energy_max3 = torch.zeros((10, 1)).to(device)

    onn.eval()

    data_iterator1 = init_batch_generator(test_dataloader1)
    data_iterator2 = init_batch_generator(test_dataloader2)
    data_iterator3 = init_batch_generator(test_dataloader3)
    progress = trange(len(test_dataloader1))

    with torch.no_grad():
        correct = 0
        total = 0
        for i in progress:
            images, labels = next(data_iterator1)
            out_image, outputs = onn(images.to(device), None, None)

            if get_energy_df:
                out_image=out_image*onn.mask
                out_energy = torch.sum(torch.abs(out_image)**2, dim = (2, 3))
                for i in range(len(labels)):
                    for j in range(10):
                        energy1[labels[i]][j]+=out_energy[i][j]
                        energy_max1[labels[i]][0]+=out_energy[i][j]

            _, predicted = torch.max(outputs.data, 1)

            if get_acc_df:
                for i in range(len(predicted)):
                    acc1[predicted[i]][labels[i]]+=1
                    acc_max1[labels[i]][0]+=1
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
            progress.set_postfix_str(f"Accuracy: {(correct / total) * 100}%")

    progress = trange(len(test_dataloader2))

    with torch.no_grad():
        correct = 0
        total = 0
        for i in progress:
            images, labels = next(data_iterator2)
            out_image, outputs = onn(None, images.to(device), None)

            if get_energy_df:
                out_image=out_image*onn.mask2
                out_energy = torch.sum(torch.abs(out_image)**2, dim = (2, 3))
                for i in range(len(labels)):
                    for j in range(10):
                        energy2[labels[i]][j]+=out_energy[i][j]
                        energy_max2[labels[i]][0]+=out_energy[i][j]

            _, predicted = torch.max(outputs.data, 1)

            if get_acc_df:
                for i in range(len(predicted)):
                    acc2[predicted[i]][labels[i]]+=1
                    acc_max2[labels[i]][0]+=1
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
            progress.set_postfix_str(f"Accuracy: {(correct / total) * 100}%")

    progress = trange(len(test_dataloader3))
    with torch.no_grad():
        correct = 0
        total = 0
        for i in progress:
            images, labels = next(data_iterator3)
            out_image, outputs = onn(None, None, images.to(device))

            if get_energy_df:
                out_image=out_image*onn.mask3
                out_energy = torch.sum(torch.abs(out_image)**2, dim = (2, 3))
                for i in range(len(labels)):
                    for j in range(10):
                        energy3[labels[i]][j]+=out_energy[i][j]
                        energy_max3[labels[i]][0]+=out_energy[i][j]

            _, predicted = torch.max(outputs.data, 1)

            if get_acc_df:
                for i in range(len(predicted)):
                    acc3[predicted[i]][labels[i]]+=1
                    acc_max3[labels[i]][0]+=1
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
            progress.set_postfix_str(f"Accuracy: {(correct / total) * 100}%")

    if get_acc_df:
        acc1/=acc_max1/100
        df_acc1 = pd.DataFrame(acc1.cpu(),
                                index=pd.MultiIndex.from_product([['Точность предсказания'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)]),).style.background_gradient(cmap='YlOrBr').set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
        acc2/=acc_max2/100
        df_acc2 = pd.DataFrame(acc2.cpu(),
                                index=pd.MultiIndex.from_product([['Точность предсказания'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)]),).style.background_gradient(cmap='YlOrBr').set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
        acc3/=acc_max3/100
        df_acc3 = pd.DataFrame(acc3.cpu(),
                                index=pd.MultiIndex.from_product([['Точность предсказания'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)]),).style.background_gradient(cmap='YlOrBr').set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
    
    if get_energy_df:
        energy1/=energy_max1/100
        df_energy1 = pd.DataFrame(torch.transpose(energy1, 0, 1).cpu(),
                                index=pd.MultiIndex.from_product([['Расперделение энергии'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)])).style.background_gradient().set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
        energy2/=energy_max2/100
        df_energy2 = pd.DataFrame(torch.transpose(energy2, 0, 1).cpu(),
                                index=pd.MultiIndex.from_product([['Расперделение энергии'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)])).style.background_gradient().set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
        energy3/=energy_max3/100
        df_energy3 = pd.DataFrame(torch.transpose(energy3, 0, 1).cpu(),
                                index=pd.MultiIndex.from_product([['Расперделение энергии'], range(10)]),
                                columns=pd.MultiIndex.from_product([['Поданное число'], range(10)])).style.background_gradient().set_table_styles([
            {'selector': 'th.col_heading', 'props': 'text-align: center;'},
            {'selector': 'th.row_heading.level0', 'props': 'writing-mode: vertical-lr; transform: rotate(180deg); text-align: center;'},
        ], overwrite=False)
    
    if get_energy_df and not get_acc_df:
        return df_energy1, df_energy2
    else:
        return df_acc1, df_energy1, df_acc2, df_energy2, df_acc3, df_energy3

In [None]:
out = test(onn, test_loader, torch.device('cuda'), True, True)

In [None]:
onn.phase4 = nn.Parameter(torch.tensor(io.loadmat('DOE_two_98.mat')['DOE_Phase2']))

In [None]:
io.savemat('two_best_doe.mat', mdict={'DOE1':onn.phase1.cpu().detach().numpy(), 'DOE2':onn.phase2.cpu().detach().numpy() })

In [None]:
io.savemat('Lerning\\chart_4.mat', mdict={'acc': np.array(acc_list), 'loss': np.array(loss_list)})

In [None]:
io.savemat('mask2.mat', mdict={'MASK': m.cpu().detach().numpy()})

In [None]:
import mat73
data_dict = mat73.loadmat('DOES.mat')

In [None]:
onn.phase = nn.Parameter(torch.angle(torch.tensor(data_dict["DOES"])))

In [None]:
onn.phase1 = nn.Parameter(torch.tensor(io.loadmat('latter.mat')['latter'], dtype=torch.float))

In [None]:
e=out[1].data.reset_index(drop=True).get("Поданное число")

In [None]:
min([(sorted(e[i])[9]-sorted(e[i])[8])/(sorted(e[i])[9]+sorted(e[i])[8]) for i in e])

In [None]:
out[1]

In [None]:
out[0]

In [None]:

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[])
#output = onn.zero(image[0][11]).cuda()
def animate(i):
    ax.clear()
    #line = ax.imshow(abs(hisory[int(i)]), cmap='gray')
    line = None
    if i>=0.6:
       i-=0.6
       line =ax.imshow(abs( onn.propagation(im_to3, i, onn.p_extended)[0][0].cpu().detach().numpy()), cmap='gray')
    elif i>=0.3:
       i-=0.3
       line =ax.imshow(abs( onn.propagation(im_to2, i, onn.p_extended)[0][0].cpu().detach().numpy()), cmap='gray')
    else:
       line =ax.imshow(abs( onn.propagation(im_to, i, onn.p_extended)[0][0].cpu().detach().numpy()), cmap='gray')
    #line = ax.plot(t, np.sin(i*t))
    return line
sin_animation = animation.FuncAnimation(fig, 
                                      animate, 
                                      frames=np.linspace(0, 0.9, 120),
                                      interval = 10,
                                      repeat = True)
sin_animation.save('anim_check.gif',
                 writer='imagemagick', 
                 fps=20)

In [None]:
im_to = im_to2 = im_to3 = None

In [None]:
im_to = onn.zero(image/(torch.sum(image**2, dim = (1, 2, 3))[:, None, None, None]**0.5)*np.sqrt(500))
im_to2 = onn.zero_add(onn.propagation(im_to.cuda(), 0.3, onn.p)*onn.DOE())
im_to3 = onn.zero_add(onn.propagation(im_to2.cuda(), 0.3, onn.p_extended)[:,:,256:768,256:768])
im_to = onn.zero_add(im_to).cuda()

In [None]:
kernel = torch.tensor(
    [[0, 1, 0],
     [1, -4, 1], 
     [0, 1, 0]], dtype=torch.float)
kernel = kernel.reshape(1, 1, 3, 3)
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)
conv.weight = nn.Parameter(kernel)

In [None]:
res = conv(image)

In [None]:
p, t = onn(nn.ZeroPad2d(1)(res.cuda()))

In [18]:
it = init_batch_generator(test_loader)

In [19]:
image, lable = next(it)
onn.eval()
p, t = onn(image.cuda())

In [None]:
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(1, 1, 1)
ax.imshow(np.squeeze((torch.sum(onn.mask, dim=(0)) ).cpu().detach().numpy()), cmap='gray')

In [None]:
def show_image(image: torch.Tensor,
               number_image: torch.Tensor,
               number: int = 0,
               acc: torch.Tensor | None = None,
               showing_grid: bool = False,
               coords: torch.Tensor | None = None,
               image_size: int = 56,
               fig_size: int = 5):
    fig = plt.figure(figsize=(fig_size*2, fig_size))
    ax = fig.add_subplot(1, 2, 1)
    for i in range(10):
        if showing_grid:
            rectangle = plt.Rectangle((coords[i]*100+482)/2,
                               image_size/2, image_size/2, fc='#00000000', ec="black")
            ax.add_patch(rectangle)
        if acc is not None:
            ax.text((coords[i][0]*100+476)/2, (coords[i][1]*100+598)/2, f"{acc[number][i].item()*100 :.1f}%", fontdict={ 'color': 'black' })
    plt.colorbar(ax.imshow(np.squeeze(torch.abs(image[number]).cpu().detach().numpy()**2), cmap='gray'))
    ax = fig.add_subplot(1, 2, 2)
    plt.colorbar(ax.imshow(np.squeeze(number_image[number].cpu().detach().numpy()), cmap='gray'))

In [20]:
show_image(p, image, 3,  coords= config.coords)

NameError: name 'res' is not defined

In [None]:
plt.imshow((torch.abs(p[1][0])).cpu().detach().numpy(), cmap='gray')