SN7

In [1]:
import os
import numpy as np 
import rasterio
import torch
import cv2
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from glob import glob
import geopandas as gpd
from rasterio.features import rasterize
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms.functional as TF
import random
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import copy


Funzione per calcolare media e deviazione standard

In [2]:
def compute_mean_std(image_paths):
    means = []  
    stds = []  
    
    for img_path in image_paths:
        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)  # Carica l'immagine inclusi tutti i canali
        #img_rgb = img[:, :, :3]  # Prendo i primi tre canali (RGB)
        #permute mi serve perchè img la leggo con opencv che mi da HxWxC e pytorch vuole CxHxW
        img_tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)  # Converti in tensor e permuta le dimensioni a (Canali, Altezza, Larghezza)

        means.append(img_tensor.mean(dim=(1, 2)))  # Calcola e aggiungi la media per ogni canale (RGB) alla lista delle medie (1=h, 2=w)
        stds.append(img_tensor.std(dim=(1, 2)))    # Calcola e aggiungi la deviazione standard per ogni canale (RGB) alla lista delle deviazioni standard
    
    mean = torch.stack(means).mean(dim=0)  # Combina tutte le medie delle immagini e calcola la media finale per ogni canale
    std = torch.stack(stds).mean(dim=0)    # Combina tutte le deviazioni standard delle immagini e calcola la deviazione standard finale per ogni canale
    
    return mean.tolist(), std.tolist()  # Converte i tensori risultanti in liste e li restituisce

In [3]:
# Path alle immagini nella cartella "training/images"
image_paths = glob('training/images/*.tif')  

# Esegui la computazione su GPU se disponibile
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

mean, std = compute_mean_std(image_paths)

print(f'Media dei canali RGBM: {mean}')
print(f'Deviazione standard dei canali RGBM: {std}')
    

Media dei canali RGBM: [82.15846252441406, 110.30465698242188, 126.20237731933594, 254.79959106445312]
Deviazione standard dei canali RGBM: [27.31522560119629, 30.465126037597656, 38.47288513183594, 3.201913833618164]


Elaborazione label:
La gestione dell'eccezione: ValueError è dovuta alle immagini e label native del dataset, in pratica in alcune immagini satellitari, c'è la presenza di nuvole, oppure rumore di qualche altro tipo. Per gestire questo rumore è stato creato un 4 canale, che effettua una maschera coprente  sulle zone dove c'è rumore/nuvole. Ci sono anche alcune labels apposite per le maschere. In casi eccezzionali l'intera immagine ha la maschera, dove avviene il mascheramento, le labels ordinarie sono prive di dati geojson, quindi se l'intera immagine è mascherata alcune label sono vuote e quindi cioò comporta quell'errore, che vista l'assenza di label, ho deciso di gestire la problematica semplicemente creando una maschera nera. 

In [4]:
# Definisce una funzione per creare una maschera binaria da geometrie
def create_mask(geojson, img_shape, transform):
    shapes = [(geom, 1) for geom in geojson.geometry if geom.is_valid and not geom.is_empty]
    if not shapes:
        raise ValueError('No valid geometry objects found for rasterize')
    mask = rasterize(shapes=shapes, out_shape=img_shape, transform=transform, fill=0, dtype=np.uint8)
    return mask

# Salva la maschera binaria normalizzata per visualizzazione
def save_normalized_mask(mask, output_path, crs, transform):
    mask_normalized = (mask * 255).astype(np.uint8)
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=mask_normalized.shape[0],
        width=mask_normalized.shape[1],
        count=1,
        dtype=mask_normalized.dtype,
        crs=crs,
        transform=transform,
    ) as dst:
        dst.write(mask_normalized, 1)

