In [None]:
#前提ライブラリ
import random
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
import pandas as pd
from dataclasses import dataclass

#torch関連
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from torch import nn
from transformers import CLIPVisionConfig

#作成したモデルのインポート
from models.Denoisingmodel import DenoisingModel
from models.Diffusionmodel import Diffuser

#その他
import math
import datetime
from tqdm import tqdm

In [None]:
'''
条件設定
シード固定
'''


def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

@dataclass
class TrainingConfig:

    image_size = 64 
    train_batch_size = 32
    

    num_epochs = 20

    num_timesteps = 1000
    
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500

    save_model_epochs = 50

    seed = 42

    clipconf = CLIPVisionConfig(
        projection_dim = 512,
        num_channels = 1,
        image_size = 64
        )
config = TrainingConfig()
torch_fix_seed(config.seed)

In [None]:
img_size = config.image_size
batch = config.train_batch_size
epochs = config.num_epochs
lr_rate = config.learning_rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = DenoisingModel(config)
save_path  = '/home/kurochan/sakuma/kenkiu/workspace_DLSR/DLSR-main/saved_model/'

model.to(device)
model.load_state_dict(torch.load(save_path+'TMP_model1220.pth', map_location="cuda",weights_only=False))

In [None]:
def make_tensor(data='TMP',area='manji',years=[2021,2022,2023],source='/mnt/nadaru/trainingdataset'):
    dss = []
    for y in years:
        filename = f"{data}_{y}_{area}.nc"
        filepath = os.path.join(source,filename)
        try:
            d = xr.load_dataset(filepath)
            dss.append(d)
            print(filename)
        except:
            pass
    ds = xr.concat(dss,dim="t")
    if np.sum(np.isnan(ds['amgsd'].values)) != 0:
        print('Warning:nan detected')
        ds.dropna(dim='y',how='any')
    return ds

#データセットをSSDから（これが時間かかる）
print('dataset loading三(　ﾟ∀ﾟ)...')
ds = make_tensor()

lr = ds['WRF_1km'].to_numpy()
hr = ds['WRF_300m'].values
amgsd = ds['amgsd'].values
timeline = ds['WRF_300m'].t
timecode = np.arange(len(timeline))

hr_mean= hr.mean()
hr_std = hr.std()

print('dataset to tensor(jstammt)...')
lr_tensor = torch.tensor(lr.astype('float'),dtype=torch.float32)
hr_tensor = torch.tensor(hr.astype('float'),dtype=torch.float32)
timecode_tensor =  torch.tensor(timecode.astype('float'))
amgsd_tensor = torch.tensor(amgsd.astype('float'))


trans = torchvision.transforms.Compose([
                                        torchvision.transforms.Resize(size=(64, 64)),
                                        torchvision.transforms.Normalize((hr_mean), (hr_std))])

class WRFdatasets(torch.utils.data.Dataset):
    def __init__(self, LR, HR, amd, timeline, transform = None):
        self.transform = transform

        self.lr = LR.unsqueeze(1)
        self.hr = HR.unsqueeze(1)

        self.time = timeline.to(torch.int)
        self.amd  = amd.unsqueeze(1)
        
        self.datanum = len(timeline)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_lr = self.lr[idx]
        out_hr = self.hr[idx]

        out_amd = self.amd[idx]
        out_time = self.time[idx]
        
        if self.transform:
            out_lr = self.transform(out_lr)
            out_hr = self.transform(out_hr)
            out_amd = self.transform(out_amd)
            

        return out_lr, out_hr, out_amd, out_time

print('dataset making...')
dataset = WRFdatasets(lr_tensor, hr_tensor, amgsd_tensor, timecode_tensor, transform=trans)

# 学習データ、検証データに 8:2 の割合で分割する。
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)
print('kansei')
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch, shuffle = True, num_workers = 2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size = batch, shuffle = False, num_workers = 2)

In [None]:
invTrans = transforms.Compose([ transforms.Normalize(mean = 0.,
                                                     std = 1/hr_std),
                                transforms.Normalize(mean = -hr_mean,
                                                     std = 1),
                               ])

#assimilation
'''
テストデータ同化出力
観測点数:300
'''
model.eval()
diffuser = Diffuser(config.num_timesteps, device=device)

obs_num = [300]
denoised = {'denoised':[],'WRF_1km':[],'WRF_333m':[],'amgsd':[],'time':[]}

with torch.no_grad():
    for low, high, amd, time in testloader:
        #条件付を削除
        #low = torch.tensor(np.zeros(low.shape))
        
        low = low.to(device)
        high = high.to(device)
        amd = amd.to(device)

        for obs in obs_num:
            denoised_obs, array = diffuser.sample_asim(model, low, amd, gamma=1.0, asim_sample=obs)
            
            result = denoised_obs.to('cpu')
            images = invTrans(result)
            images = images.numpy()
            
            denoised['denoised'.format(obs)].append(images)
            denoised['WRF_1km'].append(invTrans(low.to('cpu')).numpy())
            denoised['WRF_333m'].append(invTrans(high.to('cpu')).numpy())
            denoised['amgsd'].append(invTrans(amd.to('cpu')).numpy())
            denoised['time'].append(time.to('cpu')).numpy()


    denoised['denoised'] = np.concatenate(denoised['denoised'], 0).squeeze()
    denoised['WRF_333m'] = np.concatenate(denoised['WRF_333m'], 0).squeeze()
    denoised['WRF_1km'] = np.concatenate(denoised['WRF_1km'], 0).squeeze()
    denoised['amgsd'] = np.concatenate(denoised['amgsd'], 0).squeeze()
    denoised['time'] = np.concatenate(denoised['time'], 0)
    