In [None]:
from fastai.vision import *
from fastai.vision.gan import *
from fastai import *
import os
import random
import time
import numpy as np
import torch
from torch import nn
from torch.backends import cudnn
import torch.utils.data as data
from torch import optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import functional as F
from tqdm import tqdm, trange
from network import SSLI
import gc

In [None]:
lst = []
for i in range(122):
    lst.append(f'{i:03d}.coastal_blue')
    lst.append(f'{i:03d}.blue')
    lst.append(f'{i:03d}.green')
    lst.append(f'{i:03d}.red')
    lst.append(f'{i:03d}.nir')
    lst.append(f'{i:03d}.swir1')
    lst.append(f'{i:03d}.swir2')
    lst.append(f'{i:03d}.qa')
col_val = ['col','row','image_year']+lst+['total_n']
dtype_dict = {}
for name in col_val:
    dtype_dict.update({name: np.int16})
for name in ['tile_id', 'TVT']:
    dtype_dict.update({name: 'string'})
data_2021 = pd.read_csv('SSLI samples.csv',compression='gzip',dtype=dtype_dict)

qa_lst = []
for i in range(122): qa_lst.append(f'{i:03d}.qa')
qa_lst += ['total_n']

train_df_2021 = data_2021[data_2021['TVT'] == 'train']
valid_df_2021 = data_2021[data_2021['TVT'] == 'valid']
del(data_2021)
gc.collect()
MIN_DOY = 2
train_mat_2021 = train_df_2021[lst].values.reshape(-1,122,8)
train_mat = train_mat_2021.astype(np.float64)
train_doy = train_mat[:,:,7].sum(axis=-1)
train_mat = train_mat[train_doy>=MIN_DOY]
raw_train = torch.from_numpy(train_mat[train_mat[:,:,7]==1][:,:7]).float()
mean = torch.mean(raw_train, dim=0, keepdim=True)
std = torch.std(raw_train, dim=0, keepdim=True)
x,y = np.where(train_mat[:,:,7]==1)
train_mat[x,y,:7] = (train_mat[x,y,:7] - mean.numpy()) / std.numpy()
del(train_df_2021)
del(train_mat_2021)
gc.collect()

total_list = [i for i in range(train_mat.shape[0])]
random.shuffle(total_list)

valid_mat_2021 = valid_df_2021[lst].values.reshape(-1,122,8)
valid_mat = valid_mat_2021.astype(np.float64)

x,y = np.where(valid_mat[:,:,7]==1)
valid_mat[x,y,:7] = (valid_mat[x,y,:7] - mean.numpy()) / std.numpy()

valid_doy = valid_mat[:,:,7].sum(axis=-1)
valid_mat = valid_mat[valid_doy>=MIN_DOY]
random.seed(0)
test_list = [i for i in range(0,valid_mat.shape[0])]
random.shuffle(test_list)

del(valid_df_2021)
del(valid_mat_2021)
gc.collect()

test_num = 0
for i in tqdm(test_list):
    test_num += np.count_nonzero(valid_mat[i, :, 7]!=0)
items = np.zeros((test_num, 2), dtype=int)
n = 0
for i in tqdm(test_list):
    clr_inx = np.argwhere(valid_mat[i, :, 7]!=0)[:,0]
    for c_i in clr_inx:
        items[n,0] = i
        items[n,1] = c_i
        n += 1

In [None]:
class Loader_train(data.Dataset):
    def __init__(self, nts, index_list, ratio=0.3):
        self.nts = nts
        self.index_list = index_list
        self.ratio = ratio
    def __getitem__(self,index):
        ts = np.zeros((self.nts.shape[1], 8))
        ts[:,:7] = self.nts[self.index_list[index],:,:7].astype(np.float64) # TC
        qa = self.nts[self.index_list[index],:,7].astype(np.float64)
        clr_inx = np.argwhere(qa!=0)[:,0]
        clr_num = int(clr_inx.shape[0] * self.ratio)
        clr_num = clr_num if clr_num >= 1 else 1
        mask_inx = np.random.choice(clr_inx, size=clr_num, replace=False)
        ts[mask_inx,:7] = 0
        qa[mask_inx] = 0
        ts[:,:7][qa==0] = 0
        ts[:,-1][qa==1] = 1
        ts = torch.from_numpy(ts).float()
        gt = torch.from_numpy(self.nts[self.index_list[index],:,:7]).float()
        mask = torch.zeros_like(gt)
        mask[mask_inx] = 1
        return ts, torch.cat((gt,mask),dim=1)
    def __len__(self):
        return len(self.index_list)
class Loader_test(data.Dataset):
    def __init__(self, nts, index_list):
        self.nts = nts
        self.index_list = index_list
    def __getitem__(self,index):
        pos_index = self.index_list[index][0]
        mask_index = self.index_list[index][1]
        ts = np.zeros((self.nts.shape[1], 8))
        ts[:,:7] = self.nts[pos_index,:,:7].astype(np.float64)
        qa = self.nts[pos_index,:,7].astype(np.float64)
        qa[mask_index] = 0
        ts[:,:7][qa==0] = 0
        ts[:,-1][qa==1] = 1
        ts[mask_index,:7] = 0
        ts = torch.from_numpy(ts).float()
        gt = torch.from_numpy(self.nts[pos_index,:,:7]).float()
        mask = torch.zeros_like(gt)
        mask[mask_index] = 1
        return ts, torch.cat((gt,mask), dim=1).float()
    def __len__(self):
        return self.index_list.shape[0]

In [None]:
class mask_loss(nn.Module): # 训练CNN用
    def __init__(self, batch=True):
        super(mask_loss, self).__init__()
        self.batch = batch
        self.loss = nn.L1Loss(reduction='mean')
        self.mean = mean
        self.std = std
    
    def __call__(self, pred, target):
        y, mask = target[:,:,:7], target[:,:,7:]        
        mean = self.mean.cuda().to(pred.get_device())
        std = self.std.cuda().to(pred.get_device())
        pred = torch.sigmoid((pred*std+mean)/10000.)
        y = (y*std+mean)/10000.
        loss = self.loss(pred[mask==1.0], y[mask==1.0])
        change = torch.diff(pred, dim=1)
        smooth_loss = torch.abs(torch.diff(change,dim=1)).mean()
        return loss+smooth_loss*0.2

In [None]:
ratio = 0.3
batch_size = 800
num_workers = 40
train_dataset = Loader_train(train_mat, total_list,ratio)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=False,
    num_workers=num_workers)
test_dataset = Loader_test(valid_mat, items)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=False,
    num_workers=num_workers)

train_data = ImageDataBunch(train_dl=train_loader, valid_dl=test_loader)
train_data.sanity_check()
save_name = 'SSLI'
save_path = os.path.join('models', save_name)
if not os.path.exists(save_path): os.mkdir(save_path)

model = SSLI(7,256,4,8,7,0.2).cuda()
#model = nn.DataParallel(model)
learn = Learner(train_data, model, model_dir=save_path, loss_func=mask_loss())
learn.fit_one_cycle(50,1e-4, callbacks=[
    callbacks.SaveModelCallback(learn, every='improvement', monitor='valid_loss', name=f'best_val'),
    callbacks.CSVLogger(learn, os.path.join(save_path, 'record'))], wd=1e-3)
learn.save('model_49', with_opt=False)