def process_dataset(image_dir, label_dir, output_dir):
    image_paths = sorted(glob(os.path.join(image_dir, '*.tif')))
    label_paths = sorted(glob(os.path.join(label_dir, '*.geojson')))
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    processed_count = 0

    for img_path, label_path in zip(image_paths, label_paths):
        labels_match = gpd.read_file(label_path)

        with rasterio.open(img_path) as src:
            img_shape = (src.height, src.width)
            transform = src.transform
            img_crs = src.crs

        try:
            mask_labels_match = create_mask(labels_match, img_shape, transform)
        except ValueError as e:
            print(f"Empty or invalid label, noise on image: {img_path} due to error: {e}")
            mask_labels_match = np.zeros(img_shape, dtype=np.uint8)  # Crea una maschera nera in caso di errore
        
        img_name = os.path.basename(img_path)
        mask_name = img_name.replace('.tif', '_label_mask.tif')
        mask_output_path = os.path.join(output_dir, mask_name)

        save_normalized_mask(mask_labels_match, mask_output_path, img_crs, transform)
        
        processed_count += 1

        if processed_count % 100 == 0:
            print(f'Processed: {processed_count} masks')

Path

In [5]:
image_dir_train = 'training/images_masked'
label_dir_train = 'training/labels/labels_match'
mask_dir_train = 'training/mask_from_label'

image_dir_valid = 'validation/images_masked'
label_dir_valid = 'validation/labels/labels_match'
mask_dir_valid = 'validation/mask_from_label'

image_dir_test = 'test/images_masked'
label_dir_test = 'test/labels/labels_match'
mask_dir_test = 'test/mask_from_label'

In [6]:
#process_dataset(image_dir_train, label_dir_train, mask_dir_train)

Validation

In [7]:
#process_dataset(image_dir_valid, label_dir_valid, mask_dir_valid)

Test

In [8]:
#process_dataset(image_dir_test, label_dir_test, mask_dir_test)

In [9]:
# Output file paths
train_txt_path = os.path.join('training', 'train.txt')
val_txt_path = os.path.join('validation', 'val.txt')
test_txt_path = os.path.join('test', 'test.txt')

# Function to create txt files with image and mask paths
def create_txt_file(image_dir, mask_dir, output_txt_path):
    image_paths = sorted(glob(os.path.join(image_dir, '*.tif')))
    mask_paths = sorted(glob(os.path.join(mask_dir, '*.tif')))

    if len(image_paths) != len(mask_paths):
        raise ValueError(f"The number of images and masks do not match in {image_dir} and {mask_dir}")

    with open(output_txt_path, 'w') as f:
        for img_path, mask_path in zip(image_paths, mask_paths):
            f.write(f"{img_path} {mask_path}\n")

# Create train.txt, val.txt, and test.txt
create_txt_file(image_dir_train, mask_dir_train, train_txt_path)
create_txt_file(image_dir_valid, mask_dir_valid, val_txt_path)
create_txt_file(image_dir_test, mask_dir_test, test_txt_path)

print(f"Created {train_txt_path}, {val_txt_path}, and {test_txt_path} with image and mask references.")


Created training/train.txt, validation/val.txt, and test/test.txt with image and mask references.


In [10]:
with open(train_txt_path, 'r') as f:
    train_data = [line.split() for line in f.read().splitlines()]

train_image_paths = [line[0] for line in train_data]
train_label_paths = [line[1] for line in train_data]

with open(val_txt_path, 'r') as f:
    val_data = [line.split() for line in f.read().splitlines()]

val_image_paths = [line[0] for line in val_data]
val_label_paths = [line[1] for line in val_data]

with open(test_txt_path, 'r') as f:
    test_data = [line.split() for line in f.read().splitlines()]

test_image_paths = [line[0] for line in test_data]
test_label_paths = [line[1] for line in test_data]


SARDataset

In [11]:
class SARDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        with rasterio.open(image_path) as src:
            image = src.read().transpose((1, 2, 0))
            image = image.astype(np.float32)
            image = (image - mean) / std

        with rasterio.open(mask_path) as src:
            mask = src.read(1).astype(np.float32)
            mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

DoubleConv

In [12]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
      super(DoubleConv, self).__init__()
      self.conv = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True),
          nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(inplace=True),
      )

  def forward(self, x):
      return self.conv(x)

