In [1]:
import pandas as pd 
import numpy as np 

import neptune.new as neptune

from tqdm.notebook import tqdm 

import torch
import torch.nn as nn

import torchvision.transforms as T

import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2


In [None]:
metadata = pd.read_csv( "../data/raw/metadata.csv", parse_dates=["scene_start"])
print(metadata.shape)
metadata.shape

In [None]:
metadata.head()

In [None]:
# CALCULATE DATASET NORMALIZATION
# RESULTS 
# VV -10.4163 4.0595093
# VH -17.3540 4.3767033

def masked_std(x, mask):
    mask = np.repeat(np.expand_dims(mask > 0, axis = 0), 3,axis=0).reshape(3,-1)
    return np.std(np.transpose(x,[1,0,2,3]).reshape(3,-1),axis=1,where=mask)

def masked_mean(x, mask):
    mask = np.repeat(np.expand_dims(mask > 0, axis = 0), 3,axis=0).reshape(3,-1)
    return np.mean(np.transpose(x,[1,0,2,3]).reshape(3,-1),axis=1,where=mask)

"""
loader = DataLoader(train_dataset, batch_size=len(train_dataset), num_workers=0)
data = next(iter(loader))

print(masked_mean(data['x'],data['mask'],0), masked_std(data['x'],data['mask'],0))
print(masked_mean(data['x'],data['mask'],1), masked_std(data['x'],data['mask'],1))
"""

In [None]:
import rasterio
from torch.utils.data import Dataset
import albumentations as A 

DATA_PATH = '../data/raw/'
CHANNELS = ['_vv','_vh']
EXT = '.tif'

DTYPE = torch.float32
MEM_DTYPE = np.float32

X_NAN_VALUE = -255
Y_NAN_VALUE = 255

TEST_AUGMENTATIONS = [
    A.Transpose(p=1),
    A.VerticalFlip(p=1),
    A.HorizontalFlip(p=1),
]

TRAIN_AUGMENTATIONS = [
    A.Transpose(p=1),
    A.VerticalFlip(p=1),
    A.HorizontalFlip(p=1),
]

