In [None]:
from netCDF4 import Dataset
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import datetime
from tqdm import tqdm
from glob import glob
import pickle
import random
import os
import fnmatch

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
import torch.nn.functional as F
from torchvision import datasets, transforms
from typing import Tuple, List, Type, Dict, Any
from torch.utils.tensorboard import SummaryWriter

from SGDR import CosineAnnealingWarmRestarts
from mish import Mish
from coord_conv import CoordConv
from MyResidualNetwork import MyResNet, MyBasicBlock
from MyDataPreparation import CustomDataset, Sampler
from autoencoder import Encoder, Decoder

In [None]:
def drawing(tensor1, tensor2, indexes):
    for i in indexes:
        extracted_tensor1_0 = tensor1[i, 0, :, :]
        extracted_tensor1_1 = tensor1[i, 1, :, :]
        extracted_tensor2_0 = tensor2[i, 0, :, :]
        extracted_tensor2_1 = tensor2[i, 1, :, :]

        array1_0 = extracted_tensor1_0.detach().cpu().numpy()
        array1_1 = extracted_tensor1_1.detach().cpu().numpy()
        array2_0 = extracted_tensor2_0.detach().cpu().numpy()
        array2_1 = extracted_tensor2_1.detach().cpu().numpy()

        vmin = min(array1_0.min(), array1_1.min(), array2_0.min(), array2_1.min())
        vmax = max(array1_0.max(), array1_1.max(), array2_0.max(), array2_1.max())

        fig, axs = plt.subplots(2, 2, figsize=(8, 8))

        cax1 = axs[0, 0].imshow(array1_0, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[0, 0].set_title('Sample U')

        cax2 = axs[0, 1].imshow(array1_1, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[0, 1].set_title('Sample V')

        cax3 = axs[1, 0].imshow(array2_0, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[1, 0].set_title('Decoded U')

        cax4 = axs[1, 1].imshow(array2_1, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
        axs[1, 1].set_title('Decoded V')

        cbar = fig.colorbar(cax1, ax=axs, orientation='vertical', fraction=0.02, pad=0.04)
        cbar.ax.set_ylabel('wind speed')

        plt.tight_layout()
        plt.show()
        
        print('--------------------------------------------------------------------------------------------------')

In [None]:
run_name = 'wind_pre_autoencoder_run004'

In [None]:
device = torch.device('cuda:1')

In [None]:
encoder = torch.load(f'/app/Kara_plume_movement/wind/models/model_{run_name}_encoder.pth', map_location=torch.device('cpu'));
decoder = torch.load(f'/app/Kara_plume_movement/wind/models/model_{run_name}_decoder.pth', map_location=torch.device('cpu'));

In [None]:
encoder.eval();
decoder.eval();

In [None]:
encoder = encoder.cuda()
decoder = decoder.cuda()

In [None]:
from wind_pre_autoencoder_004 import validate_single_epoch, find_files

In [None]:
wind_files_pkl = find_files('/mnt/hippocamp/asavin/data/wind/wind_arrays_kara_norm_n80_s70_w55_e105', '*.pkl')
wind_files_pkl.sort()

In [None]:
dataset = CustomDataset(wind_files_pkl, n_files=30)

In [None]:
batch_size = 32

In [None]:
dataset.make_new_data()
sampler = Sampler([i for i in range(dataset.wind_array.shape[0])], shuffle=True)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=8, sampler=sampler)

In [None]:
dataset.wind_array.shape

In [None]:
data = next(iter(dataloader))
data.shape

In [None]:
loss_function=torch.nn.MSELoss()

In [None]:
wind_gpu = data.to(device='cuda', dtype=torch.float)
encoded_data = encoder.forward(wind_gpu)
decoded_data = decoder.forward(encoded_data)
loss = loss_function(wind_gpu, decoded_data)
test_loss = loss.detach() * batch_size
dataset.clear_cache()

In [None]:
encoded_data.shape, encoded_data.type()

In [None]:
test_loss

In [None]:
drawing(wind_gpu, decoded_data, [i for i in range(wind_gpu.shape[0])])