In [20]:
from ConvLSTM import Seq2Seq
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Adam
from torch.utils.data import DataLoader
import xarray as xr
from datetime import date
from pathlib import Path


SEQ_LEN = 30
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BASEPATH = Path(r"C:\Users\Casper\OneDrive - Danmarks Tekniske Universitet\SKOLE\Kandidat\Syntese\ProcessedGrids")

In [2]:
with xr.open_dataset(BASEPATH / "without_polar_v5_mss21.nc") as file:
    sla = file['sla'].data
    times = file['time'].data
    lat = file['Latitude'].data
    lon = file['Longitude'].data

In [13]:
TRAIN_END = np.array(date(2014, 1, 1)).astype("datetime64[ns]")
TEST_END = np.array(date(2019, 1, 1)).astype("datetime64[ns]")

In [16]:
train_features = sla[times <= TRAIN_END]
validation_features = sla[(times > TRAIN_END) & (times <= TEST_END)]
test_features = sla[times > TEST_END]

In [75]:
# Map:
# (num_channels, height, width) -> (batch_size, num_channels, seq_len, height, width)
def collate(batch):

    # Add channel dim, scale pixels between 0 and 1, send to GPU
    batch = torch.tensor(batch).unsqueeze(1)     
    batch = batch.to(DEVICE)                     

    return batch[:,:,:-1], batch[:,:,-1] 

In [84]:
SEQ_LEN = 2
BATH_SIZE = 5

In [82]:
num_frames = tf.shape[0]
seq_len_result = SEQ_LEN + 1

# Create an empty array to store the reshaped movie
reshaped_train_features = np.empty((num_frames - seq_len_result, seq_len_result, tf.shape[-2], tf.shape[-1]))

# Iterate over each frame and create the reshaped movie
for i in range(num_frames - seq_len_result):
    reshaped_train_features[i] = tf[i:i + seq_len_result]

In [92]:
# Training Data Loader
train_loader = DataLoader(reshaped_train_features, shuffle = False, batch_size = BATH_SIZE, collate_fn = collate)
# Get a batch
input_features, result = next(iter(train_loader))

In [99]:
print("Shape of input grid:")
dims = ("batch_size", "num_channels", "seq_len", "height", "width")
for s, name in zip((input_features.shape), dims):
    print(f"\t{name} : {s}")
dims = ("batch_size", "seq_len", "height", "width")
print("Shape of output grid:")
for s, name in zip((result.shape), dims):
    print(f"\t{name} : {s}")

Shape of input grid:
	batch_size : 5
	num_channels : 1
	seq_len : 2
	height : 129
	width : 360
Shape of output grid:
	batch_size : 5
	seq_len : 1
	height : 129
	width : 360
