In [None]:
import numpy as np
import pandas as pd
import geopandas as gpd
from tqdm import tqdm
from random import shuffle

import matplotlib.pyplot as plt
import seaborn as sns

import sys
sys.path.append("..//..")
import config
from utils import compute_frames

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

In [None]:
sns.set_style("darkgrid", {"grid.color": ".6", "grid.linestyle": ":"})

In [None]:
# set device to GPU
dev = "cuda:0"

# History dataset generation

In [None]:
# load frames idx detail
frames_idx = pd.read_csv(config.TR_FRAMES_IDX, index_col=0)

In [None]:
# load frames deforestation area history
deforestation = pd.read_csv(config.TR_DEFORESTATION, index_col=0)
deforestation["date"] = pd.to_datetime(deforestation["date"])

## Create grid

In [None]:
# create limits history grid
time_grid = np.zeros((len(config.TIME_STEPS), frames_idx["x"].max() - frames_idx["x"].min() + 1, frames_idx["y"].max() - frames_idx["y"].min() + 1))
for t, dt in tqdm(enumerate(config.TIME_STEPS), total=len(config.TIME_STEPS)):
    defor_area = (
        deforestation[
            deforestation["date"] == dt
        ].set_index("frame_id")["area"] +\
        pd.Series(0, index=frames_idx.index)
    ).fillna(0).sort_index()
    time_grid[t, :, :] = defor_area.values.reshape(time_grid[0, :, :].shape)

In [None]:
time_grid.shape

In [None]:
sns.lineplot((time_grid>0).mean(axis=(1,2)))

In [None]:
time_grid[time_grid>0].min()

In [None]:
(time_grid == 0).mean()

In [None]:
sns.histplot(time_grid[time_grid>0].flatten())
plt.show()

## Past Deforestation

In [None]:
past_defor = pd.read_csv(config.TR_PAST_DEFOR, index_col=0)
print("Shape:", past_defor.shape)
past_defor.head()

In [None]:
past_defor["date"] = pd.to_datetime(past_defor["date"])

In [None]:
frames_idx.shape

In [None]:
past_grid = np.zeros((len(config.TIME_STEPS), frames_idx["x"].max() - frames_idx["x"].min() + 1, frames_idx["y"].max() - frames_idx["y"].min() + 1))
for t, dt in tqdm(enumerate(config.TIME_STEPS), total=len(config.TIME_STEPS)):
    past_area = (
        past_defor[
            past_defor["date"] <= dt
        ].groupby("frame_id")["area"].sum() +\
        pd.Series(0, index=frames_idx.index)
    ).fillna(0).sort_index()
    past_grid[t, :, :] = past_area.values.reshape(time_grid[0, :, :].shape)

In [None]:
# clip on 1.0 (upper bound)
past_grid = past_grid.clip(max=1.0)

In [None]:
past_grid.shape

In [None]:
sns.lineplot((past_grid>0).mean(axis=(1,2)))

## IBAMA 

In [None]:
ibama = pd.read_csv(config.TR_IBAMA)

In [None]:
ibama.columns

In [None]:
cols = ['Access_Minut_Beef_2012', 'Access_Minut_City',
       'Access_Minut_soy', 'Access_Minut_soy_p25', 'Access_Minut_wood_2012',
       'garimpos', 'Multas_upto2019', 'Terras_Devolutas', 'TI_Dist',
       'UCPI_dist', 'UCUS_Dist', 'Pasture_Mapbiomas', 'Soybean_Mapbiomas',
       'UCPI_IO', 'UCUS_IO', 'Terras_Devolutas_IO', 'TI_IO']
ib_array = np.zeros((len(cols), frames_idx["x"].max() - frames_idx["x"].min() + 1, frames_idx["y"].max() - frames_idx["y"].min() + 1))
for icol, col in tqdm(enumerate(cols), total=len(cols)):
    v = (
        ibama.set_index("frame_id")[col] +\
        pd.Series(0, index=frames_idx.index)
    ).sort_index()
    v = v.fillna(v.mean())
    ib_array[icol, :, :] = v.values.reshape(ib_array[0, :, :].shape)

In [None]:
ib_array.shape

In [None]:
ib_array.mean(axis=(1, 2))

## Counties data