class IterChip(Dataset):

    def __init__(self,chip_ids,augment = True):
        super().__init__()

        self.chip_ids = chip_ids

        self.augment = augment
        self.TRAIN_AUGMENTATIONS = TRAIN_AUGMENTATIONS

        self.TEST_AUGMENTATIONS = TEST_AUGMENTATIONS
        
        self.ABS_CHANNEL = True

        self.x = []
        self.mask = []
        self.y = []

        print('creating dataset')
        print('reading chips')
        for id_ in tqdm(self.chip_ids):
            # VV, VH
            path = DATA_PATH + '/images/' + id_ + '_vv' + EXT
            with rasterio.open(path) as img:
                vv = img.read(1)
            path = DATA_PATH + '/images/' + id_ + '_vh' + EXT
            with rasterio.open(path) as img:
                vh = img.read(1)

            x_ = np.stack([vv, vh], axis=0)

            if self.ABS_CHANNEL:
                va = np.abs(np.diff(x_,axis=0))
                x_ = np.concatenate([x_, va], axis=0)
            
            self.x.append(x_)

            # MASK
            path = DATA_PATH + '/images/' + id_ + '_vv'  + EXT
            with rasterio.open(path) as img:
                mask_ = img.dataset_mask() / 255
                self.mask.append(mask_)

            #LABEL
            path = DATA_PATH + '/labels/' + id_  + EXT
            with rasterio.open(path) as img:
                y_ = img.read(1)
            self.y.append(y_)

        self.x = np.array(self.x,dtype=MEM_DTYPE)
        self.mask = np.array(self.mask,dtype=MEM_DTYPE)
        self.y = np.array(self.y,dtype=MEM_DTYPE)

        print('calculating norms - mean and std')
        self.t_mean = masked_mean(self.x,self.mask)
        self.t_std = masked_std(self.x,self.mask)

        # VV -10.4163 4.0595093
        # VH -17.3540 4.3767033
        print('NORMALIZE: ', self.t_mean, self.t_std)
        print('normalizing...')
        for dim, _p in enumerate(zip(self.t_mean, self.t_std)):
            _mean, _std = _p
            self.x[:,dim] = (self.x[:,dim] - _mean) / _std

        print('DATASET CREATED', self.x.shape)


    def __len__(self):
        return len(self.chip_ids)

    def __getitem__(self, index, test_augment=False):

        x_ = np.expand_dims(self.x[index],axis=0)
        y_ = np.expand_dims(self.y[index],axis=0)

        # Train Augments
        if self.augment:

            _ax, _ay = np.array([]).reshape(0,3,512,512), np.array([]).reshape(0,512,512)

            for k in [2]:

                _x = np.rot90(x_,k=k,axes=(2,3))
                _y = np.rot90(y_,k=k,axes=(1,2))

                _ax = np.concatenate([_ax, _x],axis=0)
                _ay = np.concatenate([_ay, _y],axis=0)
            
            x_ = np.concatenate([x_, _ax], axis=0)
            y_ = np.concatenate([y_, _ay], axis=0)

        # Test Augments
        if test_augment:
            for augmentation in self.TEST_AUGMENTATIONS:

                aug = A.Compose([
                    augmentation
                ])

                x_ = np.transpose(x_, [1, 2, 0])

        #x[0][y == Y_NAN_VALUE] = X_NAN_VALUE
        #x[1][y == Y_NAN_VALUE] = X_NAN_VALUE

        x_ = torch.tensor(x_).to(DTYPE)
        y_ = torch.tensor(y_).to(DTYPE)

        return {
            'x':x_,
            #'mask':mask,
            'label':y_
        }

    def converge_aug_inference(self, preds):
        
        y = np.transpose(preds.unsqueeze(1).numpy(), [0, 2, 3, 1])

        p = np.expand_dims(y[0], axis=0)
        y = y[1:]

        for augmentation, y_ in zip(self.TEST_AUGMENTATIONS, y):

            aug = A.Compose([
                augmentation
            ])

            transformed = aug(image=y_)

            p = np.concatenate((p, np.expand_dims(transformed['image'], axis=0)))
        
        p = np.transpose(p, [0, 3, 1, 2]).squeeze(1)
        p = np.mean(p,axis=0)
        
        p = torch.tensor(p).to(DTYPE).unsqueeze(0)

        return p

In [None]:
#542
from sklearn.model_selection import train_test_split

train, val = train_test_split(metadata.chip_id, test_size=0.33, random_state=42)

train_dataset = IterChip(train,True)
val_dataset = IterChip(val,False)

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

BS = 3
NW = 0

train_dataloader = DataLoader(
    dataset=train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=BS,
    num_workers=NW,
    pin_memory=True
)

val_dataloader = DataLoader(
    dataset=val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=BS,
    num_workers=NW,
    pin_memory=True
)

In [None]:
import sys
sys.path.append('..')
from src.models.swin_unet import SwinTransformerSys
import segmentation_models_pytorch as smp
from src.models.repvgg.repvgg import repvgg_model_convert, create_RepVGG_A0
from src.models.brr import BRR


"""
model = SwinTransformerSys(img_size=512, patch_size=4, in_chans=3, num_classes=1,
            embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[2, 2, 2, 2], num_heads=[4, 4, 4, 4],
            window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0.0,
            norm_layer=nn.LayerNorm, ape=True, patch_norm=True,
            use_checkpoint=False, final_upsample="expand_first")

torch.backends.cudnn.benchmark = True
"""
model = smp.Unet(
    encoder_name="resnext50_32x4d",       
    encoder_weights="imagenet", 
    in_channels=2,                  
    classes=1,                      
)

#model = BRR()

