In [None]:
# headers
from fastai.vision import *
from fastai.vision.gan import *
from fastai import *
import os
import random
random.seed(0)
import time
import numpy as np
import pandas as pd
import torch
import skimage.io as io
from torch import nn
from torch.backends import cudnn
import torch.utils.data as data
from torch import optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import functional as F
from torch.autograd import Variable as V
from tqdm.notebook import tqdm, trange
import tifffile as tff
import re
from network import SSLI

In [None]:
def is_clear(qa_value):
    Cirrus   = (qa_value>>0) & 1
    Cloud    = (qa_value>>1) & 1
    Adjacent = (qa_value>>2) & 1
    Shadow   = (qa_value>>3) & 1
    Snow     = (qa_value>>4) & 1
    is_clear = ( (Cloud==0) & (Cirrus==0) & (Adjacent==0) & (Shadow==0) & (Snow==0))
    return is_clear
def date_to_string(target):
    if target < 10: 
        target_str = f'00{target}'
    elif target <100:
        target_str = f'0{target}'
    else: target_str = str(target)
    if target <= 0: 
        print('Receiving date 0')
        return '000'
    return target_str
def search_pair(target, path_L, date_L, path_S, date_S, radius=1):
    L_dates = []
    S_dates = []
    for date in range(target-radius, target+radius+1):
        if date in date_L: L_dates.insert(len(L_dates), date)
        if date in date_S: S_dates.insert(len(S_dates), date)
    return L_dates, S_dates

def predict_full(ts, model):
    qa = ts[:,7]
    pred = np.zeros((ts.shape[0],9))
    pred[:,:8] = np.ones_like(ts)*ts
    pred[:,:7][qa == 0] = -1
    pred[:,8][qa == 1] = 1
    pred[:,8][qa == 0] = -1
    pred[:,7] = np.arange(ts.shape[0]) / 100.
    x = torch.from_numpy(pred).float().cuda().unsqueeze(0)
    y = model(x).cpu().data.numpy()
    pred[:,:7][qa==0] = y[0][qa==0]
    pred = pred[:,:7]*std+mean
    return pred
mean = np.array([[546.7921,  649.3897,  967.5513, 1159.1942, 2441.7659, 2432.7739, 1822.7062]])
std = np.array([[447.3489,  519.1541,  668.2482,  904.4891, 1167.2008, 1226.2954, 1155.0743]])
class Loader_pred(data.Dataset):
    def __init__(self, nts, index_list,pred_date):
        self.nts = nts
        self.index_list = index_list
        self.mask_index = pred_date
    def __getitem__(self,index):
        pos_index = self.index_list[index]
        mask_index = self.mask_index
        ts = np.zeros((self.nts.shape[1], 9))
        ts[:,:7] = self.nts[pos_index,:,:7].astype(np.float64)
        ts[:,:7] = (ts[:,:7] - mean) / std
        ts[:,7] = np.arange(self.nts.shape[1]) / 100.
        qa = self.nts[pos_index,:,7].astype(np.float64)
        qa[mask_index] = 0
        ts[:,8][qa == 1] = 1
        ts[mask_index,:7] = 0
        ts[:,:7][qa == 0] = 0
        ts[mask_index, 7] = (np.arange(self.nts.shape[1]) / 100.)[mask_index]
        ts = torch.from_numpy(ts).float()        
        return ts, index
    def __len__(self):
        return len(self.index_list)

