In [None]:
import numpy as np
import pickle
from torch.utils.data import DataLoader, Sampler
import random
import os
import fnmatch
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from MLP import MultiLayerPerceptron

In [None]:
from MyDataPreparationLSTM import CustomDataset

In [None]:
def find_files(directory, pattern, maxdepth=None):
    flist = []
    for root, dirs, files in os.walk(directory):
        for basename in files:
            if fnmatch.fnmatch(basename, pattern):
                filename = os.path.join(root, basename)
                filename = filename.replace('\\\\', os.sep)
                if maxdepth is None:
                    flist.append(filename)
                else:
                    if filename.count(os.sep)-directory.count(os.sep) <= maxdepth:
                        flist.append(filename)
    return flist

In [None]:
wind_files_pkl = find_files('/mnt/hippocamp/asavin/data/wind/wind_arrays_kara_norm', '*.pkl')
wind_files_pkl.sort()

In [None]:
dataset = CustomDataset(wind_files_pkl, num_days=14, num_years=1)
dataset.select_random_years()

dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
pretrained_autoencoder_name = 'wind_pre_autoencoder_run002'
lstm2D_name = 'wind_pre_lstm2D_run003'

pretrained_encoder = torch.load(f'/app/Kara_plume_movement/wind/models/model_{pretrained_autoencoder_name}_encoder.pth', map_location=torch.device('cpu'))
lstm_network = torch.load(f'/app/Kara_plume_movement/wind/models/model_{lstm2D_name}_lstm_network.pth', map_location=torch.device('cpu'))
lstm_decoder = torch.load(f'/app/Kara_plume_movement/wind/models/model_{lstm2D_name}_MLPdecoder.pth', map_location=torch.device('cpu'))
pretrained_decoder = torch.load(f'/app/Kara_plume_movement/wind/models/model_{pretrained_autoencoder_name}_decoder.pth', map_location=torch.device('cpu'))

In [None]:
pretrained_encoder = pretrained_encoder.cuda()
lstm_network = lstm_network.cuda()
lstm_decoder = lstm_decoder.cuda()
pretrained_decoder = pretrained_decoder.cuda()

In [None]:
data, target = next(iter(dataloader))

In [None]:
pretrained_encoder.eval();
lstm_network.eval();
lstm_decoder.eval();
pretrained_decoder.eval();

In [None]:
loss_function=torch.nn.MSELoss()

In [None]:
with torch.no_grad():
    data_gpu = data.to(device='cuda', dtype=torch.float)

    data_list = data_gpu.unbind(dim=1)
    encoded_data_list = [pretrained_encoder.forward(t) for t in data_list]
    decoded_target_list = [pretrained_decoder.forward(t) for t in encoded_data_list]

    wind_vector = torch.stack(encoded_data_list, dim=1)

    encoded_data, (_, _) = lstm_network.forward(wind_vector)
    split_tensors = encoded_data.unbind(dim=1)
    processed_tensors = [pretrained_decoder.forward(lstm_decoder.forward(t)) for t in split_tensors]
    decoded_data = torch.stack(processed_tensors, dim=1)
    
    decoded_target = torch.stack(decoded_target_list, dim=1)

    loss = loss_function(decoded_target, decoded_data)
    test_loss = loss.detach() * dataloader.batch_size

In [None]:
test_loss

In [None]:
decoded_target.shape, decoded_data.shape

In [None]:
def drawing(tensor1, tensor2, indexes):
    for i in indexes:
        for j in range(tensor1.shape[1]):
            # Извлечение частей тензоров
            extracted_tensor1_0 = tensor1[i, j, 0, :, :]
            extracted_tensor1_1 = tensor1[i, j, 1, :, :]
            extracted_tensor2_0 = tensor2[i, j, 0, :, :]
            extracted_tensor2_1 = tensor2[i, j, 1, :, :]

            # Конвертация в numpy
            array1_0 = extracted_tensor1_0.detach().cpu().numpy()
            array1_1 = extracted_tensor1_1.detach().cpu().numpy()
            array2_0 = extracted_tensor2_0.detach().cpu().numpy()
            array2_1 = extracted_tensor2_1.detach().cpu().numpy()

            # Получение общих минимумов/максимумов для нормализации отображения
            vmin = min(array1_0.min(), array1_1.min(), array2_0.min(), array2_1.min())
            vmax = max(array1_0.max(), array1_1.max(), array2_0.max(), array2_1.max())

            # Построение графиков
            fig, axs = plt.subplots(2, 2, figsize=(12, 8))
            
            # Tensor 1
            axs[0, 0].imshow(array1_0, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
            axs[0, 0].set_title('Tensor1 U')

            axs[1, 0].imshow(array1_1, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
            axs[1, 0].set_title('Tensor1 V')

            # Tensor 2
            axs[0, 1].imshow(array2_0, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
            axs[0, 1].set_title('Tensor2 U')

            axs[1, 1].imshow(array2_1, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
            axs[1, 1].set_title('Tensor2 V')

            # Добавление единого цветового бара
            cbar = fig.colorbar(axs[0, 0].images[0], ax=axs, orientation='vertical', fraction=0.02, pad=0.04)
            cbar.ax.set_ylabel('wind speed')

            plt.tight_layout()
            plt.show()
        
        print('--------------------------------------------------------------------------------------------------')

In [None]:
drawing(decoded_data, decoded_target, [i for i in range(2)])