In [15]:
import os
import torch
from copy import deepcopy
import numpy as np
import xarray as xr
import pandas as pd
import torch.nn as nn
import random
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

In [16]:
#设置种子
def set_seed(seed = 427):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)

In [17]:
train = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_label.nc')
train_nino= train['nino'].values

In [18]:
train_nino.shape

(4645, 36)

(55740, 24, 72)

In [56]:
train = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_train.nc')
label = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_label.nc')    
train_sst = train['sst'].values
train_sst= np.concatenate((train_sst[:151*5],train_sst[151*9:151*12],train_sst[151*13:]))  
train_sst=split_month(train_sst,1000)

In [57]:
train_sst.shape

(11976, 12, 24, 72)

In [64]:
a=split_month_label(train_label,100)

In [70]:
train_label.shape

(3890, 36)

In [114]:
size=100
temp=train_label[:size,0:12]
temp=temp.reshape(size*12)

In [115]:
temp.shape

(1200,)

In [121]:
aa=[temp[i+12:i+36] for i in range(size*12-40)]

In [122]:
np.array(aa).shape

(1160, 24)

In [117]:
len(aa)

1176

(13,)

In [63]:
split_month(train_sst,1000).shape

(11976, 12, 24, 72)

In [87]:
def split_month(array,size):#input shape: :,36,24,72
    temp=array[:size,0:12,:,:]
    temp=temp.reshape(size*12,24,72)
    temp2=np.array([temp[i:i+12,:,:] for i in range(size*12-24)])
    return temp2

def split_month_label(array,size):#input shape: :,24
    temp=array[:size,0:12]
    temp=temp.reshape(size*12)
    temp2=np.array([temp[i+12:i+36] for i in range(size*12-24)])
    return temp2
    
def load_data2():
    # CMIP data    
    size1=500
    train = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_train.nc')
    label = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_label.nc')    
    train_sst = train['sst'].values
    train_sst= np.concatenate((train_sst[:151*5],train_sst[151*9:151*12],train_sst[151*13:]))  
    train_sst=split_month(train_sst,size1)
    train_t300 = train['t300'].values
    train_t300= np.concatenate((train_t300[:151*5],train_t300[151*9:151*12],train_t300[151*13:]))
    train_t300=split_month(train_t300,size1)
    train_ua = train['ua'].values
    train_ua= np.concatenate((train_ua[:151*5],train_ua[151*9:151*12],train_ua[151*13:])) 
    train_ua=split_month(train_ua,size1)
    train_va = train['va'].values
    train_va= np.concatenate((train_va[:151*5],train_va[151*9:151*12],train_va[151*13:]))
    train_va=split_month(train_va,size1)
    train_label = label['nino'].values
    train_label= np.concatenate((train_label[:151*5],train_label[151*9:151*12],train_label[151*13:]))
    train_label=split_month_label(train_label,size1)
    
    #train_ua = np.nan_to_num(train_ua)#缺失值补0
    #train_va = np.nan_to_num(train_va)
    #train_t300 = np.nan_to_num(train_t300)
    #train_sst = np.nan_to_num(train_sst)

    # SODA data  
    size2=100
    train2 = xr.open_dataset('../tcdata/enso_round1_train_20210201/SODA_train.nc')
    label2 = xr.open_dataset('../tcdata/enso_round1_train_20210201/SODA_label.nc')
    
    train_sst2 = train2['sst'].values  # (3890, 12, 24, 72)
    train_sst2=split_month(train_sst2,size2)
    train_t3002 = train2['t300'].values
    train_t3002=split_month(train_t3002,size2)
    train_ua2 = train2['ua'].values
    train_ua2=split_month(train_ua2,size2)
    train_va2 = train2['va'].values
    train_va2=split_month(train_va2,size2)
    train_label2 = label2['nino'].values
    train_label2=split_month_label(train_label2,size2)

    print('Train samples: {}, Valid samples: {}'.format(len(train_label), len(train_label2)))

    dict_train = {
        'sst':train_sst,
        't300':train_t300,
        'ua':train_ua,
        'va': train_va,
        'label': train_label}
    dict_valid = {
        'sst':train_sst2,
        't300':train_t3002,
        'ua':train_ua2,
        'va': train_va2,
        'label': train_label2}
    train_dataset = EarthDataSet(dict_train)
    valid_dataset = EarthDataSet(dict_valid)
    return train_dataset, valid_dataset

class EarthDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['sst'])

    def __getitem__(self, idx):   
        return (self.data['sst'][idx], self.data['t300'][idx], self.data['ua'][idx], self.data['va'][idx]), self.data['label'][idx]

In [66]:
train = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_train.nc')
label = xr.open_dataset('../tcdata/enso_round1_train_20210201/CMIP_label.nc')    
train_sst = train['sst'].values
train_sst= np.concatenate((train_sst[:151*5],train_sst[151*9:151*12],train_sst[151*13:])) 

In [91]:
a=train_sst[:10]

In [92]:
a.shape

(10, 36, 24, 72)

In [83]:
#a=a.reshape(1,10,36,24,72)

In [85]:
#a.shape

(1, 10, 36, 24, 72)

In [96]:
b=np.array([a,a])
b.shape

(2, 10, 36, 24, 72)

In [59]:
res1,res2=load_data2()

Train samples: 5976, Valid samples: 1176
