##Libraries

In [None]:
import pickle as pkl
import builtins

In [None]:
#update torch and torch vision
!pip install -q torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch

In [None]:
!pip install fastai==1.0.61

In [None]:
import fastai
print("fastai version: ", fastai.__version__) #1.0.61

In [None]:
from fastai.basic_train import *
from fastai.vision.data import *
from fastai.vision.image import *
from fastai.vision.transform import *
from fastai.vision.models import *
from fastai.vision.learner import *
from fastai.vision import *

##Mounting Google Drive

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/gdrive')

##Paths

In [None]:
output_path   = Path('/content/gdrive/My Drive/Colab Notebooks/Slovenia/output_9_9')

In [None]:
train_path = output_path/'train_patchlets_64_64'

In [None]:
test_path = output_path/'test_patchlets_64_64'

##Fastaiv1

###Create custom ItemList and LabelList classes to define data loading and display

In [None]:
class SegmentationPklLabelList(SegmentationLabelList):
    def open(self, fn):
        x = pkl.load(builtins.open(str(fn),'rb'))[None,...].astype(np.float32)
        return ImageSegment(torch.tensor(np.squeeze(x, 3)))

In [None]:
class SegmentationPklList(SegmentationItemList):
    _label_cls,_square_show_res = SegmentationPklLabelList,False

    def open(self, fn):
        x = pkl.load(builtins.open(str(fn),'rb'))
        # x = x.transpose([0,3,1,2]).reshape([-1, x.shape[1], x.shape[2]]).astype(np.float32)
        x = x.transpose([0,3,1,2]).astype(np.float32) # 31*9*64*64

        #print(x.shape)
        return Image(torch.tensor(x))

    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        "Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`."
        rows = int(np.ceil(math.sqrt(len(xs))))
        axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize)
        for x,y,ax in zip(xs, ys, axs.flatten()): Image(torch.clamp(x.data[0:3,:,:]*3.5,0,1)).show(ax=ax, y=y, alpha=0.4,**kwargs)
        for ax in axs.flatten()[len(xs):]: ax.axis('off')
        plt.tight_layout()

In [None]:
classes=['No Data',
         'Cultivated_land',
         'Forest',
         'Grassland',
         'Shrubland',
         'Water',
         'Wetland',
         'Artificial_surface',
         'Bareland',
         ]

In [None]:
def valid_patch(fn, i=6):
    return f'patch_{i}' in str(fn)

def get_mask(fn):
    return str(fn).replace('feat','targ')

def exclude_masks(fn):
    return not('targ' in str(fn.name))

In [None]:
bs = 2
# bs =32

In [None]:
src = (SegmentationPklList.from_folder(train_path, extensions=['.pkl'], recurse=True, convert_mode='L')
      .filter_by_func(exclude_masks)
      .split_by_rand_pct(0.1, seed=42)
      #.split_by_valid_func(valid_patch)
      .label_from_func(get_mask, classes=classes))
src

In [None]:
# stats_data = src.databunch(bs=128)
stats_data = src.databunch(bs=16)

In [None]:
stats_data.batch_stats

In [None]:
x,y = stats_data.one_batch()

In [None]:
x.shape

In [None]:
y.shape

In [None]:
means = x.mean(dim=[0,3,4])
stds = x.std(dim=[0,3,4])
# means = x.mean(dim=[0,2,3])
# stds = x.std(dim=[0,2,3])

### Define data augmentation and get databunch

In [None]:
tfms = get_transforms(
    do_flip = True,
    flip_vert = True,
    max_rotate = 20,
    max_zoom = 1.1,
    max_lighting = 0.,
    max_warp = 0.2,
    p_affine = 0.75,
    p_lighting = 0.,
    xtra_tfms = [cutout(n_holes=(5,10), length=(3, 8), p=0.75, use_on_y=False)]
)
tfms

In [None]:
tfms = [tfms[0][1:],[]]# gets rid of resize transformations - they don't work the target mask
tfms

In [None]:
data = (src
        .transform(tfms,
                  tfm_y=True)
        .add_test(test_set,tfms=None, tfm_y=False)
        .databunch(bs=bs, num_workers = 0)
        .normalize(stats=(means,stds))
        )

In [None]:
data.batch_stats

##Loss Functions

In [None]:
from torch import nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, crit, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.crit = crit

    def forward(self, inputs, targets, reduction):
        loss = self.crit(inputs, targets)
        pt = torch.exp(-loss)
        F_loss = self.alpha * (1-pt)**self.gamma * loss

        if reduction is None:
            return F_loss
        else:
            return torch.mean(F_loss)

In [None]:
inv_freq = np.array([0.125314, 0.016608, 0.049898, 0.586994, 2.257207, 24.71436, 0.107695, 98.45247])

In [None]:
inv_freq = [0.,*inv_freq]
inv_freq

In [None]:
inv_prop = torch.tensor(inv_freq/sum(inv_freq)).float().cuda()
inv_prop

In [None]:
focal_loss = FocalLoss(crit=CrossEntropyFlat(axis=1,weight=inv_prop,ignore_index=0)) # For Fastaiv1