In [None]:
def fuse_HLS(dates, year, target_dir):
    path_L = os.listdir(f'dataset/L30/{year}/{target_dir}')
    path_L.sort()
    path_L = path_L
    date_L = [int(name[19:22]) for name in path_L]
    path_S = os.listdir(f'dataset/S30/{year}/{target_dir}')
    path_S.sort()
    path_S = path_S
    date_S = [int(name[19:22]) for name in path_S]
    L_head = path_L[0][:19]
    L_tail = path_L[0][22:]
    S_head = path_S[0][:19]
    S_tail = path_S[0][22:]
    L_bands = ['B01', 'B02','B03','B04','B05','B06','B07','Fmask']
    S_bands = ['B01', 'B02','B03','B04','B8A','B11','B12','Fmask']
    fuse_cube = np.ones((dates.shape[0], 3660,3660,9), dtype=np.int32) * -9999
    fuse_cube[:,:,:,8] = 0
    for d_i, target_date in enumerate(tqdm(dates)):
        L_dates, S_dates = search_pair(target_date, path_L, date_L, path_S, date_S)
        total_dates = len(L_dates)+len(S_dates)
        if total_dates == 0: continue
        else: base_img = np.empty((total_dates,3660,3660,9), dtype=np.int32)
        l_i = 0
        if len(L_dates) > 0:
            for l_i, d in enumerate(L_dates):
                folder_head = f'{L_head}{date_to_string(d)}'
                for p in path_L: 
                    if re.findall(f'{folder_head}.', p): folder_name=p
                for b_i, b in enumerate(L_bands):
                    base_img[l_i, :, :, b_i] = io.imread(f'dataset/L30/{year}/{target_dir}/{folder_name}/{folder_name}.{b}.tif')
            l_i += 1
        if len(S_dates) > 0:
            for s_i, d in enumerate(S_dates):
                folder_head = f'{S_head}{date_to_string(d)}'
                for p in path_S: 
                    if re.findall(f'{folder_head}.', p): folder_name=p
                for b_i, b in enumerate(S_bands):
                    base_img[s_i+l_i, :, :, b_i] = io.imread(f'dataset/S30/{year}/{target_dir}/{folder_name}/{folder_name}.{b}.tif')
        base_img[:,:,:,8] = is_clear(base_img[:,:,:,7].astype(int))
        missing_map = (1-base_img[:,:,:,8]).astype(bool)
        base_img[missing_map,:7] = -9999
        if total_dates == 1:
            if len(L_dates) == 1:
                fuse_cube[d_i] = base_img[0]
            else:
                fuse_cube[d_i] = base_img[0]
        elif total_dates > 1:
            fuse_img = np.ones((3660,3660,9), dtype=np.int32) * -9999
            fuse_img[:,:,8] = 0
            date_bin = base_img[:,:,:,8].sum(axis=0)
            pix_bin = base_img[:,:,:,8].reshape(total_dates,-1).astype(int).sum(axis=1)
            sort_date = np.argsort(pix_bin)
            d1_maps = ((date_bin == 1) * base_img[:,:,:,8]).astype(bool)
            for t in range(total_dates):
                fuse_img[d1_maps[t]] = base_img[t,d1_maps[t]]
            dmore_maps = date_bin > 1
            for t in sort_date:
                if base_img[t,dmore_maps,8].sum()>0:
                    dmore_t_map = (base_img[t,:,:,8] == 1) * dmore_maps
                    fuse_img[dmore_t_map] = base_img[t, dmore_t_map]
            fuse_cube[d_i] = fuse_img
        else:
            print('error in dates')
            break
        fuse_cube[:,:,:,7] = fuse_cube[:,:,:,8]
        total_n = fuse_cube[:,:,:,-1].reshape(122,-1).T.sum(axis=1)
    return fuse_cube[:,:,:,:8], total_n

In [None]:
model_name = 'SSLI'
model_id = 'model_49'
model = SSLI(7,256,4,8,7,0).cuda()
# model = nn.DataParallel(model)
if not os.path.exists(f'models/{model_name}/{model_id}'):
    os.makedirs(f'models/{model_name}/{model_id}')
model_path = f'D:/HLS/models/{model_name}/{model_id}.pth'
state_dict = torch.load(model_path)['model']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]
    name = f'module.{k}'
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.eval();

In [None]:
start_date, end_date, date_radius = 1,365,1
dates = np.arange(start_date,end_date+1,date_radius*2+1)
test_path = ['11/S/Q/T', '12/T/V/M', '14/T/Q/P', '17/R/N/L', '18/T/W/N']
DOYs = [60, 55, 37, 119, 31]

In [None]:
tile_id = 0
batch_size = 200
num_workers = 40
tile_name = test_path[tile_id][:2]+test_path[tile_id][3]+test_path[tile_id][5]+test_path[tile_id][7]
cube, _ = fuse_HLS(dates, 2021, test_path[tile_id])
fuse_vector = cube.reshape(122,-1,8).transpose(1,0,2)
N = fuse_vector.shape[0]
test_list = [i for i in range(0,N)]   
test_dataset = Loader_pred(fuse_vector, test_list, DOYs[tile_id])
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=False,
    num_workers=num_workers)
pred = np.zeros((3660*3660, 7)).astype(float)
for iteration, batch in enumerate(tqdm(test_loader)):
    batch_in = batch[0]
    N = batch_in.shape[0]
    batch_in = V(batch_in).cuda()
    outs = model(batch_in).cpu().data.numpy()
    index = batch[1].cpu().data.numpy()
    for n in range(N):
        pred[index[n]] = outs[n,DOYs[tile_id]]
with open(f'models/{model_name}/pred_{tile_name}.npy', 'wb') as f:
    np.save(f, pred)
del(fuse_vector)
del(cube)
gc.collect()