In [None]:
#Colab資料夾設定方法
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/My Drive/Colab Notebooks/water_segmentation')
os.listdir()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


['models',
 'metrics',
 'datasets',
 'utils',
 'dataset',
 'lightning_logs',
 'dataset2',
 'water_segmentation_iou.ipynb',
 'water_segmentation_all.ipynb',
 'water_segmentation_final.ipynb']

In [None]:
!pip install pytorch-lightning
!pip install segmentation-models-pytorch



In [None]:
import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks/water_segmentation/models')
sys.path.append('/content/drive/My Drive/Colab Notebooks/water_segmentation/datasets')
sys.path.append('/content/drive/My Drive/Colab Notebooks/water_segmentation/metrics')

In [None]:
# Part1. Setting ################################

from models.unet import UNETModule
import torch
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import cv2

from datasets.water_bodies_dataset import SimpleWaterBodiesDataset
from datasets.water_bodies_dataset import PredictionWaterBodiesDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
import random
from pytorch_lightning import seed_everything

import argparse

myseed = 6666  # set a random seed for reproducibility
seed_everything(myseed, workers=True)
os.environ['PYTHONHASHSEED'] = str(myseed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

import gc
torch.cuda.empty_cache()
gc.collect()

INFO:lightning_fabric.utilities.seed:Seed set to 6666


In [None]:
# Part2. Dataset ################################

transform = A.Compose([
    ToTensorV2()
])

root = "dataset/"
train_dataset = SimpleWaterBodiesDataset(root, mode="all", transform=transform)
val_dataset = SimpleWaterBodiesDataset(root, mode="all", transform=transform)

print("Train dataset length:", len(train_dataset))
print("Val dataset length:", len(val_dataset))

n_cpu = os.cpu_count()
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=n_cpu)


Train dataset length: 80
Val dataset length: 80


In [None]:
# Part3. Training ###############################

accelerator = "gpu" if torch.cuda.is_available() else "cpu"

unet_module = UNETModule(
    model="unetpp",
    encoder="efficientnet-b4",
    encoder_weights=None,
    loss_fn="crossentropy",
)

unet_trainer = pl.Trainer(
    max_epochs=114,
    accelerator=accelerator,
    devices=1,
    log_every_n_steps=5,
    enable_checkpointing=False,
    deterministic=True,
)

unet_trainer.fit(unet_module, train_dataloader, valid_dataloader)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type            | Params | Mode 
-----------------------------------------------------------
0 | model          | UnetPlusPlus    | 20.8 M | train
1 | loss_fn        | BCELoss         | 0      | train
2 | validation_iou | SegmentationIOU | 0      | train
3 | training_iou   | SegmentationIOU | 0      | train
4 | test_iou       | SegmentationIOU | 0      | train
-----------------------------------------------------------
20.8 M    Trainable params
0         Non-trainable params
20.8 M    Total params
83.252    Total estimated model params size (MB)
641       Modules in train mode
0         Modules 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:
# Part4. Predict ################################

predict_dataset = SimpleWaterBodiesDataset(root, mode="all", transform=transform)
test_dataloader = DataLoader(predict_dataset, batch_size=20, shuffle=False, num_workers=n_cpu)
unet_trainer.validate(model=unet_module, dataloaders=test_dataloader)


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.08757827430963516, 'val_iou': 0.947890043258667}]

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

ERROR: Timed out waiting for TensorBoard to start. It may still be running as pid 18018.

In [None]:
import matplotlib.pyplot as plt

batch = next(iter(test_dataloader))
with torch.no_grad():
    unet_module.eval()
    logits = unet_module(batch["image"].float())
preds = (logits.sigmoid() > .5).float()

In [None]:
for image, gt_mask, pr_mask in zip(batch["image"], batch["mask"], preds):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image.detach().cpu().numpy().transpose(1, 2, 0))  # convert CHW -> HWC
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(gt_mask.squeeze().detach().cpu().numpy()) # just squeeze classes dim, because we have only one class
    plt.title("Ground truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pr_mask.squeeze().detach().cpu().numpy()) # just squeeze classes dim, because we have only one class
    plt.title("Prediction")
    plt.axis("off")

    plt.show()

Output hidden; open in https://colab.research.google.com to view.

In [None]:
from PIL import Image

def getimage(image_directory):

    transform_test = A.Compose([
        ToTensorV2(),
    ])

    image = np.array(Image.open(image_directory).convert("RGB"))
    image = Image.fromarray(image.astype(np.uint8))
    image_size = image.size
    if image.size != (512, 512):
        image = image.resize((512, 512), Image.BILINEAR)
    image = np.array(image, dtype=np.uint8)
    transformed = transform_test(image=image)
    image = transformed["image"]

    return image, image_size


for i in range(1, 21):
    image_test, image_size = getimage('dataset2/image/'+str(i)+'.jpg')
    with torch.no_grad():
        unet_module.eval()
        logits = unet_module(image_test.unsqueeze(0).float())
    preds = (logits.sigmoid() > .5).float()

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(image_test.detach().cpu().numpy().transpose(1, 2, 0))  # convert CHW -> HWC
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(preds.squeeze().detach().cpu().numpy()) # just squeeze classes dim, because we have only one class
    plt.title("Prediction")
    plt.axis("off")

    plt.show()


    preds = preds.squeeze().detach().cpu().numpy()
    preds = preds*255
    preds = Image.fromarray(preds.astype(np.uint8))
    preds = preds.resize(image_size, Image.NEAREST)
    preds.save('dataset2/output/'+str(i)+'.png')