In [None]:
# counties
frames_county = pd.read_csv(config.TR_COUNTIES, index_col=0)

In [None]:
county_data = np.zeros((2, frames_idx["x"].max() - frames_idx["x"].min() + 1, frames_idx["y"].max() - frames_idx["y"].min() + 1))
county_data[0, :, :] = (
    frames_county.set_index("frame_id")["populacao"] +\
    pd.Series(0, index=frames_idx.index)
).fillna(0).\
    values.reshape(county_data.shape[1:])

county_data[1, :, :] = (
    frames_county.set_index("frame_id")["densidade"] +\
    pd.Series(0, index=frames_idx.index)
).fillna(0).\
    values.reshape(county_data.shape[1:])

In [None]:
county_data.shape

In [None]:
county_data.mean(axis=(1,2))

# Compute frame patches

A patch is squared set of unitary frames. The patch formation process consists in iterating through the full image on both axis computing the frames corresponding to each iteration.

In [None]:
config.INPUT_BOXES_SIZE = 128

In [None]:
out_condition = "both"  # deforestation | borders | both

bundle_step = 64
patches = []
patches_full = []
for ix in tqdm(list(range(frames_idx["x"].min(), frames_idx["x"].max()+1, bundle_step))):
    fx = ix + config.INPUT_BOXES_SIZE
    for iy in range(frames_idx["y"].min(), frames_idx["y"].max()+1, bundle_step):
        fy = iy + config.INPUT_BOXES_SIZE

        iframes = frames_idx[
            (frames_idx["x"] >= ix) & 
            (frames_idx["x"] < fx) &
            (frames_idx["y"] >= iy) &
            (frames_idx["y"] < fy)
        ]

        patches_full.append(iframes.index)
        
        if out_condition == "borders":
            if iframes["in_borders"].mean() >= 0.5:  # condition: bundle has to be at least half inside borders
                patches.append(iframes.index)
                
        elif out_condition == "deforestation":
            out_of_borders_frames = len(set(iframes.index) - set(deforestation["frame_id"].values))
            if out_of_borders_frames < len(iframes):  # condition: bundle has to contain some deforestation
                patches.append(iframes.index) 

        elif out_condition == "both":
            out_of_borders_frames = len(set(iframes.index) - set(deforestation["frame_id"].values))
            if (out_of_borders_frames < len(iframes)) and (iframes["in_borders"].mean() >= 0.5):
                patches.append(iframes.index) 

In [None]:
len(patches[0])

In [None]:
time_grid.shape

In [None]:
# remove patches that represent reduced regions
patches = [b for b in patches if (len(b)==config.INPUT_BOXES_SIZE**2)]

In [None]:
len(patches)

In [None]:
patches[0]

In [None]:
# remove not used anymore dataframes
del deforestation
del past_defor
del ibama
del frames_county

# Train test split

In [None]:
time_grid.shape

In [None]:
train_time_idx = range(0,156)
val_time_idx = range(104, 206)
test_time_idx = range(154, 258)

train_data = time_grid[train_time_idx, :, :]
val_data = time_grid[val_time_idx, :, :]
test_data = time_grid[test_time_idx, :, :]

In [None]:
print(f"""
train : {config.TIME_STEPS[train_time_idx[52]].date()} -> {config.TIME_STEPS[train_time_idx[-1]].date()}
val   : {config.TIME_STEPS[val_time_idx[52]].date()} -> {config.TIME_STEPS[val_time_idx[-1]].date()}
test  : {config.TIME_STEPS[test_time_idx[52]].date()} -> {config.TIME_STEPS[test_time_idx[-1]].date()}
""")

In [None]:
config.TIME_STEPS[-1]

# Data Normalization

In [None]:
time_grid.shape

In [None]:
n1 = (train_data <= 1e-18).sum() - ((~frames_idx["in_borders"]).sum() * len(train_time_idx))
n1

In [None]:
plt.hist(train_data[(train_data > 1e-12)], bins=10)
plt.show()

In [None]:
n2 = (train_data > 1e-12).sum()
n2

In [None]:
imbalance_factor = n1 / n2
imbalance_factor

In [None]:
for i in range(ib_array.shape[0]):
    ib_array[i, :, :] = (ib_array[i, :, :] - ib_array[i, :, :].mean()) / ib_array[i, :, :].std()