In [None]:
class myMixUpCallback(LearnerCallback):
    "Callback that creates the mixed-up input and target."
    def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
        super().__init__(learn)
        self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y

    def on_train_begin(self, **kwargs):
        if self.stack_y: self.learn.loss_func = myMixUpLoss(self.learn.loss_func)

    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        "Applies mixup to `last_input` and `last_target` if `train`."
        if not train: return
        lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
        lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
        lambd = last_input.new(lambd)

        shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
        x1, y1 = last_input[shuffle], last_target[shuffle]
        if self.stack_x:
            new_input = [last_input, last_input[shuffle], lambd]
        else:
            out_shape = [lambd.size(0)] + [1 for _ in range(len(x1.shape) - 1)]
            new_input = (last_input * lambd.view(out_shape) + x1 * (1-lambd).view(out_shape))
        if self.stack_y:

            new_lambd = torch.distributions.utils.broadcast_all(lambd[:,None,None,None], last_target)[0]

            #new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), new_lambd[:,None].float()], 1)
            new_target = torch.stack([last_target.float(), y1.float(), new_lambd.float()], 1)
        else:
            if len(last_target.shape) == 2:
                lambd = lambd.unsqueeze(1).float()
            new_target = last_target.float() * lambd + y1.float() * (1-lambd)

        return {'last_input': new_input, 'last_target': new_target}

    def on_train_end(self, **kwargs):
        if self.stack_y: self.learn.loss_func = self.learn.loss_func.get_old()

In [None]:
from functools import partial

In [None]:
class myMixUpLoss(nn.Module):
    "Adapt the loss function `crit` to go with mixup."

    def __init__(self, crit, reduction='mean'):
        super().__init__()
        if hasattr(crit, 'reduction'):
            self.crit = crit
            self.old_red = crit.reduction
            setattr(self.crit, 'reduction', 'none')
        else:
            self.crit = partial(crit, reduction='none')
            self.old_crit = crit
        self.reduction = reduction

    def forward(self, output, target):
        if len(target.size()) >= 5:
            loss1, loss2 = self.crit(output,target[:,0].long()), self.crit(output,target[:,1].long())
            lambd = target[:,2].contiguous().view(-1)
            d = (loss1 * lambd  + loss2 * (1-lambd)).mean()
        else:  d = self.crit(output, target)
        if self.reduction == 'mean': return d.mean()
        elif self.reduction == 'sum':            return d.sum()
        return d

    def get_old(self):
        if hasattr(self, 'old_crit'):  return self.old_crit
        elif hasattr(self, 'old_red'):
            setattr(self.crit, 'reduction', self.old_red)
            return self.crit

In [None]:
def pixel_acc(inputs, targs):
    inputs = inputs.argmax(dim=1)[:,None,...]
    return (targs[targs!=0]==inputs[targs!=0]).float().mean()
    #return (targs==inputs).float().mean()

##Model

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DownsampleDilatedBlock(nn.Module):
    # This code is hidden, will be released after the acceptance of our manuscript.

class DoubleConv(nn.Module):
    # This code is hidden, will be released after the acceptance of our manuscript.

class Interp(nn.Module):
  # This code is hidden, will be released after the acceptance of our manuscript.

class SE(nn.Module):
    # This code is hidden, will be released after the acceptance of our manuscript.

class SE2d(nn.Module):
    # This code is hidden, will be released after the acceptance of our manuscript.

class DilatedSE3DUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256]):
        super(DilatedSE3DUNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.attn = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DownsampleDilatedBlock(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(Interp(feature*2, feature))
            self.ups.append(DoubleConv(feature*2, feature))

        # Attention modules
        for feature in reversed(features):
          self.attn.append(SE(feature*2))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.attn2d = SE2d(512)
        self.final_conv = nn.Conv2d(512, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape[2:] != skip_connection.shape[2:]:
                es = x.shape[2:]
                ds = skip_connection.shape[2:]
                skip_connection = skip_connection[
                            :,
                            :,
                            ((ds[0] - es[0]) // 2):((ds[0] + es[0]) // 2),
                            ((ds[1] - es[1]) // 2):((ds[1] + es[1]) // 2),
                            ((ds[2] - es[2]) // 2):((ds[2] + es[2]) // 2),
                            ]

            concat_skip = torch.cat((skip_connection, x), dim=1)
            att = self.attn[idx//2](concat_skip)
            x = self.ups[idx+1](att)
        x = x.reshape([x.shape[0], x.shape[1]*x.shape[2], 64, 64])
        x = self.attn2d(x)
        return self.final_conv(x)

In [None]:
def test():
    x = torch.randn((4, 31, 9, 64, 64))
    model = DilatedSE3DUNET(in_channels=31, out_channels=11)
    #print(model)
    preds = model(x)
    #print(preds)
    print("preds.shape:", preds.shape)

In [None]:
test()

##Training

In [None]:
learn = Learner(data,
                DilatedSE3DUNET(in_channels=23, out_channels=11),
                loss_func=focal_loss,
                metrics=[pixel_acc],
                callback_fns=[partial(myMixUpCallback,alpha=0.4, stack_y=True)])

In [None]:
learn.summary()

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(5, max_lr = 1e-2, wd = 0.3)

In [None]:
learn.unfreeze()

In [None]:
learn.fit_one_cycle(20, max_lr=1e-3, wd=0.3)

In [None]:
learn.save('Final_focal_loss_31_9_64_64')

In [None]:
learn.export('/content/gdrive/My Drive/Colab Notebooks/Slovenia/Final_focal_loss_31_9_64_64.pkl')