In [None]:
import os
import cv2
import numpy as np
from glob import glob
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from torch.amp import GradScaler
from tqdm import tqdm
from torch import autocast
from torchmetrics import JaccardIndex, F1Score

In [None]:
Image_dir_train = "Datasets/Combined_Augmented/train/images"
Mask_dir_train = "Datasets/Combined_Augmented/train/labels"
Image_dir_val = "Datasets/Combined_Augmented/val/images"
Mask_dir_val = "Datasets/Combined_Augmented/val/labels"
Image_dir_test = "Datasets/CITY_OSM/test/images"
Mask_dir_test= "Datasets/CITY_OSM/test/labels"
batch_size = 6
patience = 20
epochs = 65
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [None]:
preprocess_input = smp.encoders.get_preprocessing_fn('tu-maxvit_large_tf_512', pretrained='imagenet')

# Classe do Dataset utilizando Albumentations
class SegmentationDataset(Dataset):
    def __init__(self, img_paths: list, mask_paths: list):
        self.img_paths = img_paths
        self.mask_paths = mask_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        # Leitura e pré-processamento básico
        image = cv2.imread(self.img_paths[index])
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        image = cv2.resize(image,(512,512))
        image = image.astype('float32') / 255.0

        image = preprocess_input(image)

        mask = cv2.imread(self.mask_paths[index],cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask,(512,512))
        mask = mask.astype('int8')
        mask = np.expand_dims(mask,axis=0) # (1,512,512)

        image = torch.tensor(image).permute(2,0,1)
        mask = torch.tensor(mask)

        return image, mask


# Listando os caminhos dos arquivos
train_images = sorted(glob(os.path.join(Image_dir_train, "*png")))
train_mask   = sorted(glob(os.path.join(Mask_dir_train, "*png")))

val_images = sorted(glob(os.path.join(Image_dir_val, "*png")))
val_mask   = sorted(glob(os.path.join(Mask_dir_val, "*png")))

test_images = sorted(glob(os.path.join(Image_dir_test, "*png")))
test_mask   = sorted(glob(os.path.join(Mask_dir_test, "*png")))

# Criando os datasets com as transformações apropriadas
train_dataset = SegmentationDataset(train_images, train_mask)
val_dataset   = SegmentationDataset(val_images, val_mask)
test_dataset  = SegmentationDataset(test_images, test_mask)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model = smp.Unet(encoder_name='tu-maxvit_large_tf_512',encoder_weights='imagenet',in_channels=3,classes=3)
model.to(device)
criterion_ce = torch.nn.CrossEntropyLoss()  # Cross-Entropy Loss
criterion_dice = smp.losses.DiceLoss(mode="multiclass")  # Dice Loss
optimizer = optim.AdamW(model.parameters(),lr=0.0001)

# Agendador de aprendizado com decaimento polinomial
decay_steps = epochs * len(train_loader)  # Número total de iterações de treinamento
scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=decay_steps, power=2.0)

In [None]:
# Criando listas vazias para armazenar os valores
history = {
    'train_loss': [],
    'val_loss': [],
    'train_iou': [],
    'val_iou': [],
    'train_f1': [],
    'val_f1': []
}

In [None]:
best_loss = float("inf")
counter = 0
train_loss = 0.0
train_iou = 0.0
train_f1 = 0.0
val_loss = 0.0
val_iou = 0.0
val_f1 = 0.0

scaler = GradScaler()

for epoch in range(epochs):
    model.train()

    for batch_idx, (images, mask) in enumerate(tqdm(train_loader, desc=f"Época {epoch+1}/{epochs}"), 1):
        images,mask = images.to(device).float(), mask.to(device)
        mask = mask.to(device).squeeze(1).long()
        mask = torch.clamp(mask, 0, 2) #garantir que os pixeis estao entre [0, 2]

        optimizer.zero_grad()
        with autocast('cuda'): #optimizar o treino mexendo na precision
            output = model(images)
            loss_ce = criterion_ce(output, mask) #os dois tipos de erro
            loss_dice = criterion_dice(output, mask)
            loss = 0.5 * loss_ce + 0.5 * loss_dice

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

        scheduler.step()  # Atualizar taxa de aprendizado no final de cada batch

        # Para dados multiclass, aplica-se a argmax para obter probabilidades
        output_probs = torch.argmax(output, dim=1)
        # Calcula as estatísticas para dados multiclasses
        tp, fp, fn, tn = smp.metrics.get_stats(output_probs, mask, mode='multiclass', num_classes=3)
        # Calcula IoU e F1 Score (Dice) com redução "micro"
        batch_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        batch_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        train_iou += batch_iou.item()
        train_f1 += batch_f1.item()

        #tqdm.write(f"Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")


    model.eval()

    with torch.no_grad():
        for images, mask in val_loader:
            images,mask = images.to(device).float(), mask.to(device)
            mask = mask.to(device).squeeze(1).long()
            mask = torch.clamp(mask, 0, 2) #garantir que os pixeis estao entre [0, 2]
            output = model(images)

            loss_ce = criterion_ce(output, mask)
            loss_dice = criterion_dice(output, mask)
            loss = 0.5 * loss_ce + 0.5 * loss_dice

            val_loss += loss.item()

            output_probs = torch.argmax(output, dim=1)
            tp, fp, fn, tn = smp.metrics.get_stats(output_probs, mask, mode='multiclass', num_classes=3)
            batch_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
            batch_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
            val_iou += batch_iou.item()
            val_f1 += batch_f1.item()



    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    train_iou /= len(train_loader)
    train_f1 /= len(train_loader)
    val_iou /= len(val_loader)
    val_f1 /= len(val_loader)

    # Armazena os valores nas listas para plotagem posterior
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_iou)
    history['val_iou'].append(val_iou)
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)

    print(f"Epoch: {epoch}/{epochs}, Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Train F1: {train_f1:.4f}, Scheduler: {scheduler.get_last_lr()}")
    print(f"             Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}, Val F1: {val_f1:.4f}")

    if val_loss < best_loss: #guardar os modelos se tiverem um val_loss menor
        best_loss = val_loss
        torch.save(model.state_dict(), "UNetMaxVitLCombinedAug.pth")
        print(f"saveing best model with val_loss {val_loss} at UNetMaxVitLCombinedAug.pth")
    else: #early stopping ao final de x epocas sem melhorar
        counter += 1
        print(f"EarlyStopping: {counter}/{patience}")
        if counter >= patience:
            print(f"Early stopping ativado após {epoch+1} épocas")
            break