In [None]:
norm_pop = (county_data[0, :, :] - np.median(county_data[0, :, :])) / 1e5
norm_den = (county_data[1, :, :] - np.median(county_data[1, :, :])) / 30

county_data[0, :, :] = norm_pop
county_data[1, :, :] = norm_den

# Loss function

In [None]:
# focal loss
from segmentation_models_pytorch.losses import FocalLoss
loss = FocalLoss("binary", gamma=3, reduction="mean").to(dev)

# Dataset & Dataloaders

In [None]:
FUTURE_WINDOW_PRED = 4

In [None]:
class CustomDataset(Dataset):
    def __init__(
        self, 
        X, 
        patches, 
        frames_idx,
        county_data=None,
        ibama_data=None,
        past_defor_data=None
    ):
        super(CustomDataset, self).__init__()

        self.patches = patches
        self.frames_idx = frames_idx
        self.X = X
        self.county_data = county_data
        self.ibama_data = ibama_data
        self.past_defor_data = past_defor_data

        self.autor_window = 52
        self.future_window = FUTURE_WINDOW_PRED
        self.ix = frames_idx["x"].min()
        self.iy = frames_idx["y"].min()

    def __len__(self):
        return len(self.patches) * (self.X.shape[0]-self.autor_window-self.future_window+1)

    def __getitem__(self, index):

        # get index info
        idx_patch = index // (self.X.shape[0]-self.autor_window-self.future_window+1)
        idx_time   = index % (self.X.shape[0]-self.autor_window-self.future_window+1)
        idx_frames = self.frames_idx.loc[self.patches[idx_patch]]

        # get input
        
        # full past deforestation
        input_matrix = self.past_defor_data[
            idx_time+self.autor_window, 
            idx_frames["x"].min()-self.ix:idx_frames["x"].max()-self.ix+1, 
            idx_frames["y"].min()-self.iy:idx_frames["y"].max()-self.iy+1
        ]
        input_matrix = input_matrix.reshape(1, input_matrix.shape[0], input_matrix.shape[1])
        
        # # last 'autor_window' weeks of deforestation
        # input_matrix = self.X[
        #     idx_time:idx_time+self.autor_window, 
        #     idx_frames["x"].min()-self.ix:idx_frames["x"].max()-self.ix+1, 
        #     idx_frames["y"].min()-self.iy:idx_frames["y"].max()-self.iy+1
        # ].sum(axis=0).clip(max=1.0)
        # input_matrix = input_matrix.reshape(1, input_matrix.shape[0], input_matrix.shape[1])
        
        # some 'delta weeks' of deforestation
        lagweeks = [52, 24, 12, 4, 2, 0]
        for i in range(len(lagweeks)-1):
            input_matrix = np.concatenate([
                input_matrix,
                self.X[
                    idx_time+self.autor_window-lagweeks[i]:idx_time+self.autor_window-lagweeks[i+1],
                    idx_frames["x"].min()-self.ix:idx_frames["x"].max()-self.ix+1, 
                    idx_frames["y"].min()-self.iy:idx_frames["y"].max()-self.iy+1
                ].sum(axis=0).clip(max=1.0)\
                .reshape((1, input_matrix.shape[1], input_matrix.shape[2]))
            ])
        
        if self.county_data is not None:
            input_matrix = np.concatenate([
                input_matrix,
                self.county_data[
                    :,
                    idx_frames["x"].min()-self.ix:idx_frames["x"].max()-self.ix+1, 
                    idx_frames["y"].min()-self.iy:idx_frames["y"].max()-self.iy+1
                ]
            ])
        
        
        if self.ibama_data is not None:
            input_matrix = np.concatenate([
                input_matrix,
                self.ibama_data[
                    :,
                    idx_frames["x"].min()-self.ix:idx_frames["x"].max()-self.ix+1, 
                    idx_frames["y"].min()-self.iy:idx_frames["y"].max()-self.iy+1
                ]
            ])
            
        data = torch.tensor(input_matrix).float().to(dev)

        # get output
        labels = np.zeros(
            (
                idx_frames["x"].max()-idx_frames["x"].min() + 1, 
                idx_frames["y"].max()-idx_frames["y"].min() + 1
            )
        )
        target_idx = np.where(
            self.X[
                idx_time+self.autor_window:idx_time+self.autor_window+self.future_window, 
                idx_frames["x"].min()-self.ix:idx_frames["x"].max()-self.ix+1, 
                idx_frames["y"].min()-self.iy:idx_frames["y"].max()-self.iy+1
            ].sum(axis=0) > 1e-18
        )
        labels[target_idx] = 1
        labels = torch.tensor(labels).float().to(dev)
        
        return data, labels

