### Imports

In [1]:
import torch
import random
import numpy as np

from functions.download_data import (
    get_patchs_labels,
    normalization_params,
    get_golden_paths,
    pooled_std_dev,
)

from functions.filter import filter_indices_from_labels

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from functions.instanciators import get_dataset, get_lightning_module, get_trainer
from functions.instanciators import get_model
from torch.utils.data import random_split, DataLoader
from torch import Generator
from config.module import module_dict

import mlflow

import gc

seed = 12345 
torch.manual_seed(seed)
random.seed(0)
np.random.seed(0)

#from torchvision.models.resnet import resnet50
import torch.nn as nn

### Variables

In [2]:
remote_server_uri = "https://projet-slums-detection-128833.user.lab.sspcloud.fr"
experiment_name = "test-dev"
run_name =  "stagiosessaye"
task = "segmentation"
source = "PLEIADES"
deps =  ["MARTINIQUE"]
years = ["2022"]
tiles_size = 250
augment_size = 250
type_labeler = "BDTOPO"
n_bands = 3
logits = 1
freeze_encoder = 0
epochs = 10
batch_size = 8
test_batch_size = 8
num_sanity_val_steps = 1
accumulate_batch = 8
module_name = "PSPNet"
loss_name =  "cross_entropy_weighted"
building_class_weight = 1
label_smoothing = 0.0
lr = 0.00005
momentum = float
scheduler_name = "one_cycle"
scheduler_patience = 3
patience = 200
from_s3 = 0
seed = 12345 
cuda = 0
cuda = cuda and torch.cuda.is_available()
kwargs = {"num_workers": os.cpu_count(), "pin_memory": True} if cuda else {}

dep, year  = "MARTINIQUE", "2022"

### Données

In [3]:
patches, labels = get_patchs_labels(
        from_s3, task, source, dep, year, tiles_size, type_labeler, train=True
    )

train_patches = []
train_labels = []
test_patches = []
test_labels = []
normalization_means = []
normalization_stds = []
weights = []

patches.sort()
labels.sort()
indices = filter_indices_from_labels(labels, -1.0, 2.0)
train_patches += [patches[idx] for idx in indices]
train_labels += [labels[idx] for idx in indices]

module_name


patches, labels = get_patchs_labels(
    from_s3, task, source, dep, year, tiles_size, type_labeler, train=False
)

patches.sort()
labels.sort()
test_patches += list(patches)
test_labels += list(labels)

normalization_mean, normalization_std = normalization_params(
    task, source, dep, year, tiles_size, type_labeler
)
normalization_means.append(normalization_mean)
normalization_stds.append(normalization_std)
weights.append(len(indices))

# Golden test
golden_patches, golden_labels = get_golden_paths(
    from_s3, task, source, "MAYOTTE_CLEAN", "2022", tiles_size
)

golden_patches.sort()
golden_labels.sort()


`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0019.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0019.jp2`
`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0028.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0524_8587_U38S_8Bits_0028.jp2`
`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0512_8592_U38S_8Bits_0005.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0512_8592_U38S_8Bits_0005.jp2`
`s3/projet-slums-detection/golden-test/patchs/segmentation/PLEIADES/MAYOTTE_CLEAN/2022/250/ORT_976_2022_0513_8593_U38S_8Bits_0034.jp2` -> `data/data-preprocessed/golden-test/patchs/segmentat

In [4]:

normalization_mean = np.average(
    [mean[:n_bands] for mean in normalization_means], weights=weights, axis=0
)
normalization_std = [
    pooled_std_dev(
        weights,
        [mean[i] for mean in normalization_means],
        [std[i] for std in normalization_stds],
    )
    for i in range(n_bands)
]

transform_list = [
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean,
        std=normalization_std,
    ),
    ToTensorV2(),
]

if augment_size != tiles_size:
    transform_list.insert(0, A.Resize(augment_size, augment_size))
transform = A.Compose(transform_list)

test_transform_list = [
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean,
        std=normalization_std,
    ),
    ToTensorV2(),
]
if augment_size != tiles_size:
    test_transform_list.insert(0, A.Resize(augment_size, augment_size))
test_transform = A.Compose(test_transform_list)

### Dataset

In [5]:
dataset = get_dataset(task, train_patches, train_labels, n_bands, from_s3, transform)
dataset = get_dataset(task, train_patches[:40], train_labels[:40], n_bands, from_s3, transform)
test_dataset = get_dataset(task, test_patches, test_labels, n_bands, from_s3, test_transform)
golden_dataset = get_dataset(
    task, golden_patches, golden_labels, n_bands, from_s3, test_transform
)

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=Generator())

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs
)
val_loader = DataLoader(
    val_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)
test_loader = DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)
golden_loader = DataLoader(
    golden_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)

In [6]:
model = get_model(module_name,3, True, False)

batch = next(iter(train_loader))
labels = batch["labels"]
images = batch["pixel_values"]

output = model(images)
print(output)




Stage output shape: torch.Size([8, 512, 8, 8])
Stage output shape: torch.Size([8, 512, 8, 8])
Stage output shape: torch.Size([8, 512, 8, 8])
Stage output shape: torch.Size([8, 512, 8, 8])
Concatenated shape: torch.Size([8, 4096, 8, 8])
tensor([[[[ 0.3415,  0.3415,  0.3415,  ...,  0.1319,  0.1319,  0.1319],
          [ 0.3415,  0.3415,  0.3415,  ...,  0.1319,  0.1319,  0.1319],
          [ 0.3415,  0.3415,  0.3415,  ...,  0.1319,  0.1319,  0.1319],
          ...,
          [ 0.4882,  0.4882,  0.4882,  ..., -0.0425, -0.0425, -0.0425],
          [ 0.4882,  0.4882,  0.4882,  ..., -0.0425, -0.0425, -0.0425],
          [ 0.4882,  0.4882,  0.4882,  ..., -0.0425, -0.0425, -0.0425]]],


        [[[-0.1728, -0.1728, -0.1728,  ..., -0.1293, -0.1293, -0.1293],
          [-0.1728, -0.1728, -0.1728,  ..., -0.1293, -0.1293, -0.1293],
          [-0.1728, -0.1728, -0.1728,  ..., -0.1293, -0.1293, -0.1293],
          ...,
          [-0.5636, -0.5636, -0.5636,  ...,  0.2874,  0.2874,  0.2874],
          