In [None]:
import os
import ast
import sys
import json
import random
import logging
import argparse
import numpy as np
from tqdm import tqdm
from os.path import dirname as up

import segmentation_models_pytorch as smp

import torch
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

path_cur = os.path.abspath(os.getcwd())

sys.path.append(path_cur)
from unet import UNet
from segmentation_models_pytorch.encoders import get_preprocessing_fn


from vims_dataloader_backbone import GenDEBRIS, RandomRotationTransform , class_distr, gen_weights, bands_mean, bands_std
# from vims_dataloader_lightning import NAIPStructure, RandomRotationTransform , class_distr, gen_weights, bands_mean, bands_std, GeoNAIPDataModule

sys.path.append(os.path.join(up(up(path_cur)), 'utils'))

root_path = up(up(path_cur))

logging.basicConfig(filename=os.path.join(root_path, 'logs','log_unet.log'), filemode='a',level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
logging.info('*'*10)


from pytorch_lightning import Trainer, seed_everything
import pytorch_lightning as pl
from torchgeo.trainers import SemanticSegmentationTask
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
import matplotlib.pyplot as plt


In [None]:
root_path

In [None]:

# Set transformation for dataset

transform_train = transforms.Compose([transforms.ToTensor(),
                                      transforms.RandomRotation(degrees=(0, 360)),
                                      transforms.RandomVerticalFlip(),
                                    transforms.RandomHorizontalFlip()])

transform_test = transforms.Compose([transforms.ToTensor()])

standardization = transforms.Normalize(bands_mean, bands_std)


dataset_train = GenDEBRIS('train', transform=transform_train, standardization = standardization, agg_to_water = False)
dataset_val = GenDEBRIS('val', transform=transform_test, standardization = standardization, agg_to_water = False)
dataset_test = GenDEBRIS('test', transform=transform_test, standardization = standardization, agg_to_water = False)

# Create data loaders



train_dataloader = DataLoader(dataset_train, 
                            batch_size = 32, 
                            shuffle = True,
                            num_workers = 0)

val_dataloader = DataLoader(dataset_val, 
                            batch_size = 32, 
                            shuffle = False,
                            num_workers=0)

test_dataloader = DataLoader(dataset_test, 
                            batch_size = 32, 
                            shuffle = False,
                            num_workers=0)

In [None]:

# visualization of input dataset and masks    

labels = ['Background','Bulkhead Or Sea Wall', 'Rip Rap', 
          'Groin', 'Breakwater']

def vis(batch, labels):
    
    # Formatter
    valid_values = range(0,len(labels))
    lookup = dict(zip(valid_values, labels))
    formatter = plt.FuncFormatter(lambda val, loc: lookup[val])

    for image, gt_mask in zip(batch["image"], batch["mask"]):
    
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(image.numpy().transpose(1, 2, 0)[:,:,0:3].astype(float))  # convert CHW -> HWC 
        plt.title("Image")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(gt_mask.numpy().squeeze(),cmap=plt.cm.get_cmap('rainbow', len(labels))) # just squeeze classes dim, because we have only one class
        plt.title("Ground truth")
        plt.colorbar(ticks=range(len(labels)), format=formatter)
        plt.axis("off")

        plt.show()

batch = next(iter(train_dataloader))
vis(batch, labels)

In [None]:
from vims_segmentation import SemanticSegmentation

In [None]:
import itertools


# Hyperparameter options

training_set_options = ["4class_16batch"]
model_options = ["unet"] #, "unet"
encoder_options = ["resnet50"] #"resnet18", 
lr_options = [1e-3] #1e-4
loss_options = ["ce"] #, "ce", "jaccard" # normalized focal loss: leads to faster convergence and better accuracy (https://arxiv.org/pdf/1909.07829.pdf)
weight_init_options = ["imagenet"] # "swsl", "ssl"
in_channel = 3
out_channel = 5 # including background
class_whts = ['CWFalse'] #'CWFalse'
class_weights = torch.tensor([0, 0.1343008 , 0.21761741, 0.31898171, 0.32910008])
betas_lst=[(0.9, 0.999)] #(0.85,0.999), (0.80, 0.999), (0.83,0.999), (0.84,0.999)

for (train_state, model, encoder, lr, loss, weight_init, class_wht, betas) in itertools.product(
        training_set_options,
        model_options,
        encoder_options,
        lr_options,
        loss_options,
        weight_init_options,
        class_whts,
        betas_lst):
    
    experiment_name = f"{train_state}_{model}_{encoder}_{lr}_{loss}_{weight_init}_{class_wht}_{betas}"
    print(experiment_name)

#     model = SemanticSegmentationTask(
#                 segmentation_model=model,
#                 encoder_name=encoder,
#                 encoder_weights=weight_init,
#                 learning_rate=lr,
#                 in_channels=in_channel,
#                 num_classes=out_channel,
#                 learning_rate_schedule_patience=10,
#                 ignore_zeros=False,
#                 loss=loss,
#                 imagenet_pretraining=True)

    if class_wht == 'CWTrue':
        class_weights = torch.tensor([0, 0.1343008 , 0.21761741, 0.31898171, 0.32910008])
    elif class_wht == 'CWFalse':
        class_weights = None
    
    model = SemanticSegmentation(
        segmentation_model=model,
        encoder_name=encoder,
        encoder_weights=weight_init,
        learning_rate=lr,
        in_channels=in_channel,
        num_classes=out_channel,
        learning_rate_schedule_patience=6,
        ignore_zeros=False,
        loss=loss,
        imagenet_pretraining=True,
        class_weights=class_weights,
        betas=betas) # None or a tensor of weights

    csv_name = "{}_csv".format(experiment_name)
    tb_name = "{}_tb".format(experiment_name)
    
    # Set the output directory and callbacks
    experiment_dir = os.path.join(root_path, "torchlightning_results_batch32")

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=experiment_dir,
        save_top_k=1,
        filename=experiment_name+"-{epoch:02d}-{val_loss:.2f}",
        save_last=True)

    early_stopping_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        patience=50
    )

    csv_logger = CSVLogger(
        save_dir=experiment_dir,
        name=csv_name
    )

    tb_logger = TensorBoardLogger(
        save_dir=experiment_dir, 
        name=tb_name)

    trainer = pl.Trainer(
        callbacks=[checkpoint_callback, early_stopping_callback],
        logger=[tb_logger],
        default_root_dir=experiment_dir,
        min_epochs=50,
        max_epochs=500,
        gpus=[5]
        #gpus=[4, 5, 6, 7]
    )
    
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
#     checkpoint_callback.best_model_path
    
    trainer.test(ckpt_path="best", dataloaders=test_dataloader)