In [None]:
train_data.shape, val_data.shape, test_data.shape

In [None]:
np.random.seed(0)

In [None]:
patches_sample_train = patches
patches_sample_val = patches

# rand_patches_idx = np.random.choice(range(len(patches)), 60)
# patches_sample_train = [patches[idx] for idx in rand_patches_idx]
# patches_sample_val = [patches[idx] for idx in rand_patches_idx]

In [None]:
trainloader = torch.utils.data.DataLoader(
    CustomDataset(
        train_data, 
        patches_sample_train, 
        frames_idx,
        past_defor_data=past_grid,
        county_data=county_data,
        ibama_data=ib_array
    ),
    batch_size=128,
    shuffle=True
)

valloader = torch.utils.data.DataLoader(
    CustomDataset(
        val_data, 
        patches_sample_val, 
        frames_idx,
        past_defor_data=past_grid,
        county_data=county_data,
        ibama_data=ib_array
    ),
    batch_size=128,
    shuffle=True
)

In [None]:
trainloader.__len__(), valloader.__len__()

# Baseline Model

Evaluate error without any model

In [None]:
# baseline: all zero
base_train_err = 0
for inputs, labels in tqdm(trainloader):
    y_pred = torch.tensor(-10*np.ones(labels.shape)).to(dev)
    base_train_err += loss(y_pred=y_pred, y_true=labels)
base_train_err = base_train_err / len(trainloader)

print(f"Baseline Error (Train) = {base_train_err:.6f}")

In [None]:
base_val_err = 0
for inputs, labels in tqdm(valloader):
    y_pred = torch.tensor(-10*np.ones(labels.shape)).to(dev)
    base_val_err += loss(y_pred=y_pred, y_true=labels)
base_val_err = base_val_err / len(valloader)
print(f"Baseline Error (Validation) = {base_val_err:.6f}")

# Model Init

In [None]:
import torch.optim as optim
from torch import nn
from resunet import ResUnet

In [None]:
inputs.shape

In [None]:
in_channels = 25
in_channels

In [None]:
model = ResUnet(
    channel=in_channels,
    output_dim=1,
    filters=[16, 32, 64, 128]
).to(dev)
optimizer = optim.Adam(model.parameters())#, lr=1e-4)

In [None]:
sum([p.flatten().shape[0] for p in model.parameters()])

# Train loop

In [None]:
# train loop
model.epoch = 0
model.errs = []

In [None]:
def evaluate_model(model, dataloader):
    err = 0
    for inputs, labels in dataloader:
        err += loss(model(inputs), labels).detach()
    err = err / len(dataloader)

    return err

In [None]:
def run_epoch():
    model.epoch += 1
    print(f"\nEpoch {model.epoch}")
    
    train_err = 0
    for inputs, labels in tqdm(trainloader):
        L = loss(model(inputs), labels)
        optimizer.zero_grad()
        L.backward()
        optimizer.step()
        train_err += L.detach()
    train_err = train_err / len(trainloader)
    
    return train_err


def train(n_epochs):
    for epoch in range(n_epochs):
        
        # train for 1 epoch and compute error
        train_err = run_epoch()

        # compute validation error and save history            
        val_err = evaluate_model(model, valloader)
        model.errs.append([train_err, val_err])

        print(f"Epoch {model.epoch}: Train Loss = {train_err:.6f} | Validation Loss = {val_err:.6f}")

In [None]:
%%time

train(1)

In [None]:
train(1)

In [None]:
train(1)

In [None]:
train(1)

In [None]:
train(1)

Better than baseline?

In [None]:
(
    float((model.errs[-1][0] - base_train_err) / base_train_err), 
    float((model.errs[-1][1] - base_val_err) / base_val_err)
)

