In [1]:
import numpy as np
import matplotlib.pyplot as plt
from rlxutils import subplots
import pandas as pd
import sys
sys.path.append("..")
import torch
from torch import nn

from lib import data
from lib import sar
from lib import io

from torch import nn  
from loguru import logger
from lib.datamodules.components import scatterdataset
from lib.datamodules import scatterdatamodule
from lib.models import scattermodel
from omegaconf import OmegaConf
import hydra
import pprint
from importlib import reload
reload(scatterdataset)

<module 'lib.datamodules.components.scatterdataset' from '/home/ubuntu/sar-recovery/notebooks/../lib/datamodules/components/scatterdataset.py'>

In [2]:
conf = OmegaConf.load("../configs/scatter2coherence.yaml")
pprint.pp(OmegaConf.to_container(conf))

{'dataloader': {'_target_': 'lib.datamodules.scatterdatamodule.ScatterCoherencePatchesDataModule',
                'base_path': '/home/rlx/data/sar-recovery',
                'date_train': 20100520,
                'date_test': 20120720,
                'date_val': 20120720,
                'scatter_elems': ['Shh', 'Shv'],
                'coherence_elems': ['Shh2'],
                'patch_size': 60,
                'avg_window_size': 5,
                'splitmask_fn_src': 'lambda h, w: '
                                    'data.cv_splitpixels_spatial(h, w, '
                                    'pixels_train = 70, pixels_test = 15, '
                                    'pixels_val = 15, angle = np.pi/4)'},
 'model': {'_target_': 'lib.models.scattermodel.Scatter2Coherence',
           'in_channels': 2}}


In [3]:
dl = hydra.utils.instantiate(conf.dataloader)

[32m2024-04-19 19:21:22.376[0m | [1mINFO    [0m | [36mlib.datamodules.components.scatterdataset[0m:[36m__init__[0m:[36m44[0m - [1mloading scatter matrix[0m
[32m2024-04-19 19:21:22.526[0m | [1mINFO    [0m | [36mlib.datamodules.components.scatterdataset[0m:[36m__init__[0m:[36m47[0m - [1mcomputing coherence matrix[0m
ScatterCoherencePatchesDataset  patch_size=60  splitmask_dims=(4402, 1602)
original split proportions: train=0.700  test=0.150  val=0.150
patch split proportions:    train=1.000  test=0.000  val=0.000
        [0m
[32m2024-04-19 19:21:26.062[0m | [1mINFO    [0m | [36mlib.datamodules.components.scatterdataset[0m:[36m__init__[0m:[36m58[0m - [1mscatter   matrix shape is (4402, 1602, 2, 2), retrieving elems [('Shh', [0, 0]), ('Shv', [0, 1])][0m
[32m2024-04-19 19:21:26.062[0m | [1mINFO    [0m | [36mlib.datamodules.components.scatterdataset[0m:[36m__init__[0m:[36m59[0m - [1mcoherence matrix shape is (4402, 1602, 3, 3), retrieving elems

In [5]:
b = next(iter(dl.train_dataloader()))

In [6]:
b.keys()

dict_keys(['scatter_patch', 'coherence_patch', 'patch_coords', 'avg_coherence_patch'])

In [7]:
b['scatter_patch'].shape, b['avg_coherence_patch'].shape

(torch.Size([16, 2, 60, 60]), torch.Size([16, 1, 12, 12]))

In [67]:
ComplexActivation = scattermodel.ComplexActivation
smap = scattermodel.smap
from complexPyTorch.complexFunctions import complex_relu
from complexPyTorch.complexLayers import ComplexBatchNorm2d

class Scatter2Coherence(nn.Module):
    """
    assumes an input of shape [batch_size, h, w, 2, 2]
    """
    def __init__(self, in_channels):
        super().__init__()
        
        self.in_channels = in_channels

        self.layers = nn.Sequential(
                 nn.Conv2d(in_channels=self.in_channels, out_channels=20, 
                           kernel_size=5, stride=5, padding=0, 
                           dtype=torch.cfloat
                           ),
                 ComplexActivation(complex_relu),   
                 ComplexBatchNorm2d(20),
        
                 nn.Conv2d(in_channels=20, out_channels=20, 
                           kernel_size=3, stride=1, padding='same', 
                           dtype=torch.cfloat,
                           ),
                 ComplexActivation(complex_relu), 
                 ComplexBatchNorm2d(20),

                 nn.Conv2d(in_channels=20, out_channels=10, 
                           kernel_size=5, stride=1, padding='same', 
                           dtype=torch.cfloat,
                           ),
                 ComplexActivation(complex_relu), 
                 ComplexBatchNorm2d(10),


                 nn.Conv2d(in_channels=10, out_channels=1, 
                           kernel_size=7, stride=1, padding='same', 
                           dtype=torch.cfloat,
                           ),

        )
        
    def get_output_shape(self, input_shape):
        x = torch.rand((1,self.in_channels, *input_shape)).type(torch.cfloat)
        return self(x).shape[-2:]
    
    def forward(self, x):
        
        x = self.layers(x)
        
        return x

In [68]:
m = Scatter2Coherence(in_channels=2)

In [70]:
x = b['scatter_patch']
o = m(x)
x.shape, o.shape


(torch.Size([16, 2, 60, 60]), torch.Size([16, 1, 12, 12]))

In [55]:
(torch.Size([16, 2, 60, 60]), torch.Size([16, 1, 12, 12]))


(torch.Size([16, 2, 60, 60]), torch.Size([16, 1, 12, 12]))

In [71]:
o

tensor([[[[-6.8209e-01-3.5772e-02j,  1.0259e-01+2.2092e+00j,
           -2.2120e+00+8.2161e-01j,  ...,
           -2.3616e-01+2.0221e+00j, -2.2577e-01+5.3397e-01j,
           -8.8225e-01+1.7272e-01j],
          [ 2.2600e+00-5.4800e-01j, -9.5448e-01+2.4369e+00j,
           -2.6471e+00+9.1205e-01j,  ...,
           -3.0981e-01-1.7500e-01j, -1.8977e-01+3.0341e+00j,
           -1.2427e+00-5.4042e-01j],
          [ 7.7692e-02-2.7153e+00j, -2.4675e+00+5.4201e+00j,
            4.2657e-01+3.5719e+00j,  ...,
           -2.2209e+00+9.7791e-01j, -1.5946e+00+9.8416e-01j,
           -9.2323e-01+1.0586e+00j],
          ...,
          [ 9.3165e-01-3.9186e-01j, -1.2047e+00-7.7419e-01j,
           -2.4030e+00+3.8348e-01j,  ...,
           -4.7482e-01-1.3344e+00j, -1.2582e-01-6.0045e-01j,
           -9.7294e-01-8.0606e-01j],
          [-2.0956e-01-4.8668e-01j, -1.3964e+00-5.2854e-01j,
           -3.8091e-01-2.0907e-01j,  ...,
            1.3627e+00+5.2097e-01j, -1.4209e+00+6.5290e-01j,
           -1.840