### Imports

In [1]:
import torch
import random
import numpy as np

from functions.download_data import (
    get_patchs_labels,
    normalization_params,
    get_golden_paths,
    pooled_std_dev,
)

from functions.filter import filter_indices_from_labels

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from functions.instanciators import get_dataset, get_lightning_module, get_trainer
from functions.instanciators import get_model
from torch.utils.data import random_split, DataLoader
from torch import Generator
from config.module import module_dict

import mlflow

import gc

seed = 12345 
torch.manual_seed(seed)
random.seed(0)
np.random.seed(0)

#from torchvision.models.resnet import resnet50
import torch.nn as nn

### Variables

In [5]:
remote_server_uri = "https://projet-slums-detection-128833.user.lab.sspcloud.fr"
experiment_name = "test-dev"
run_name =  "stagiosessaye"
task = "segmentation"
source = "PLEIADES"
deps =  ["MARTINIQUE"]
years = ["2022"]
tiles_size = 250
augment_size = 250
type_labeler = "BDTOPO"
n_bands = 3
logits = 1
freeze_encoder = 0
epochs = 10
batch_size = 8
test_batch_size = 8
num_sanity_val_steps = 1
accumulate_batch = 8
module_name = "DeepLabGaetan"
loss_name =  "cross_entropy_weighted"
building_class_weight = 1
label_smoothing = 0.0
lr = 0.00005
momentum = float
scheduler_name = "one_cycle"
scheduler_patience = 3
patience = 200
from_s3 = 0
seed = 12345 
cuda = 0
cuda = cuda and torch.cuda.is_available()
kwargs = {"num_workers": os.cpu_count(), "pin_memory": True} if cuda else {}

dep, year  = "MARTINIQUE", "2022"

### Données

In [3]:
patches, labels = get_patchs_labels(
        from_s3, task, source, dep, year, tiles_size, type_labeler, train=True
    )

train_patches = []
train_labels = []
test_patches = []
test_labels = []
normalization_means = []
normalization_stds = []
weights = []

patches.sort()
labels.sort()
indices = filter_indices_from_labels(labels, -1.0, 2.0)
train_patches += [patches[idx] for idx in indices]
train_labels += [labels[idx] for idx in indices]

module_name


patches, labels = get_patchs_labels(
    from_s3, task, source, dep, year, tiles_size, type_labeler, train=False
)

patches.sort()
labels.sort()
test_patches += list(patches)
test_labels += list(labels)

normalization_mean, normalization_std = normalization_params(
    task, source, dep, year, tiles_size, type_labeler
)
normalization_means.append(normalization_mean)
normalization_stds.append(normalization_std)
weights.append(len(indices))

# Golden test
golden_patches, golden_labels = get_golden_paths(
    from_s3, task, source, "MAYOTTE_CLEAN", "2022", tiles_size
)

golden_patches.sort()
golden_labels.sort()


`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0513_8568_U38S_8Bits_0026.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0513_8568_U38S_8Bits_0026.jp2`
`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0513_8568_U38S_8Bits_0025.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0513_8568_U38S_8Bits_0025.jp2`
`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0512_8592_U38S_8Bits_0005.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0512_8592_U38S_8Bits_0005.jp2`
`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0513_8593_U38S_8Bits_0034.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentat

In [6]:
n_bands = 3
normalization_mean = np.average(
    [mean[:n_bands] for mean in normalization_means], weights=weights, axis=0
)
normalization_std = [
    pooled_std_dev(
        weights,
        [mean[i] for mean in normalization_means],
        [std[i] for std in normalization_stds],
    )
    for i in range(n_bands)
]

transform_list = [
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean,
        std=normalization_std,
    ),
    ToTensorV2(),
]

augment_size = 250
if augment_size != tiles_size:
    transform_list.insert(0, A.Resize(augment_size, augment_size))
transform = A.Compose(transform_list)

test_transform_list = [
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean,
        std=normalization_std,
    ),
    ToTensorV2(),
]
if augment_size != tiles_size:
    test_transform_list.insert(0, A.Resize(augment_size, augment_size))
test_transform = A.Compose(test_transform_list)

### Dataset

In [7]:
dataset = get_dataset(task, train_patches, train_labels, n_bands, from_s3, transform)
dataset = get_dataset(task, train_patches[:40], train_labels[:40], n_bands, from_s3, transform)
test_dataset = get_dataset(task, test_patches, test_labels, n_bands, from_s3, test_transform)
golden_dataset = get_dataset(
    task, golden_patches, golden_labels, n_bands, from_s3, test_transform
)

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=Generator())

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs
)
val_loader = DataLoader(
    val_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)
test_loader = DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)
golden_loader = DataLoader(
    golden_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)

In [12]:
model = get_model(module_name,3, True, False)

batch = next(iter(train_loader))
labels = batch["labels"]
images = batch["pixel_values"]

output = model(images)
print(output)

tensor([[[[-3.9403e-01, -3.9878e-01, -4.0352e-01,  ..., -2.0273e-01,
           -1.7037e-01, -1.3800e-01],
          [-4.0697e-01, -4.0597e-01, -4.0496e-01,  ..., -1.9086e-01,
           -1.6400e-01, -1.3714e-01],
          [-4.1991e-01, -4.1316e-01, -4.0641e-01,  ..., -1.7899e-01,
           -1.5763e-01, -1.3628e-01],
          ...,
          [-1.7058e-01, -1.9111e-01, -2.1163e-01,  ..., -2.7728e-01,
           -2.0851e-01, -1.3974e-01],
          [-1.5920e-01, -1.8332e-01, -2.0743e-01,  ..., -2.7220e-01,
           -2.0073e-01, -1.2926e-01],
          [-1.4782e-01, -1.7552e-01, -2.0323e-01,  ..., -2.6712e-01,
           -1.9295e-01, -1.1877e-01]],

         [[ 2.1710e-01,  1.5233e-01,  8.7554e-02,  ...,  3.1144e-01,
            3.4694e-01,  3.8243e-01],
          [ 2.0432e-01,  1.4618e-01,  8.8035e-02,  ...,  2.7922e-01,
            3.0129e-01,  3.2336e-01],
          [ 1.9153e-01,  1.4003e-01,  8.8517e-02,  ...,  2.4700e-01,
            2.5565e-01,  2.6430e-01],
          ...,
     