Learning curve

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
x = np.array(range(len(model.errs))) + 1
sns.lineplot(x=x, y=[float(e[0]) for e in model.errs], label="Train")
sns.lineplot(x=x, y=[float(e[1]) for e in model.errs], label="Validation")
ax.set_title("Loss")
ax.set_xlabel("Epochs")
plt.show()

# Test DataLoader

In [None]:
testloader = torch.utils.data.DataLoader(
    CustomDataset(
        test_data, 
        patches_sample_val, 
        frames_idx,
        county_data=county_data,
        ibama_data=ib_array,
        past_defor_data=past_grid
    ),
    batch_size=64,
    shuffle=True,
)

# Prediction example

In [None]:
# TODO: erro por true value

In [None]:
def get_sample(min_area=0.1, dataloader=trainloader):
    for input_, truth in dataloader:
        batches = list(range(truth.shape[0])) 
        shuffle(batches)
        for idx_batch in batches:
            if truth[idx_batch, :, :].mean() >= min_area:
                return input_, truth, idx_batch

In [None]:
sigmoid_fun = nn.Sigmoid()

In [None]:
input_, truth, idx_batch = get_sample(1e-3, testloader)

pred = sigmoid_fun(model(input_))[idx_batch, 0, :, :]
label = truth[idx_batch, :, :].cpu()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 6))
m1 = ax[0].matshow(label, cmap="Reds", vmin=0, vmax=1)
m2 = ax[1].matshow(pred.detach().cpu().numpy(), cmap="Reds", vmin=0, vmax=1)
m3 = ax[2].matshow(pred.detach().cpu().numpy() - label.cpu().numpy(), cmap="seismic", vmin=-1, vmax=1)
fig.colorbar(m1)
fig.colorbar(m2)
fig.colorbar(m3)
ax[0].set_title("True")
ax[1].set_title("Prediction")
ax[2].set_title("Error")
fig.show()

In [None]:
fig, ax = plt.subplots(5, 5, figsize=(24, 8))
for i in range(5):
    for j in range(5):
        ax[j, i].matshow(input_[idx_batch, i+5*j, :, :].cpu() > 0, cmap="Reds", vmin=truth.min(), vmax=truth.max())
plt.show()

# Treshold Selection

Use validation dataset to select the treshold that maximizes F1-score

In [None]:
pt_values = range(0, 101, 1)
val_p_cm = np.zeros((len(pt_values), 2, 2)).astype(int)
for inputs, labels in tqdm(valloader):
    # compute prediction (probability)
    p_hat = sigmoid_fun(model(inputs))[:, 0, :, :].detach().cpu().numpy().flatten()
    vals = labels[:, :, :].cpu().numpy().flatten()
    for i, ptresh in enumerate(pt_values):
        # for each treshold compute confusion matrix
        y_hat = p_hat >= ptresh / 100
        unique_0, counts_0 = np.unique(y_hat[np.where(vals == 0)], return_counts=True)
        unique_1, counts_1 = np.unique(y_hat[np.where(vals == 1)], return_counts=True)
        for (u, c) in zip(unique_0, counts_0):
            if u:  # pred True
                val_p_cm[i, 0, 1] += c
            else:
                val_p_cm[i, 0, 0] += c
        for (u, c) in zip(unique_1, counts_1):
            if u:
                val_p_cm[i, 1, 1] += c
            else:
                val_p_cm[i, 1, 0] += c

In [None]:
acc = (val_p_cm[:, 0, 0] + val_p_cm[:, 1, 1]) / val_p_cm.sum(axis=(1,2))
prc = val_p_cm[:, 1, 1] / val_p_cm[:, :, 1].sum(axis=1)
prc[np.where(prc!=prc)] = 1e-8
rcl = val_p_cm[:, 1, 1] / val_p_cm[:, 1, :].sum(axis=1)
f1 = 2 * prc * rcl / (prc + rcl)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(18, 9), sharex=True)
sns.lineplot(x=pt_values, y=acc, ax=ax[0, 0])
sns.lineplot(x=pt_values, y=prc, ax=ax[0, 1])
sns.lineplot(x=pt_values, y=rcl, ax=ax[1, 0])
sns.lineplot(x=pt_values, y=f1, ax=ax[1, 1])
ax[0,0].set_title("Accuracy")
ax[0,1].set_title("Precision")
ax[1,0].set_title("Recall")
ax[1,1].set_title("F1-Score")
sns.lineplot(x=[0,100], y=[0,1], ax=ax[0,0])
plt.show()

