In [1]:
import os
import data_processing_tool as dpt
from datetime import timedelta, date, datetime
# import args_parameter as args
import torch,torchvision
import numpy as np

from torch.utils.data import Dataset,random_split
from torchvision import datasets, models, transforms

import time
import xarray as xr

class ACCESS_BARRA_v3(Dataset):
    '''
    scale is size(hr)=size(lr)*scale
    version_3_documention: compare with ver1, I modify:
    1. access file is created on getitem,the file list is access_date,barra,barra_date,time_leading
      in order to read more data like zg etc. more easier, we change access_filepath to access_date

    2. in ver.3, I extend the demention of the input data DEM.and change the domain to fit the size of dem. the shape also can be divided by 4
   
    '''
    def __init__(self,start_date=date(1990, 1, 1),end_date=date(1990,12 , 31),regin="AUS",transform=None,train=True,args=None):
        print("=> BARRA_R & ACCESS_S1 loading")
        print("=> from "+start_date.strftime("%Y/%m/%d")+" to "+end_date.strftime("%Y/%m/%d")+"")
        self.file_BARRA_dir = args.file_BARRA_dir
        self.file_ACCESS_dir = args.file_ACCESS_dir
        self.args=args
        
        self.transform = transform
        self.start_date = start_date
        self.end_date = end_date
        
        self.scale = args.scale[0]
        self.regin = regin
        self.leading_time=217
        self.leading_time_we_use=args.leading_time_we_use

        self.ensemble_access=['e01','e02','e03','e04','e05','e06','e07','e08','e09','e10','e11']
        self.ensemble=[]
        for i in range(args.ensemble):
            self.ensemble.append(self.ensemble_access[i])
                
        self.dates = self.date_range(start_date, end_date)
        
        
        self.filename_list=self.get_filename_with_time_order(args.file_ACCESS_dir+"pr/daily/")
        if not os.path.exists(args.file_ACCESS_dir+"pr/daily/"):
            print(args.file_ACCESS_dir+"pr/daily/")
            print("no file or no permission")
        
        
        _,_,date_for_BARRA,time_leading=self.filename_list[0]
        if not os.path.exists("/g/data/ma05/BARRA_R/v1/forecast/spec/accum_prcp/1990/01/accum_prcp-fc-spec-PT1H-BARRA_R-v1-19900109T0600Z.sub.nc"):
            print(self.file_BARRA_dir)
            print("no file or no permission!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        data_high=dpt.read_barra_data_fc(self.file_BARRA_dir,date_for_BARRA,nine2nine=True)
        data_exp=dpt.map_aust(data_high,domain=args.domain,xrarray=True)#,domain=domain)
        self.lat=data_exp["lat"]
        self.lon=data_exp["lon"]
        self.shape=(81,108)
        if self.args.dem:
            self.dem_data=dpt.interp_tensor_2d(dpt.read_dem(args.file_DEM_dir+"dem-9s1.tif"),self.shape )
        

        
    def __len__(self):
        return len(self.filename_list)
    

    def date_range(self,start_date, end_date):
        """This function takes a start date and an end date as datetime date objects.
        It returns a list of dates for each date in order starting at the first date and ending with the last date"""
        return [start_date + timedelta(x) for x in range((end_date - start_date).days + 1)]

    
    def get_filename_with_no_time_order(self,rootdir):
        '''get filename first and generate label '''
        _files = []
        list = os.listdir(rootdir) #列出文件夹下所有的目录与文件
        for i in range(0,len(list)):
            path = os.path.join(rootdir,list[i])
            if os.path.isdir(path):
                _files.extend(self.get_filename_with_no_time_order(path))
            if os.path.isfile(path):
                if path[-3:]==".nc":
                    _files.append(path)
        return _files
    
    def get_filename_with_time_order(self,rootdir):
        '''get filename first and generate label ,one different w'''
        _files = []
        for en in self.ensemble:
            for date in self.dates:
                
                    
                
                filename="da_pr_"+date.strftime("%Y%m%d")+"_"+en+".nc"
                access_path=rootdir+en+"/"+"da_pr_"+date.strftime("%Y%m%d")+"_"+en+".nc"
                if os.path.exists(access_path):
                    for i in range(self.leading_time_we_use):
                        if date==self.end_date and i==1:
                            break
                        path=[access_path]
                        barra_date=date+timedelta(i)
                        path.append(date)
                        path.append(barra_date)
                        path.append(i)
                        _files.append(path)
    
    #最后去掉第一行，然后shuffle
        if self.args.nine2nine and self.args.date_minus_one==1:
            del _files[0]
        return _files

    

        
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        t=time.time()
        
        #read_data filemame[idx]
        access_filename_pr,access_date,date_for_BARRA,time_leading=self.filename_list[idx]
#         print(type(date_for_BARRA))
#         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low=dpt.read_access_data(access_filename_pr,idx=time_leading)
        lr_raw=dpt.map_aust(data_low,domain=self.args.domain,xrarray=False)
        lr=np.expand_dims(dpt.interp_tensor_2d(lr_raw,self.shape),axis=2)
        
#         domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
#         print(domain)

        data_high=dpt.read_barra_data_fc(self.file_BARRA_dir,date_for_BARRA,nine2nine=True)
        label=dpt.map_aust(data_high,domain=self.args.domain,xrarray=False)#,domain=domain)

        
        if self.args.zg:
            access_filename_zg=self.args.file_ACCESS_dir+"zg/daily/"+en+"/"+"da_zg_"+access_date.strftime("%Y%m%d")+"_"+en+".nc"
            data_zg=dpt.read_access_zg(access_filename_zg,idx=time_leading)
            data_zg_aus=map_aust(data_zg,xrarray=False)
            lr_zg=dpt.interp_tensor_3d(data_zg_aus,self.shape)
            lr=np.concatenate(lr,np.expand_dims(lr_zg,axis=2),axis=2)
        
        if self.args.psl:
            access_filename_psl=self.args.file_ACCESS_dir+"psl/daily/"+en+"/"+"da_psl_"+access_date.strftime("%Y%m%d")+"_"+en+".nc"
            data_psl=dpt.read_access_data(access_filename_psl,idx=time_leading)
            data_psl_aus=map_aust(data_psl,xrarray=False)
            lr_psl=dpt.interp_tensor_2d(data_psl_aus,self.shape)
            lr=np.concatenate(lr,np.expand_dims(lr_psl,axis=2),axis=2)
        if self.args.tasmax:
            access_filename_tasmax=self.args.file_ACCESS_dir+"tasmax/daily/"+en+"/"+"da_tasmax_"+access_date.strftime("%Y%m%d")+"_"+en+".nc"
            data_tasmax=dpt.read_access_data(access_filename_tasmax,idx=time_leading)
            data_tasmax_aus=map_aust(data_tasmax,xrarray=False)
            lr_tasmax=dpt.interp_tensor_2d(data_tasmax_aus,self.shape)
            lr=np.concatenate(lr,np.expand_dims(lr_tasmax,axis=2),axis=2)

            
        if self.args.tasmax:
            access_filename_tasmin=self.args.file_ACCESS_dir+"tasmin/daily/"+en+"/"+"da_tasmin_"+access_date.strftime("%Y%m%d")+"_"+en+".nc"
            data_tasmin=dpt.read_access_data(access_filename_tasmin,idx=time_leading)
            data_tasmin_aus=map_aust(data_tasmin,xrarray=False)
            lr_tasmin=dpt.interp_tensor_2d(data_tasmin_aus,self.shape)
            lr=np.concatenate(lr,np.expand_dims(lr_tasmin,axis=2),axis=2)
        if self.args.dem:
            access_filename_tasmin=self.args.file_ACCESS_dir+"tasmin/daily/"+en+"/"+"da_tasmin_"+access_date.strftime("%Y%m%d")+"_"+en+".nc"
            data_tasmin=dpt.read_access_data(access_filename_tasmin,idx=time_leading)
            data_tasmin_aus=map_aust(data_tasmin,xrarray=False)
            lr_tasmin=dpt.interp_tensor_2d(data_tasmin_aus,self.shape)
            lr=np.concatenate(lr,np.expand_dims(lr_tasmin,axis=2),axis=2)

            
            
        print("end loading one data,time cost %f"%(time.time()-t))

        if self.transform:#channel 数量需要整理！！
            return self.transform(lr*86400),self.transform(label),torch.tensor(int(date_for_BARRA.strftime("%Y%m%d"))),torch.tensor(time_leading)
        else:
            return lr*86400,label,torch.tensor(int(date_for_BARRA.strftime("%Y%m%d"))),torch.tensor(time_leading)
#         return np.reshape(train_data,(78,100,1))*86400,np.reshape(label,(312,400,1))



In [2]:
## data stack

dem_data=dpt.interp_tensor_2d(dpt.read_dem("../../../DEM/"+"dem-9s1.tif"),(81,108))



In [3]:
data_low=dpt.read_access_data("E:/climate/access-s1/pr/daily/e01/da_pr_19900101_e01.nc",idx=0)
lr_raw=dpt.map_aust(data_low,domain=[112.9, 154.00, -43.7425, -9.0],xrarray=False)
lr_3=dpt.interp_tensor_2d(lr_raw,(81,108))
lr_2=dpt.interp_tensor_2d(lr_raw,(81,108))

In [5]:

print(lr_3.shape)
print(lr_2.shape)

(81, 108)
(81, 108)


In [24]:
a=np.concatenate((np.expand_dims(lr_2,axis=2),np.expand_dims(dem_data,axis=2)),axis=2)
a.shape

(81, 108, 2)

In [28]:

b=np.concatenate((a,np.expand_dims(lr_3,axis=2)),axis=2)
# b=np.stack((a,np.expand_dims(lr_3,axis=2)),axis=2)

b.shape

(81, 108, 3)