In [None]:
import torch

from tqdm import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from pde.models.fuxi_v2 import FuxiV2VRWKV
from pde.train_utils.train_v3 import train
from pde.train_utils.dataloader_openstl import load_data
from pde.train_utils.losses import calculate_pde_and_continuity_loss, weighted_mae_torch

In [3]:
data_root = "/home/fa.buzaev/data_5/"
last_hour_in = 13
last_hour_out = 85
exp_root = 'exp3_default_params_vrwkv'
model_root = f"/home/asutkin/kursach/{exp_root}/best_models/model_with_fno.pt"

In [None]:
dataloader_train, dataloader_vali, dataloader_test, mean, std = load_data(batch_size=12,
                                                                          val_batch_size=12,
                                                                          data_root=data_root,
                                                                          num_workers=6,
                                                                          # data_split='2_8125',
                                                                          # data_name='mv3',
                                                                          data_split='5_625',
                                                                          data_name='uv10',
                                                                          # data_name='mv6',
                                                                          # train_time=['1979', '2015'],
                                                                          train_time=['2010', '2011'],
                                                                          val_time=['2016', '2016'],
                                                                          test_time=['2017', '2018'],
                                                                          idx_in=[*range(1, last_hour_in)],
                                                                          idx_out=[*range(last_hour_in, last_hour_out)],
                                                                          step=1,
                                                                          level=1,
                                                                          distributed=False, use_augment=False, 
                                                                          use_prefetcher=False, drop_last=False)

In [None]:
from tqdm import tqdm
import random
import json

criterion = torch.nn.L1Loss()
model = FuxiV2VRWKV(img_size=(2, 32, 64), patch_size=(2, 4, 4), in_chans=12, out_chans=2, embed_dim=192, num_groups=16, num_heads=8, window_size=7, depth=12).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
best_loss=float('inf')
with open(f"/home/asutkin/kursach/{exp_root}/logs.txt", 'a') as file:
    for i in range(5):
        model.train()
        train_loss = 0

        for x_train, y_train in tqdm(dataloader_train, desc='Training'):
            x_train, y_train = x_train.to(device), y_train.to(device)
            optimizer.zero_grad()
            big_loss = 0
            for j in range(1, last_hour_out - last_hour_in):
                time_tensor = j * torch.ones(x_train.shape[0], device=x_train.device).unsqueeze(-1)
                prediction = model(x_train, time_tensor)
                loss = criterion(prediction, y_train[:,j,:,:,:])
                x_train = torch.cat((x_train[:,1:,:,:,:], prediction.unsqueeze(1)), dim=1)
                big_loss += loss
            big_loss.backward()
            optimizer.step()
            train_loss += big_loss.item() / (last_hour_out - last_hour_in)
            # torch.cuda.empty_cache()
        train_loss /= len(dataloader_train)

        model.eval()
        val_loss = 0

        for x_val, y_val in tqdm(dataloader_vali, desc='Validating'):
            x_val, y_val = x_val.to(device), y_val.to(device)
            big_loss_val = 0
            for j in range(1, last_hour_out - last_hour_in):
                time_tensor = j * torch.ones(x_val.shape[0], device=x_val.device).unsqueeze(-1)
                prediction = model(x_val, time_tensor)
                loss = criterion(prediction, y_val[:,j,:,:,:])
                x_val = torch.cat((x_val[:,1:,:,:,:], prediction.unsqueeze(1)), dim=1)
                big_loss_val += loss
            val_loss += big_loss_val.item() / (last_hour_out - last_hour_in)
        val_loss /= len(dataloader_vali)

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': best_loss,
            }, model_root)

        print(f"Epoch: {i + 1}; Train_loss: {train_loss}; Vall_loss: {val_loss}")
        log = json.dumps({'Epoch': i+1, 'Train_loss': train_loss, 'Val_loss': val_loss})
        file.write(f"{log}\n")

In [None]:
from tqdm import tqdm

model = FuxiV2VRWKV(img_size=(2, 32, 64), patch_size=(2, 4, 4), in_chans=12, out_chans=2, embed_dim=192, num_groups=16, num_heads=8, window_size=7, depth=12).to(device)
criterion = torch.nn.L1Loss()
checkpoint = torch.load(model_root)
model.load_state_dict(checkpoint['model'])
model.eval()
running_loss = 0
for x_test, y_test in tqdm(dataloader_vali):
    x_test, y_test = x_test.to(device), y_test.to(device)
    big_loss = 0
    for i in range(1, (last_hour_out - last_hour_in)):
        time_tensor = i * torch.ones(x_test.shape[0], device=x_test.device).unsqueeze(-1)
        prediction = model(x_test, time_tensor)
        loss = criterion(prediction, y_test[:,i,:,:,:])
        torch.cat((x_test[:,1:,:,:,:], prediction.unsqueeze(1)), dim=1)
        big_loss += loss
    running_loss += big_loss.item() / (last_hour_out - last_hour_in)
running_loss / len(dataloader_train)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
def denorm(item, std, mean, idx=0):
    mean = mean.squeeze()[idx]
    std = std.squeeze()[idx]
    item_denorm = item * std + mean
    return item_denorm

for x_test, y_test in dataloader_test:
    x_test, y_test = x_test.to(device), y_test.to(device)
    break
    
x_test_ = torch.empty(x_test.shape, device=device)

model.eval()
t = 0
x_data = x_test

for i in range(1, 73):
    time_tensor = i * torch.ones(x_test.shape[0], device=x_test.device).unsqueeze(-1)
    with torch.no_grad():
        prediction = model(x_data, time_tensor)
    plt.figure(constrained_layout=True, figsize=(32, 6))
    x_data = torch.cat((x_data[:,1:,:,:,:], prediction.unsqueeze(1)), dim=1)
    
    plt.subplot(131)
    plt.imshow(denorm(prediction[7, 0], std, mean).squeeze().detach().cpu().numpy())
    plt.title(f"Prediction U wind by Andrew, step={t}")
    # plt.colorbar()

    plt.subplot(132)
    plt.imshow(denorm(y_test[7, t, 0], std, mean).squeeze().detach().cpu().numpy())
    plt.title(f"True answer, step={t}")
    plt.colorbar(boundaries=np.linspace(-20, 20, 20)) 

    plt.subplot(133)
    plt.imshow(np.abs(denorm(prediction[7, 0], std, mean).squeeze().detach().cpu().numpy() - denorm(y_test[7, t, 0], std, mean).squeeze().detach().cpu().numpy()))
    plt.title(f"Absolute difference, step={t}")
    plt.colorbar(boundaries=np.linspace(0, 20, 20))
    clear_output(wait=True)

    plt.savefig(f"/home/asutkin/kursach/{exp_root}/predictions/imvp_{t}.png")

    t += 1

In [None]:
import imageio

images = []
for i in range(0,72):
    images.append(imageio.imread(f"/home/asutkin/kursach/{exp_root}/predictions/imvp_{i}.png"))


imageio.mimsave(f"/home/asutkin/kursach/{exp_root}/predictions/wind_FNO.gif", images)