In [None]:
i_treshold = np.argmax(f1[f1>0])
ptreshold = pt_values[i_treshold]
ptreshold, f1[i_treshold], prc[i_treshold]

In [None]:
val_p_cm[i_treshold, :, :]

# Test Score

Apply chosen treshold and compute scores in test dataset

In [None]:
test_cm = np.zeros((2,2)).astype(int)
for inputs, labels in tqdm(testloader):
    pred = (sigmoid_fun(model(inputs)).detach() >= ptreshold / 100).cpu().numpy().flatten()
    vals = labels.cpu().numpy().flatten()
    unique_0, counts_0 = np.unique(pred[np.where(vals == 0)], return_counts=True)
    unique_1, counts_1 = np.unique(pred[np.where(vals == 1)], return_counts=True)
    for (u, c) in zip(unique_0, counts_0):
        if u:  # pred True
            test_cm[0, 1] += c
        else:
            test_cm[0, 0] += c
    for (u, c) in zip(unique_1, counts_1):
        if u:
            test_cm[1, 1] += c
        else:
            test_cm[1, 0] += c

In [None]:
acc_test = (test_cm[0, 0] + test_cm[1, 1]) / test_cm.sum()
prc_test = test_cm[1, 1] / test_cm[:, 1].sum()
rcl_test = test_cm[1, 1] / test_cm[1, :].sum()
f1_test = 2 * prc_test * rcl_test / (prc_test + rcl_test)

In [None]:
test_cm

In [None]:
acc_test, prc_test, rcl_test, f1_test

# Lower spatial granularity

Apply a spatial reduction to data, select new treshold (in validation) and evaluate in test.

In [None]:
import numpy as np
from scipy.ndimage import convolve

def apply_distance_threshold(arr, threshold):
    # Define a kernel that considers neighboring cells within the threshold distance
    kernel_size = 2 * threshold + 1
    kernel = np.ones((kernel_size, kernel_size))
    
    # Convolve the original array with the kernel
    convolved = convolve(arr, kernel, mode='constant', cval=0)
    
    # Apply threshold to the convolved array
    result = convolved > 0
    
    return result.astype(int)

def apply_distance_threshold_3d(arr_3d, threshold):
    result_3d = np.zeros_like(arr_3d)
    
    for i, arr_2d in enumerate(arr_3d):
        result_3d[i] = apply_distance_threshold(arr_2d, threshold)
    
    return result_3d

# Example usage
original_array_3d = np.array([[[0, 0, 0, 0, 0],
                                [0, 0, 0, 0, 0],
                                [0, 0, 0, 0, 0]],
                               
                               [[0, 0, 0, 1, 0],
                                [0, 0, 0, 1, 0],
                                [0, 0, 0, 0, 0]]])

threshold_distance = 1

result_array_3d = apply_distance_threshold_3d(original_array_3d, threshold_distance)
print(result_array_3d)

In [None]:
dims = [3, 5, 7, 9, 11]

## Select treshold by dimensionality

In [None]:
pt_values = list(range(15, 46))

In [None]:
val_pd_cm = np.zeros((len(dims), len(pt_values), 2, 2)).astype(int)
for inputs, labels in tqdm(valloader):
    # compute prediction (probability)
    p_hat = sigmoid_fun(model(inputs))[:, 0, :, :].detach().cpu().numpy()

    for idim, dim in enumerate(dims):
        dist = int(dim / 2)
        vals = labels[:, :, :].cpu().numpy()
        vals = apply_distance_threshold_3d(vals, dist).flatten()

        for i, ptresh in enumerate(pt_values):
            # for each treshold compute confusion matrix
            y_hat = p_hat >= ptresh / 100
            y_hat = apply_distance_threshold_3d(y_hat, dist).flatten()
            unique_0, counts_0 = np.unique(y_hat[np.where(vals == 0)], return_counts=True)
            unique_1, counts_1 = np.unique(y_hat[np.where(vals == 1)], return_counts=True)
            for (u, c) in zip(unique_0, counts_0):
                if u:  # pred True
                    val_pd_cm[idim, i, 0, 1] += c
                else:
                    val_pd_cm[idim, i, 0, 0] += c
            for (u, c) in zip(unique_1, counts_1):
                if u:
                    val_pd_cm[idim, i, 1, 1] += c
                else:
                    val_pd_cm[idim, i, 1, 0] += c