#model = create_RepVGG_A0(deploy=False)
#model = nn.Conv2d(2,1,(7,7),stride=(1,1),padding=3)
#model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=2, out_channels=1, init_features=64, pretrained=False)

model

In [None]:
run = neptune.init(
    project='victorcallejas/FBSim',
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJlNDRlNTJiNC00OTQwLTQxYjgtYWZiNS02OWQ0MDcwZmU5N2YifQ=='
    )

In [None]:
device = torch.device("cuda")
device

In [None]:
from segmentation_models_pytorch.losses.soft_bce import SoftBCEWithLogitsLoss
from segmentation_models_pytorch.losses import DiceLoss, LovaszLoss

def jaccard_coeff(preds, true):

    preds = nn.Sigmoid()(preds)
    preds = (preds > 0.5) * 1

    valid_pixel_mask = true.ne(255)
    true = true.masked_select(valid_pixel_mask)
    preds = preds.masked_select(valid_pixel_mask)

    intersection = np.logical_and(true, preds)
    union = np.logical_or(true, preds)
    return intersection.sum() / union.sum()

def BCE1_DICE(preds, true):
    #f(x) = BCE + 1 — DICE
    bce = SoftBCEWithLogitsLoss(ignore_index = Y_NAN_VALUE,smooth_factor = 0.05)
    dice = DiceLoss('binary', log_loss=True, from_logits=True, smooth=0.05, ignore_index=Y_NAN_VALUE, eps=1e-07)
    return 0.5 * bce(preds,true) + 0.5 * dice(preds,true)

In [None]:
optimizer = torch.optim.AdamW(
                model.parameters(),
                lr = 5e-4
            )

criterion = BCE1_DICE
#criterion = SoftBCEWithLogitsLoss(ignore_index = Y_NAN_VALUE,smooth_factor = 0.1).to(device)
#criterion = DiceLoss('binary', log_loss=False, from_logits=True, smooth=0.2, ignore_index=Y_NAN_VALUE, eps=1e-07)
#criterion = LovaszLoss('binary',from_logits=True,ignore_index=Y_NAN_VALUE).to(device)
#jaccard = JaccardLoss(from_logits = True, mode = 'binary')

fp16 = False
scaler = torch.cuda.amp.GradScaler()

iters_to_accumulate = 2

VAL_WATERSHED = False
VAL_PLOT =True

In [None]:
from src.utils.plot import display_preds
%matplotlib agg

from src.utils.post import post_watershed

def valid(model,val_dataloader, plot = True, watershed = False):

    model.eval()

    labels, preds = torch.tensor([]), torch.tensor([])

    for _, batch in tqdm(enumerate(val_dataloader),total=len(val_dataloader),leave=False):

        x = batch['x'].to(device,non_blocking=True).flatten(0,1)
        targets = batch['label'].to(device,non_blocking=True).flatten(0,1)
        
        with torch.cuda.amp.autocast(enabled=fp16):
            with torch.no_grad(): 
                b_preds = model(x).squeeze(1)
                loss = criterion(b_preds, targets)
        
                scaler.scale(loss)
                run["dev/batch_loss"].log(loss.item())

                preds = torch.cat([preds, b_preds.detach().cpu()], dim = 0)
                labels = torch.cat([labels, targets.detach().cpu()], dim = 0)
                 

    epoch_loss = criterion(preds,labels)
    run["dev/loss"].log(epoch_loss)

    jac = jaccard_coeff(preds,labels)
    run["dev/jaccard"].log(jac)

    if watershed:
        preds_ws = post_watershed(preds)
        jac_w = jaccard_coeff(preds_ws,labels)
        run["dev/jaccard_ws"].log(jac_w)

    print('Validation: ', epoch_loss, jac)

    if plot:
        fig = display_preds(val_dataset ,preds,labels,2)
        run['validation/plt'].upload(fig)

                

In [None]:
model = model.to(device)

steps = 0

