In [1]:
import os
import data_processing_tool as dpt
from datetime import timedelta, date, datetime
import argparse

import torch,os,torchvision
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from PIL import Image
import time
import xarray as xr
from sklearn.model_selection import StratifiedShuffleSplit

file_ACCESS_dir="/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
file_BARRA_dir="/g/data/ma05/BARRA_R/analysis/acum_proc"
ensemble=['e01','e02']
# ensemble=['e01','e02','e03','e04','e05','e06','e07','e08','e09','e10','e11']

leading_time=217
leading_time_we_use=31


init_date=date(1970, 1, 1)
start_date=date(1990, 1, 1)
end_date=date(1990,12,31) #if 929 is true we should substract 1 day
dates=[start_date + timedelta(x) for x in range((end_date - start_date).days + 1)]


In [2]:
parser = argparse.ArgumentParser(description='BARRA_R and ACCESS-S!')
def set_parser():
    parser.add_argument('--debug', action='store_true',
                        help='Enables debug mode')
    parser.add_argument('--template', default='.',
                        help='You can set various templates in option.py')

    # Hardware specifications
    parser.add_argument('--n_threads', type=int, default=3,
                        help='number of threads for data loading')
    parser.add_argument('--cpu', action='store_true',
                        help='use cpu only')
    parser.add_argument('--n_GPUs', type=int, default=2,
                        help='number of GPUs')
    parser.add_argument('--seed', type=int, default=1,
                        help='random seed')

    # Data specifications
    parser.add_argument('--file_ACCESS_dir', type=str, 
                        default="F:/climate/access-s1/pr/daily/",
    
                        help='dataset directory')
    parser.add_argument('--file_BARRA_dir', type=str, 
                        default="C:/Users/JIA059/barra/",
                        help='dataset directory')
    parser.add_argument('--nine2nine', type=bool, 
                        default=True,
                        help='whether rainfall acculate from 9am to 9am')
    parser.add_argument('--date_minus_one', type=int, 
                        default=1,
                        help='whether rainfall acculate from yesterday(1)/today(0) 9am to tody/tomorrow 9am')
    
    
    parser.add_argument('--dir_demo', type=str, default='../test',
                        help='demo image directory')