In [None]:
acc_pd = (val_pd_cm[:, :, 0, 0] + val_pd_cm[:, :, 1, 1]) / val_pd_cm.sum(axis=(2,3))
prc_pd = val_pd_cm[:, :, 1, 1] / val_pd_cm[:, :, :, 1].sum(axis=2)
prc_pd[np.where(prc_pd!=prc_pd)] = 1e-8
rcl_pd = val_pd_cm[:, :, 1, 1] / val_pd_cm[:, :, 1, :].sum(axis=2)
f1_pd = 2 * prc_pd * rcl_pd / (prc_pd + rcl_pd)
f1_pd[np.where(f1_pd!=f1_pd)] = 1e-8

In [None]:
itresh_by_dim = np.argmax(f1_pd, axis=1)

In [None]:
ptresh_by_dim = [pt_values[i] for i in itresh_by_dim]

In [None]:
for idim in range(len(dims)):
    print(dims[idim])
    print(ptresh_by_dim[idim])
    print(val_pd_cm[idim, itresh_by_dim[idim], :, :])
    print(f1_pd[idim, itresh_by_dim[idim]])
    print()

## Evaluate on test

In [None]:
test_d_cm = np.zeros((len(dims), 2, 2)).astype(int)
for inputs, labels in tqdm(testloader):
    # compute prediction (probability)
    p_hat = sigmoid_fun(model(inputs))[:, 0, :, :].detach().cpu().numpy()

    for idim, dim in enumerate(dims):
        dist = int(dim / 2)
        vals = labels[:, :, :].cpu().numpy()
        vals = apply_distance_threshold_3d(vals, dist).flatten()

        ptresh = ptresh_by_dim[idim]
        y_hat = p_hat >= ptresh / 100
        y_hat = apply_distance_threshold_3d(y_hat, dist).flatten()
        unique_0, counts_0 = np.unique(y_hat[np.where(vals == 0)], return_counts=True)
        unique_1, counts_1 = np.unique(y_hat[np.where(vals == 1)], return_counts=True)
        for (u, c) in zip(unique_0, counts_0):
            if u:  # pred True
                test_d_cm[idim, 0, 1] += c
            else:
                test_d_cm[idim, 0, 0] += c
        for (u, c) in zip(unique_1, counts_1):
            if u:
                test_d_cm[idim, 1, 1] += c
            else:
                test_d_cm[idim, 1, 0] += c

In [None]:
acc_d = (test_d_cm[:, 0, 0] + test_d_cm[:, 1, 1]) / test_d_cm.sum(axis=(1,2))
prc_d = test_d_cm[:, 1, 1] / test_d_cm[:, :, 1].sum(axis=1)
prc_d[np.where(prc_d!=prc_d)] = 1e-8
rcl_d = test_d_cm[:, 1, 1] / test_d_cm[:, 1, :].sum(axis=1)
f1_d = 2 * prc_d * rcl_d / (prc_d + rcl_d)

In [None]:
f1_d

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(18, 9), sharex=True)
sns.barplot(x=dims, y=acc_d, ax=ax[0, 0], color=sns.color_palette()[0])
sns.barplot(x=dims, y=prc_d, ax=ax[0, 1], color=sns.color_palette()[0])
sns.barplot(x=dims, y=rcl_d, ax=ax[1, 0], color=sns.color_palette()[0])
sns.barplot(x=dims, y=f1_d, ax=ax[1, 1], color=sns.color_palette()[0])
ax[0,0].set_title("Accuracy")
ax[0,1].set_title("Precision")
ax[1,0].set_title("Recall")
ax[1,1].set_title("F1-Score")
plt.show()

# Export Predict TIF

In [None]:
# t = 10
# test_dataloader_full = torch.utils.data.DataLoader(
#     CustomDataset(
#         test_data[t:t+54, :, :], 
#         patches, 
#         frames_idx,
#         ibama_data=ib_array
#     ),
#     batch_size=1,
#     shuffle=False
# )