for epoch in tqdm(range(1,1000)):

    print('Epoch: ',epoch)

    model.train()
    optimizer.zero_grad(set_to_none=True)

    labels, preds = torch.tensor([]), torch.tensor([])

    for step, batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader),leave=False):

        x = batch['x'].to(device,non_blocking=True)
        targets = batch['label'].to(device,non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=fp16):
            b_preds = model(x).squeeze(1)
            loss = criterion(b_preds, targets)
        scaler.scale(loss).backward()

        run["train/batch_loss"].log(loss.item())

        if (step + 1) % iters_to_accumulate == 0:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(),5.0)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        preds = torch.cat([preds, b_preds.detach().cpu()], dim = 0)
        labels = torch.cat([labels, targets.detach().cpu()], dim = 0)   

    epoch_loss = criterion(preds,labels)
    run["train/loss"].log(epoch_loss)

    jac = jaccard_coeff(preds,labels)
    run["train/jaccard"].log(jac)

    print('Train: ', epoch_loss, jac)
    valid(model,val_dataloader,VAL_PLOT, VAL_WATERSHED)



In [None]:
torch.save(model,'../artifacts/model.pt')

In [None]:
model = torch.load('../artifacts/model.pt',map_location=device)

## AUGMENTED INFERENCE

In [None]:
val_aug_dataset = IterChip(val,False,True)

val_aug_dataloader = DataLoader(
    val_aug_dataset,
    sampler = SequentialSampler(val_aug_dataset),
    batch_size=1,
    pin_memory=True,
    num_workers=NW
)

In [None]:
model.eval()

labels, preds = torch.tensor([]), torch.tensor([])

for _, batch in tqdm(enumerate(val_aug_dataloader),total=len(val_aug_dataloader),leave=False):

    x = batch['x'].to(device,non_blocking=True).squeeze(0) # BATCH SIZE ALWAYS MAKE 1
    targets = batch['label'][0][0].unsqueeze(0)
   
    with torch.cuda.amp.autocast(enabled=fp16):
        with torch.no_grad(): 
            b_preds = model(x).squeeze(1)

            b_preds = val_aug_dataloader.dataset.converge_aug_inference(b_preds.detach().cpu())
            
            preds = torch.cat([preds, b_preds], dim = 0)
            labels = torch.cat([labels, targets.detach()], dim = 0)
            
epoch_loss = criterion(preds,labels)
run["dev/aug_loss"].log(epoch_loss)

jac = jaccard_coeff(preds,labels)
run["dev/aug_jaccard"].log(jac)

if VAL_WATERSHED:
    preds_ws = post_watershed(preds)
    jac_w = jaccard_coeff(preds_ws,labels)
    run["dev/jaccard_ws"].log(jac_w)    

print('Augmented Validation: ', epoch_loss, jac)
valid(model,val_dataloader)

In [None]:
from skimage.segmentation import watershed
from skimage.feature import peak_local_max

from scipy import ndimage as ndi

from multiprocessing import Pool

N_WORKERS = 10

def w_watershed(img):

    distance = ndi.distance_transform_edt(img)

    coords = peak_local_max(distance, footprint=np.ones((3, 3)),labels=img)
    mask = np.zeros(distance.shape, dtype=bool)
    mask[tuple(coords.T)] = True
    markers, _ = ndi.label(mask)
    p = watershed(-distance, markers, mask=img)

    return p

def post_watershed(x):

    x = torch.nn.Sigmoid()(x)
    x = (x > 0.5) * 1

    x = x.numpy().astype(np.int32)
    
    pool = Pool(processes=N_WORKERS)

    l = []
    for s, result in enumerate(pool.imap(func=w_watershed, iterable=x)):
        print(s)
        l.append(result)

    """
    l = []
    for s, img in enumerate(x):
        print(s)
        l.append(w_watershed(img))
    
    """
    return torch.tensor(l)

In [None]:
VAL_WATERSHED = True
valid(model,val_dataloader)