#     parser.add_argument('--data_train', type=str, default='BARRA_R',
#                         help='train dataset name')
#     parser.add_argument('--data_test', type=str, default='DIV2K',
#                         help='test dataset name')
    parser.add_argument('--benchmark_noise', action='store_true',
                        help='use noisy benchmark sets')
    parser.add_argument('--n_train', type=int, default=800,
                        help='number of training set')
    parser.add_argument('--n_val', type=int, default=10,
                        help='number of validation set')
    parser.add_argument('--offset_val', type=int, default=800,
                        help='validation index offest')
    parser.add_argument('--ext', type=str, default='sep',
                        help='dataset file extension')
    parser.add_argument('--scale', default='4',
                        help='super resolution scale')
    parser.add_argument('--patch_size', type=int, default=96,
                        help='output patch size')
    #??????????????????????????????????????????????????
    parser.add_argument('--rgb_range', type=int, default=300,
                        help='maximum value of RGB')
    parser.add_argument('--n_colors', type=int, default=1,
                        help='number of color channels to use')
    parser.add_argument('--noise', type=str, default='.',
                        help='Gaussian noise std.')
    parser.add_argument('--chop', action='store_true',
                        help='enable memory-efficient forward')

    # Model specifications
    parser.add_argument('--model', default='RCAN',
                        help='model name')

    parser.add_argument('--act', type=str, default='relu',
                        help='activation function')
    parser.add_argument('--pre_train', type=str, default='.',
                        help='pre-trained model directory')
    parser.add_argument('--extend', type=str, default='.',
                        help='pre-trained model directory')
    parser.add_argument('--n_resblocks', type=int, default=16,
                        help='number of residual blocks')
    parser.add_argument('--n_feats', type=int, default=64,
                        help='number of feature maps')
    parser.add_argument('--res_scale', type=float, default=1,
                        help='residual scaling')
    parser.add_argument('--shift_mean', default=True,
                        help='subtract pixel mean from the input')
    parser.add_argument('--precision', type=str, default='single',
                        choices=('single', 'half'),
                        help='FP precision for test (single | half)')

    # Training specifications
    parser.add_argument('--reset', action='store_true',
                        help='reset the training')
    parser.add_argument('--test_every', type=int, default=1000,
                        help='do test per every N batches')
    parser.add_argument('--epochs', type=int, default=3000,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=16,
                        help='input batch size for training')
    parser.add_argument('--split_batch', type=int, default=1,
                        help='split the batch into smaller chunks')
    parser.add_argument('--self_ensemble', action='store_true',
                        help='use self-ensemble method for test')
    parser.add_argument('--test_only', action='store_true',
                        help='set this option to test the model')
    parser.add_argument('--gan_k', type=int, default=1,
                        help='k value for adversarial loss')

    # Optimization specifications
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--lr_decay', type=int, default=200,
                        help='learning rate decay per N epochs')
    parser.add_argument('--decay_type', type=str, default='step',
                        help='learning rate decay type')
    parser.add_argument('--gamma', type=float, default=0.5,
                        help='learning rate decay factor for step decay')
    parser.add_argument('--optimizer', default='ADAM',
                        choices=('SGD', 'ADAM', 'RMSprop'),
                        help='optimizer to use (SGD | ADAM | RMSprop)')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='SGD momentum')
    parser.add_argument('--beta1', type=float, default=0.9,
                        help='ADAM beta1')
    parser.add_argument('--beta2', type=float, default=0.999,
                        help='ADAM beta2')
    parser.add_argument('--epsilon', type=float, default=1e-8,
                        help='ADAM epsilon for numerical stability')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')

    # Loss specifications
    parser.add_argument('--loss', type=str, default='1*L1',
                        help='loss function configuration')
    parser.add_argument('--skip_threshold', type=float, default='1e6',
                        help='skipping batch that has large error')

    # Log specifications
    parser.add_argument('--save', type=str, default='RCAN',
                        help='file name to save')
    parser.add_argument('--load', type=str, default='.',
                        help='file name to load')
    parser.add_argument('--resume', type=int, default=0,
                        help='resume from specific checkpoint')
    parser.add_argument('--print_model', action='store_true',
                        help='print model')
    parser.add_argument('--save_models', action='store_true',
                        help='save all intermediate models')
    parser.add_argument('--print_every', type=int, default=100,
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save_results', action='store_true',
                        help='save output results')

    # New options
    parser.add_argument('--n_resgroups', type=int, default=10,
                        help='number of residual groups')
    parser.add_argument('--reduction', type=int, default=16,
                        help='number of feature maps reduction')
    parser.add_argument('--testpath', type=str, default='../test/DIV2K_val_LR_our',
                        help='dataset directory for testing')
    parser.add_argument('--testset', type=str, default='Set5',
                        help='dataset name for testing')
    parser.add_argument('--degradation', type=str, default='BI',
                        help='degradation model: BI, BD')

    args = parser.parse_args(args=[])
    # args = parser.parse_args()
#     template.set_template(args)

    args.scale = list(map(lambda x: int(x), args.scale.split('+')))
    
    if args.epochs == 0:
        args.epochs = 1e8

    for arg in vars(args):
        if vars(args)[arg] == 'True':
            vars(args)[arg] = True
        elif vars(args)[arg] == 'False':
            vars(args)[arg] = False
    return args

args=set_parser()
# args.template.find("DDBPN")


In [6]:

# domain = [111.975, 156.275, -44.525, -9.975]

file_ACCESS_dir="F:/climate/access-s1/pr/daily/"#"/g/data/ub7/access-s1/hc/raw_model/atmos/pr/daily/"
file_BARRA_dir="F:/climate/barra/"

