### Imports

In [13]:
import torch
import random
import numpy as np

from functions.download_data import (
    get_patchs_labels,
    normalization_params,
    get_golden_paths,
    pooled_std_dev,
)

from functions.filter import filter_indices_from_labels

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from functions.instanciators import get_dataset, get_lightning_module, get_trainer
from torch.utils.data import random_split, DataLoader
from torch import Generator

import mlflow

import gc

seed = 12345 
torch.manual_seed(seed)
random.seed(0)
np.random.seed(0)

from torchvision.models.resnet import resnet50
import torch.nn as nn

### Variables

In [3]:
remote_server_uri = "https://projet-slums-detection-128833.user.lab.sspcloud.fr"
experiment_name = "test-dev"
run_name =  "kikito_stagios"
task = "segmentation"
source = "PLEIADES"
deps =  ["MARTINIQUE"]
years = ["2022"]
tiles_size = 250
augment_size = 512
type_labeler = "BDTOPO"
n_bands = 3
logits = 1
freeze_encoder = 0
epochs = 10
batch_size = 8
test_batch_size = 8
num_sanity_val_steps = 1
accumulate_batch = 8
module_name = "segformer-b5"
loss_name =  "cross_entropy_weighted"
building_class_weight = 1
label_smoothing = 0.0
lr = 0.00005
momentum = float
scheduler_name = "one_cycle"
scheduler_patience = 3
patience = 200
from_s3 = 0
seed = 12345 
cuda = 0
cuda = cuda and torch.cuda.is_available()

dep, year  = "MARTINIQUE", "2022"

### Données

In [None]:
patches, labels = get_patchs_labels(
        from_s3, task, source, dep, year, tiles_size, type_labeler, train=True
    )

train_patches = []
train_labels = []
test_patches = []
test_labels = []
normalization_means = []
normalization_stds = []
weights = []

patches.sort()
labels.sort()
indices = filter_indices_from_labels(labels, -1.0, 2.0)
train_patches += [patches[idx] for idx in indices]
train_labels += [labels[idx] for idx in indices]

patches, labels = get_patchs_labels(
    from_s3, task, source, dep, year, tiles_size, type_labeler, train=False
)

patches.sort()
labels.sort()
test_patches += list(patches)
test_labels += list(labels)

normalization_mean, normalization_std = normalization_params(
    task, source, dep, year, tiles_size, type_labeler
)
normalization_means.append(normalization_mean)
normalization_stds.append(normalization_std)
weights.append(len(indices))

# Golden test
golden_patches, golden_labels = get_golden_paths(
    from_s3, task, source, "MAYOTTE_CLEAN", "2022", tiles_size
)

golden_patches.sort()
golden_labels.sort()


In [5]:
n_bands = 3
normalization_mean = np.average(
    [mean[:n_bands] for mean in normalization_means], weights=weights, axis=0
)
normalization_std = [
    pooled_std_dev(
        weights,
        [mean[i] for mean in normalization_means],
        [std[i] for std in normalization_stds],
    )
    for i in range(n_bands)
]

transform_list = [
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean,
        std=normalization_std,
    ),
    ToTensorV2(),
]

augment_size = 250
if augment_size != tiles_size:
    transform_list.insert(0, A.Resize(augment_size, augment_size))
transform = A.Compose(transform_list)

test_transform_list = [
    A.Normalize(
        max_pixel_value=1.0,
        mean=normalization_mean,
        std=normalization_std,
    ),
    ToTensorV2(),
]
if augment_size != tiles_size:
    test_transform_list.insert(0, A.Resize(augment_size, augment_size))
test_transform = A.Compose(test_transform_list)

### Dataset

In [10]:
dataset = get_dataset(task, train_patches, train_labels, n_bands, from_s3, transform)
dataset = get_dataset(task, train_patches[:40], train_labels[:40], n_bands, from_s3, transform)
test_dataset = get_dataset(task, test_patches, test_labels, n_bands, from_s3, test_transform)
golden_dataset = get_dataset(
    task, golden_patches, golden_labels, n_bands, from_s3, test_transform
)

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=Generator())

batch_size = 8
test_batch_size = 8
cuda = 0
kwargs = {"num_workers": os.cpu_count(), "pin_memory": True} if cuda else {}

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs
)
val_loader = DataLoader(
    val_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)
test_loader = DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)
golden_loader = DataLoader(
    golden_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True, **kwargs
)

In [25]:
%load_ext autoreload
%autoreload 2

from modeles_gaetan import DeepLabV3
from config.module import module_dict
from functions.instanciators import get_model


model = get_model(module_name,3, True, False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Some weights of SemanticSegmentationSegformer were not initialized from the model checkpoint at nvidia/mit-b5 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
batch = next(iter(train_loader))
labels = batch["labels"]
images = batch["pixel_values"]

output = model(images)
print(output)



tensor([[[[-1.6894e-01, -1.7306e-01, -1.4554e-01,  ..., -1.0657e-01,
           -1.3741e-01, -1.4959e-01],
          [-1.5876e-01, -1.5202e-01, -1.2576e-01,  ..., -1.0948e-01,
           -1.2602e-01, -1.2634e-01],
          [-1.6172e-01, -1.3103e-01, -1.8607e-01,  ..., -6.1894e-02,
           -8.8633e-02, -1.1977e-01],
          ...,
          [-6.8017e-02, -6.8532e-02, -5.4099e-02,  ...,  8.6875e-03,
            2.2609e-02,  5.8906e-02],
          [-5.1421e-02, -5.3451e-02, -5.0281e-02,  ...,  3.2626e-02,
            4.3051e-02,  3.7620e-02],
          [-6.8483e-02, -5.5347e-02, -3.1371e-02,  ...,  2.2855e-02,
            6.4174e-02,  4.8061e-02]],

         [[ 6.7517e-02,  9.0774e-02,  5.7544e-02,  ...,  7.5267e-02,
            8.8852e-02,  8.9217e-02],
          [ 7.5122e-02,  9.8914e-02,  7.1640e-02,  ...,  5.4765e-02,
            8.7357e-02,  8.0986e-02],
          [ 7.7997e-02,  7.2922e-02,  7.0272e-02,  ...,  6.3148e-02,
            1.0268e-01,  1.2803e-01],
          ...,
     