In [None]:
# index = 0
# labels_list = []
# preds_list = []
# for inputs, labels in tqdm(test_dataloader_full):
#     # get index info
#     idx_patch = index // (test_dataloader_full.dataset.X.shape[0] - test_dataloader_full.dataset.autor_window - test_dataloader_full.dataset.future_window + 1)
#     idx_time   = index % (test_dataloader_full.dataset.X.shape[0] - test_dataloader_full.dataset.autor_window - test_dataloader_full.dataset.future_window + 1)
#     idx_frames = test_dataloader_full.dataset.frames_idx.loc[test_dataloader_full.dataset.patches[idx_patch]]
    
#     labels_df = pd.melt(
#         pd.DataFrame(labels.cpu().numpy()[0, :, :], index=idx_frames["x"].unique(), columns=idx_frames["y"].unique()),
#         ignore_index=False
#     ).reset_index().rename(columns={"index":"x", "variable":"y","value":"label"})
    
#     preds_df = pd.melt(
#         pd.DataFrame(model(inputs).detach().cpu().numpy()[0, 0, :, :], index=idx_frames["x"].unique(), columns=idx_frames["y"].unique()),
#         ignore_index=False
#     ).reset_index().rename(columns={"index":"x", "variable":"y","value":"pred"})
    
#     labels_list.append(labels_df)
#     preds_list.append(preds_df)
    
#     index += 1

In [None]:
# result_df = pd.merge(
#     pd.concat(labels_list).groupby(["x", "y"]).max().reset_index(),
#     pd.concat(preds_list).groupby(["x", "y"]).mean().reset_index(),
#     on=["x","y"],
#     how="outer",
#     validate="1:1"
# )

In [None]:
# # apply sigmoid to prediction values
# result_df["pred"] = 1 / (1+(-result_df["pred"]).apply(np.exp))

In [None]:
# %%time
# grid = gpd.read_file(config.TR_FRAMES)

In [None]:
# grid_with_results = grid.merge(result_df, on=["x","y"], how="left", validate="1:1")

## Save it as raster

In [None]:
# am_bounds = gpd.read_file(config.AMAZON_FRONTIER_DATA)

In [None]:
# from geocube.api.core import make_geocube

# # raster data with bands by date
# grid_raster_data = make_geocube(
#     vector_data=grid_with_results.dropna(),
#     resolution=(0.01, 0.01),
#     measurements=["label", "pred"],
#     geom=am_bounds.geometry.item(),
#     fill=-1.0
# )

In [None]:
# grid_raster_data.rio.to_raster("model_pred_example.tif")

# Save model and results

In [None]:
import os
from datetime import datetime

In [None]:
nowtimestr = datetime.strftime(datetime.now(), format="%Y%m%d%H%M%S")
os.mkdir(nowtimestr)

# save model
torch.save(model.state_dict(), os.path.join(nowtimestr, r"resunet_class_0.01.pt"))

In [None]:
result_str = f"""
PREDICTION WINDOW (WEEKS): {FUTURE_WINDOW_PRED}

BEST TRESHOLD: {ptreshold}
CONFUSION MATRIX[0,0]: {test_cm[0,0]}
CONFUSION MATRIX[0,1]: {test_cm[0,1]}
CONFUSION MATRIX[1,0]: {test_cm[1,0]}
CONFUSION MATRIX[1,1]: {test_cm[1,1]}
F1-SCORE (TEST): {f1_test}
AUGMENTED DIMENSION SIZES: {list(dims)}
"""

for idim, dim in enumerate(dims):
    result_str += f"""
    FOR AUG SPATIAL DIM: {dim}
    BEST TRESHOLD: {ptresh_by_dim[idim]}
    CONFUSION MATRIX[0,0]: {test_d_cm[idim, 0,0]}
    CONFUSION MATRIX[0,1]: {test_d_cm[idim, 0,1]}
    CONFUSION MATRIX[1,0]: {test_d_cm[idim, 1,0]}
    CONFUSION MATRIX[1,1]: {test_d_cm[idim, 1,1]}
    F1-SCORE (TEST): {f1_d[idim]}
    """

In [None]:
print(result_str)

In [None]:
with open(os.path.join(nowtimestr, "results.txt"), "w") as file:
    file.write(result_str)