In [1]:
from netCDF4 import Dataset
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import datetime
from tqdm import tqdm
from glob import glob
import pickle
import random
import os
import fnmatch

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import torch.nn.functional as F
from torchvision import datasets, transforms
from typing import Tuple, List, Type, Dict, Any
from torch.utils.tensorboard import SummaryWriter

from SGDR import CosineAnnealingWarmRestarts
from mish import Mish
from coord_conv import CoordConv
from MyResidualNetwork import MyResNet, MyBasicBlock
from MyDataPreparationLSTM import CustomDataset
from autoencoder import Encoder, Decoder

In [None]:
def drawing(tensor1, tensor2, indexes):
    for i in indexes:
        # Извлечение частей тензоров
        extracted_tensor1_0 = tensor1[i, 0, :, :]
        extracted_tensor1_1 = tensor1[i, 1, :, :]
        extracted_tensor2_0 = tensor2[i, 0, :, :]
        extracted_tensor2_1 = tensor2[i, 1, :, :]

        # Конвертация в numpy
        array1_0 = extracted_tensor1_0.detach().cpu().numpy()
        array1_1 = extracted_tensor1_1.detach().cpu().numpy()
        array2_0 = extracted_tensor2_0.detach().cpu().numpy()
        array2_1 = extracted_tensor2_1.detach().cpu().numpy()

        # Получение общих минимумов/максимумов для нормализации отображения
        vmin = min(array1_0.min(), array1_1.min(), array2_0.min(), array2_1.min())
        vmax = max(array1_0.max(), array1_1.max(), array2_0.max(), array2_1.max())

        # Построение графиков
        fig, axs = plt.subplots(2, 2, figsize=(12, 8))
        
        # Tensor 1
        axs[0, 0].imshow(array1_0, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[0, 0].set_title('Tensor1 U')

        axs[1, 0].imshow(array1_1, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[1, 0].set_title('Tensor1 V')

        # Tensor 2
        axs[0, 1].imshow(array2_0, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[0, 1].set_title('Tensor2 U')

        axs[1, 1].imshow(array2_1, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[1, 1].set_title('Tensor2 V')

        # Добавление единого цветового бара
        cbar = fig.colorbar(axs[0, 0].images[0], ax=axs, orientation='vertical', fraction=0.02, pad=0.04)
        cbar.ax.set_ylabel('wind speed')

        plt.tight_layout()
        plt.show()
        
        print('--------------------------------------------------------------------------------------------------')

In [3]:
run_name = 'wind_pre_lstm_run002'
pretrained_encoder_name = 'model_wind_pre_autoencoder_run002_encoder'
pretrained_decoder_name = 'model_wind_pre_autoencoder_run002_decoder'

In [None]:
pretrained_encoder = torch.load(f'/app/Kara_plume_movement/wind/models/{pretrained_encoder_name}.pth', map_location=torch.device('cpu'));
lstm = torch.load(f'/app/Kara_plume_movement/wind/models/model_{run_name}_lstm_network.pth', map_location=torch.device('cpu'));
decoder = torch.load(f'/app/Kara_plume_movement/wind/models/model_{run_name}_MLPdecoder.pth', map_location=torch.device('cpu'));
pretrained_decoder = torch.load(f'/app/Kara_plume_movement/wind/models/{pretrained_decoder_name}.pth', map_location=torch.device('cpu'));

In [5]:
pretrained_encoder.eval();
lstm.eval();
decoder.eval();
pretrained_decoder.eval();

In [6]:
pretrained_encoder = pretrained_encoder.cuda()
lstm = lstm.cuda()
decoder = decoder.cuda()
pretrained_decoder = pretrained_decoder.cuda()

In [7]:
from wind_pre_lstm_001 import find_files

In [8]:
wind_files_pkl = find_files('/mnt/hippocamp/asavin/data/wind/wind_arrays_kara_norm', '*.pkl')
wind_files_pkl.sort()

In [9]:
dataset = CustomDataset(wind_files_pkl, num_days=14, num_years=4)

In [10]:
batch_size = 32

In [11]:
dataset.select_random_years()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
data, next_day = next(iter(dataloader))
data.shape, next_day.shape

In [13]:
loss_function=torch.nn.MSELoss()

In [14]:
data_gpu = data.to(device='cuda', dtype=torch.float)
next_day_gpu = next_day.to(device='cuda', dtype=torch.float)

In [None]:
encoded_data_list = []
for t in range(data_gpu.shape[1]):
    encoded_data_list.append(pretrained_encoder.forward(data_gpu[:,t,:,:,:]))

wind_vector = torch.stack(encoded_data_list, dim=1)

encoded_data, (_, _) = lstm.forward(wind_vector)
result_lstm = encoded_data.reshape(encoded_data.shape[0], -1)
decoded_data = decoder.forward(result_lstm)
encoded_target = pretrained_encoder.forward(next_day_gpu)
loss = loss_function(encoded_target, decoded_data)
test_loss = loss.detach() * batch_size / len(dataset)
print(test_loss)

In [None]:
final_data = pretrained_decoder.forward(decoded_data)
final_data.shape

In [None]:
decoded_target = pretrained_decoder.forward(encoded_target)
decoded_target.shape

In [None]:
drawing(final_data, decoded_target, next_day_gpu, [i for i in range(next_day_gpu.shape[0])])