class ACCESS_BARRA_v1(Dataset):
    '''
    scale is size(hr)=size(lr)*scale
    version_1_documention: the data we use is raw data that store at NCI
    '''
    def __init__(self,start_date=date(1990, 1, 1),end_date=date(1990,12 , 31),regin="AUS",transform=None,args=args):
        self.file_BARRA_dir = args.file_BARRA_dir
        self.file_ACCESS_dir = args.file_ACCESS_dir
        
        self.transform = transform
        self.start_date = start_date
        self.end_date = end_date
        
        self.scale = args.scale[0]
        self.regin = regin
        
        if regin=="AUS":
            self.shape=(314,403,1,1)
            self.domain=[111.975, 156.275, -44.525, -9.975]
        else:
            self.shape=(768,1200,1,1)
                
        self.dates = self.date_range(start_date, end_date)
        
        
        self.filename_list=self.get_filename_with_time_order(args.file_ACCESS_dir)

        
        
    def __len__(self):
        return len(self.filename_list)
    

    def date_range(self,start_date, end_date):
        """This function takes a start date and an end date as datetime date objects.
        It returns a list of dates for each date in order starting at the first date and ending with the last date"""
        return [start_date + timedelta(x) for x in range((end_date - start_date).days + 1)]

    
    def get_filename_with_no_time_order(self,rootdir):
        '''get filename first and generate label '''
        _files = []
        list = os.listdir(rootdir) #列出文件夹下所有的目录与文件
        for i in range(0,len(list)):
            path = os.path.join(rootdir,list[i])
            if os.path.isdir(path):
                _files.extend(self.get_filename_with_no_time_order(path))
            if os.path.isfile(path):
                if path[-3:]==".nc":
                    _files.append(path)
        return _files
    
    def get_filename_with_time_order(self,rootdir):
        '''get filename first and generate label '''
        _files = []
        for en in ensemble:
            for date in dates:
                filename="da_pr_"+date.strftime("%Y%m%d")+"_"+en+".nc"
                access_path=rootdir+en+"/"+"da_pr_"+date.strftime("%Y%m%d")+"_"+en+".nc"
#                 print(access_path)
                if os.path.exists(access_path):
                    for i in range(leading_time_we_use):
                        path=[access_path]
                        
#                         barra_path=file_BARRA_dir+"/accum_prcp-an-spec-PT0H-BARRA_R-v1-"+((date+timedelta(i)).strftime("%Y%m%d"))
                        barra_date=date+timedelta(i)
#                         self.data_dir+date.strftime('%m')+"/accum_prcp-an-spec-PT0H-BARRA_R-v1-"\
#                         +date.strftime('%Y%m%d')+"T"+enum[i]+"Z.nc"
                        path.append(barra_date)
                        path.append(i)
#                         print(path)
                        _files.append(path)
    
    #最后去掉第一行，然后shuffle
        if args.nine2nine and args.date_minus_one==1:
            del _files[0]
        return _files

    

        
    def __getitem__(self, idx):
        '''
        from filename idx get id
        return lr,hr
        '''
        #read_data filemame[idx]
        access_filename,date_for_BARRA,time_leading=self.filename_list[idx]
#         print(type(date_for_BARRA))
#         low_filename,high_filename,time_leading=self.filename_list[idx]

        data_low=dpt.read_access_data(access_filename,idx=idx)
        train_data=dpt.map_aust(data_low)
        
        domain = [train_data.lon.data.min(), train_data.lon.data.max(), train_data.lat.data.min(), train_data.lat.data.max()]
        print(domain)

        data_high=dpt.read_barra_data_fc(self.file_BARRA_dir,date_for_BARRA,nine2nine=False)
        label=dpt.map_aust(data_high,domain=domain)
        
#         print(train_data.shape,label.shape)
#         print(train_data.shape[0]*4,train_data.shape[1]*4)
#         print(label.shape[0]/4,label.shape[1]/4)

        return train_data.shape,label.shape

    

# ACCESS_BARRA(file_access_dir,file_BARRA_dir).filename_list
data_set=ACCESS_BARRA_v1(args=args)
print(len(data_set))
# for i in data_set.filename_list:
#     print(i)
print(data_set[0])


2789
[112.08333, 156.25, -44.166664, -10.277771]
((62, 54), (308, 402))