UNET

In [13]:
class UNET(nn.Module):
  def __init__(
      self, in_channels=4, out_channels=1, features=[64, 128, 256, 512],
  ):
    super(UNET, self).__init__()
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    for feature in reversed(features):
      self.ups.append(
          nn.ConvTranspose2d(
              feature*2, feature, kernel_size=2, stride=2,
          )
      )
      self.ups.append(DoubleConv(feature*2, feature))

    self.bottleneck = DoubleConv(features[-1], features[-1]*2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
  def forward(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    skip_connections = skip_connections[::-1]

    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx](x)
      skip_connection = skip_connections[idx//2]

      if x.shape != skip_connection.shape:
        x = TF.resize(x, size=skip_connection.shape[2:])

      concat_skip = torch.cat((skip_connection, x), dim=1)
      x = self.ups[idx+1](concat_skip)

    return self.final_conv(x)

seed

In [14]:
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

Funzione per il training

In [15]:
def train_fn(loader, model, optimizer, loss_fn, scaler, txtfile):
    loop = tqdm(loader)
    tot_loss = 0.0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=device)
        targets = targets.float().to(device=device)

        with torch.cuda.amp.autocast():
            predictions = model(data).squeeze(1)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        tot_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = tot_loss / len(loader)
    with open(txtfile, "a") as f:
        f.write(f"Train Loss: {avg_loss:.4f}\n")

Funzioni di mantenimento addestramento

In [16]:
def save_checkpoint(state, filename="/NN_register/checkpoint.pth.tar"):
  print("=> Saving checkpoint")
  torch.save(state, filename)

def load_checkpoint(checkpoint, model):
  print("=> Loading checkpoint")
  model.load_state_dict(checkpoint["state_dict"])

Evaluation Function

In [17]:
def eval_fn(loader, model, loss_fn, txtfile, device="cuda"):
    model.eval()
    num_correct = 0
    num_pixels = 0
    total_dice_score = 0
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    total_loss = 0

    with torch.no_grad():
        for num_it, (x, y) in enumerate(loader, start=1):
            print(num_it)
            x = x.to(device)
            y = y.to(device)
            out = model(x).squeeze(1)
            loss = loss_fn(out, y)
            total_loss += loss.item()
            preds = torch.sigmoid(out)
            preds = (preds > 0.5).float()

            num_correct += (preds == y).sum().item()
            num_pixels += torch.numel(preds)

            dice_score = (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)
            total_dice_score += dice_score.item()

            y_np = y.cpu().numpy().flatten()
            preds_np = preds.cpu().numpy().flatten()
            total_precision += precision_score(y_np, preds_np, zero_division=1)
            total_recall += recall_score(y_np, preds_np, zero_division=1)
            total_f1 += f1_score(y_np, preds_np, zero_division=1)

    avg_loss = total_loss / len(loader)
    avg_dice_score = total_dice_score / len(loader)
    avg_precision = total_precision / len(loader)
    avg_recall = total_recall / len(loader)
    avg_f1 = total_f1 / len(loader)
    accuracy = num_correct / num_pixels * 100

    print(f"Validation Loss: {avg_loss:.4f}")
    print(f"Accuracy: {accuracy:.2f}")
    print(f"Dice score: {avg_dice_score:.4f}")
    print(f"Precision: {avg_precision:.4f}")
    print(f"Recall: {avg_recall:.4f}")
    print(f"F1 Score: {avg_f1:.4f}")

    with open(txtfile, "a") as f:
        f.write(f"Validation Loss: {avg_loss:.4f}\n")
        f.write(f"Accuracy: {accuracy:.2f}\n")
        f.write(f"Dice score: {avg_dice_score:.4f}\n")
        f.write(f"Precision: {avg_precision:.4f}\n")
        f.write(f"Recall: {avg_recall:.4f}\n")
        f.write(f"F1 Score: {avg_f1:.4f}\n")
        f.write("\n")

    model.train()
    return avg_loss, accuracy, avg_dice_score, avg_precision, avg_recall, avg_f1

In [18]:
def save_predictions_as_imgs(loader, model, folder, device="cuda"):
    os.makedirs(folder, exist_ok=True)
    model.eval()
    it = 1
    for idx, (x, y) in enumerate(tqdm(loader, desc="Saving predictions")):
        print(it)
        it += 1
        x = x.to(device=device)
        y = y.to(device=device).unsqueeze(1)

        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()

        combined = torch.cat((y, preds), dim=1)  # Concatenate along height (dim=1)
        torchvision.utils.save_image(combined, f"{folder}/comparison_{idx}.tif")

    model.train()

Training baseline

In [19]:
set_seed(42)

transform = A.Compose([
    A.Resize(height=360, width=360),
    # A.CenterCrop(height=896, width=896),
    # A.RandomCrop(height=320, width=320),
    # A.HorizontalFlip(p=0.5),
    # A.VerticalFlip(p=0.5),
    # A.RandomRotate90(p=1),
    # A.Rotate(limit=90, p=0.5),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=360, width=360),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])

train_dataset = SARDataset(train_image_paths, train_label_paths, transform=transform)
val_dataset = SARDataset(val_image_paths, val_label_paths, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
load_model = True
lr = 1e-5

base_dir = os.getcwd()
model_dir = os.path.join(base_dir, "NN_register")
output_dir = os.path.join(model_dir, "output")
image_output_dir = os.path.join(model_dir, "images")

# Assicurati che le directory di output esistano
os.makedirs(output_dir, exist_ok=True)
os.makedirs(image_output_dir, exist_ok=True)

step_size = 18

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNET(in_channels=4, out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=step_size)
num_epochs = 20
scaler = torch.cuda.amp.GradScaler()

checkpoint_path = os.path.join(model_dir, "checkpoint.pth.tar")
if load_model and os.path.exists(checkpoint_path):
    load_checkpoint(torch.load(checkpoint_path), model)
else:
    print("Checkpoint file not found, starting training from scratch.")



Checkpoint file not found, starting training from scratch.


In [22]:
best_f1_score = 0.0

for epoch in range(num_epochs):
    print(f"epoch {epoch}")
    train_fn(train_loader, model, optimizer, criterion, scaler, os.path.join(output_dir, "output.txt"))

    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint, os.path.join(model_dir, "checkpoint.pth.tar"))

    val_loss, val_accuracy, val_dice_score, val_precision, val_recall, val_f1 = eval_fn(val_loader, model, criterion, os.path.join(output_dir, "output.txt"), device=device)

    if val_f1 > best_f1_score:
        best_f1_score = val_f1
        save_checkpoint(checkpoint, filename=os.path.join(model_dir, "best_model.pth.tar"))
        print(f"New best F1 score: {best_f1_score:.4f}")

    if (epoch + 1) % 5 == 0:
        save_predictions_as_imgs(val_loader, model, image_output_dir, device=device)
    scheduler.step()


epoch 0


100%|██████████| 63/63 [00:19<00:00,  3.27it/s, loss=0.456]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.4566
Accuracy: 91.79
Dice score: 0.0099
Precision: 0.0699
Recall: 0.0054
F1 Score: 0.0099
=> Saving checkpoint
New best F1 score: 0.0099
epoch 1


100%|██████████| 63/63 [00:19<00:00,  3.22it/s, loss=0.408]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.4265
Accuracy: 92.16
Dice score: 0.0029
Precision: 0.0727
Recall: 0.0015
F1 Score: 0.0029
epoch 2


100%|██████████| 63/63 [00:19<00:00,  3.24it/s, loss=0.427]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.4076
Accuracy: 92.29
Dice score: 0.0003
Precision: 0.0954
Recall: 0.0001
F1 Score: 0.0003
epoch 3


100%|██████████| 63/63 [00:19<00:00,  3.18it/s, loss=0.416]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3952
Accuracy: 92.29
Dice score: 0.0001
Precision: 0.1034
Recall: 0.0000
F1 Score: 0.0001
epoch 4


100%|██████████| 63/63 [00:19<00:00,  3.22it/s, loss=0.412]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3856
Accuracy: 92.29
Dice score: 0.0001
Precision: 0.1760
Recall: 0.0001
F1 Score: 0.0001


Saving predictions:   0%|          | 0/13 [00:00<?, ?it/s]

1


Saving predictions:  15%|█▌        | 2/13 [00:00<00:04,  2.36it/s]

2
3


Saving predictions:  31%|███       | 4/13 [00:01<00:02,  3.86it/s]

4
5


Saving predictions:  46%|████▌     | 6/13 [00:01<00:01,  4.72it/s]

6
7


Saving predictions:  62%|██████▏   | 8/13 [00:02<00:00,  5.25it/s]

8
9


Saving predictions:  77%|███████▋  | 10/13 [00:02<00:00,  5.56it/s]

10
11


Saving predictions:  92%|█████████▏| 12/13 [00:02<00:00,  5.78it/s]

12
13


Saving predictions: 100%|██████████| 13/13 [00:02<00:00,  4.52it/s]


epoch 5


100%|██████████| 63/63 [00:19<00:00,  3.23it/s, loss=0.401]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3782
Accuracy: 92.29
Dice score: 0.0001
Precision: 0.1473
Recall: 0.0000
F1 Score: 0.0001
epoch 6


100%|██████████| 63/63 [00:19<00:00,  3.22it/s, loss=0.346]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3700
Accuracy: 92.29
Dice score: 0.0000
Precision: 1.0000
Recall: 0.0000
F1 Score: 0.0000
epoch 7


100%|██████████| 63/63 [00:19<00:00,  3.20it/s, loss=0.377]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3562
Accuracy: 92.29
Dice score: 0.0000
Precision: 0.0673
Recall: 0.0000
F1 Score: 0.0000
epoch 8


100%|██████████| 63/63 [00:19<00:00,  3.20it/s, loss=0.362]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3495
Accuracy: 92.29
Dice score: 0.0000
Precision: 0.0673
Recall: 0.0000
F1 Score: 0.0000
epoch 9


100%|██████████| 63/63 [00:19<00:00,  3.18it/s, loss=0.34] 


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3425
Accuracy: 92.29
Dice score: 0.0000
Precision: 1.0000
Recall: 0.0000
F1 Score: 0.0000


Saving predictions:   8%|▊         | 1/13 [00:00<00:09,  1.30it/s]

1
2


Saving predictions:  23%|██▎       | 3/13 [00:01<00:03,  3.19it/s]

3
4


Saving predictions:  38%|███▊      | 5/13 [00:01<00:01,  4.20it/s]

5
6


Saving predictions:  54%|█████▍    | 7/13 [00:01<00:01,  4.98it/s]

7
8


Saving predictions:  69%|██████▉   | 9/13 [00:02<00:00,  5.41it/s]

9
10


Saving predictions:  85%|████████▍ | 11/13 [00:02<00:00,  5.65it/s]

11
12


Saving predictions: 100%|██████████| 13/13 [00:02<00:00,  4.48it/s]


13
epoch 10


100%|██████████| 63/63 [00:19<00:00,  3.19it/s, loss=0.361]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3367
Accuracy: 92.29
Dice score: 0.0000
Precision: 1.0000
Recall: 0.0000
F1 Score: 0.0000
epoch 11


100%|██████████| 63/63 [00:19<00:00,  3.22it/s, loss=0.329]


=> Saving checkpoint
1
2
3
4
5
6
7
8
9
10
11
12
13
Validation Loss: 0.3350
Accuracy: 92.29
Dice score: 0.0000
Precision: 1.0000
Recall: 0.0000
F1 Score: 0.0000
epoch 12


100%|██████████| 63/63 [00:19<00:00,  3.20it/s, loss=0.317]


=> Saving checkpoint
1


KeyboardInterrupt: 