In [None]:
#Plot do Loss
epochs = range(1, len(history['train_loss']) + 1 )
loss = history['train_loss']
val_loss = history['val_loss']

plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
plt.show()

In [None]:
# Plot do IoU
epochs = range(1, len(history["val_iou"]) + 1)
iou = history["train_iou"]
val_iou = history["val_iou"]

plt.figure()
plt.plot(epochs, iou, "bo", label="Training IoU")
plt.plot(epochs, val_iou, "b", label="Validation IoU")
plt.title("Training and validation IoU")
plt.legend()
plt.show()

In [None]:
# Plot do F1-Score
epochs = range(1, len(history["val_f1"]) + 1)
f1 = history["train_f1"]
val_f1 = history["val_f1"]

plt.figure()
plt.plot(epochs, f1, "bo", label="Training F1-Score")
plt.plot(epochs, val_f1, "b", label="Validation F1-Score")
plt.title("Training and validation F1-Score")
plt.legend()
plt.show()

In [None]:
def infer_and_visualize(model,image_path,device):
    image = cv2.imread(image_path, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image,(512, 512))
    image_tensor = image.astype('float32') / 255.0 #(512,512,3)
    image_tensor = preprocess_input(image_tensor)

    image_tensor = torch.tensor(image_tensor).permute(2,0,1).unsqueeze(0).to(device).float() # (1,3,512,512)

    with torch.no_grad():
        output = model(image_tensor)
        output = torch.softmax(output, dim=1)
        output = output.argmax(dim=1).cpu().squeeze().numpy()

    plt.figure(figsize=(10,5))
    plt.subplot(1,3,1)
    plt.title("Original Image")
    plt.imshow(image)


    plt.subplot(1,3,2)
    plt.title("Mask")
    plt.imshow(output)

    plt.subplot(1,3,3)
    plt.title("overlap")
    plt.imshow(image)
    plt.imshow(output,cmap='jet',alpha=0.5)


    plt.show()

In [None]:
model.load_state_dict(torch.load("UNetMaxVitLCombinedAug.pth"))

In [None]:
model.eval()

In [None]:
infer_and_visualize(model,"Datasets/PrivateDataset/test/images/256.png",device)

In [None]:
def test_model(model, test_loader, device):
    model.to(device)
    model.eval()
    test_loss = 0.0

    # Cria os objetos de métricas para 3 classes com average='none' para obter o valor de cada classe
    jaccard = JaccardIndex(num_classes=3, average='none', task="multiclass").to(device)
    f1score = F1Score(num_classes=3, average='none', task="multiclass").to(device)

    with torch.no_grad():
        for images, mask in tqdm(test_loader, desc="Teste"):
            images, mask = images.to(device).float(), mask.to(device)
            mask = mask.squeeze(1).long()  # Remove a dimensão extra, se necessário

            output = model(images)
            loss_ce = criterion_ce(output, mask)
            loss_dice = criterion_dice(output, mask)
            loss = 0.5 * loss_ce + 0.5 * loss_dice
            test_loss += loss.item()

            output_probs = torch.argmax(output, dim=1)

            # Atualiza as métricas para o batch atual
            jaccard.update(output_probs, mask)
            f1score.update(output_probs, mask)

    avg_loss = test_loss / len(test_loader)
    per_class_iou = jaccard.compute()   # Tensor com o IoU de cada classe
    per_class_f1 = f1score.compute()      # Tensor com o F1 score de cada classe

    mean_iou = per_class_iou.mean()
    mean_f1 = per_class_f1.mean()

    # Exibe os resultados
    print(f"Test Loss: {avg_loss:.4f}")
    for i in range(3):
        print(f"Classe {i}: IoU = {per_class_iou[i]:.4f}, F1 = {per_class_f1[i]:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")
    print(f"Mean F1: {mean_f1:.4f}")

    return avg_loss, per_class_iou, mean_iou, per_class_f1, mean_f1

# Exemplo de uso:
test_loss, per_class_iou, mean_iou, per_class_f1, mean_f1 = test_model(model, test_loader, device)