In [None]:

"""
    loss: "ce"
    segmentation_model: "unet"
    encoder_name: "resnet18"
    encoder_weights: null
    encoder_output_stride: 16
    learning_rate: 1e-3
    learning_rate_schedule_patience: 6
    in_channels: 4
    num_classes: 7
    num_filters: 256
    ignore_zeros: False
    imagenet_pretraining: True

"""



model = SemanticSegmentationTask(
    segmentation_model='unet',
    encoder_name='resnet50',
    encoder_weights='imagenet',
    learning_rate=1e-4,
    in_channels=4,
    num_classes=5,
    learning_rate_schedule_patience=10,
    ignore_zeros=False,
    loss='ce',
    imagenet_pretraining=True
)

In [None]:
experiment_dir = os.path.join(root_path, "torchlightning_results")

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_dir,
    save_top_k=1,
    save_last=True,
)

early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=50,
)

csv_logger = CSVLogger(
    save_dir=experiment_dir,
    name="tutorial_logs"
)

tb_logger = TensorBoardLogger(
    save_dir=experiment_dir, 
    name="tb_model")

trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=[tb_logger],
    default_root_dir=experiment_dir,
    min_epochs=1,
    max_epochs=1000,
    gpus=[7]
    #gpus=[4, 5, 6, 7]
)

In [None]:
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
import csv

in_tests = False

if not in_tests:
    train_steps = []
    train_rmse = []

    val_steps = []
    val_rmse = []
    with open(os.path.join(experiment_dir, "tutorial_logs", "version_2", "metrics.csv"), "r") as f:
        csv_reader = csv.DictReader(f, delimiter=',')
        for i, row in enumerate(csv_reader):
            try:
                train_rmse.append(float(row["train_Accuracy"]))
                train_steps.append(i)
            except ValueError: # Ignore rows where train RMSE is empty
                pass

            try:
                val_rmse.append(float(row["val_Accuracy"]))
                val_steps.append(i)
            except ValueError: # Ignore rows where val RMSE is empty
                pass

In [None]:
import matplotlib.pyplot as plt

if not in_tests:
    plt.figure()
    plt.plot(train_steps, train_rmse, label="Train Accuracy")
    plt.plot(val_steps, val_rmse, label="Validation Accuracy")
    plt.legend(fontsize=15)
    plt.xlabel("Batches", fontsize=15)
    plt.ylabel("Accuracy", fontsize=15)
    plt.show()
    plt.close()

In [None]:
trainer.test(ckpt_path="best", dataloaders=test_dataloader)

In [None]:
from vims_segmentation import SemanticSegmentation

model2 = SemanticSegmentation(
    segmentation_model='unet',
    encoder_name='resnet50',
    encoder_weights='imagenet',
    learning_rate=1e-4,
    in_channels=4,
    num_classes=5,
    learning_rate_schedule_patience=6,
    ignore_zeros=False,
    loss='ce',
    imagenet_pretraining=True
)

In [None]:
trainer.fit(model2, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)