In [12]:
import os
import torch
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from typing import Any, Dict, Type, cast
from sharp_dataloader import GenMARSH, RandomRotation, Resize, gen_weights
from torch.utils.data import DataLoader
from os.path import dirname as up
import pandas as pd
import smp_metrics
from torchmetrics import Accuracy, JaccardIndex, MetricCollection
from omegaconf import OmegaConf

import numpy as np
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from MarshModel import MarshModel
from torchvision import transforms
import itertools 
import tqdm
import time

import matplotlib.pyplot as plt

path_cur = os.path.abspath(os.getcwd())
root_path = up(path_cur)
data_path = os.path.join(root_path, 'data/HL_NAIP/HL_transferlearning.csv')

In [8]:
import sys
sys.path.append(root_path)
from utils import metrics


In [9]:
import warnings
from typing import Any, Dict, cast

import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, JaccardIndex, MetricCollection
from sharp_trainer import SemanticSegmentationTask

In [10]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers import TensorBoardLogger

In [11]:
root_path

'/rapids/notebooks/sciclone/geograd/Miranda/github/MarshMapping'

In [None]:
transform_train = transforms.Compose([transforms.ToTensor(),
                                RandomRotation([-90, 0, 90, 180]),
                                Resize(512)])

# load data
dataset = GenMARSH(data_path, transform=transform_train, normalization=True, ndvi=True, datasource='NAIP')

train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(0))

trainloader = DataLoader(train_dataset, 
                batch_size = 8, 
                shuffle = True,
                num_workers = 0,
                pin_memory = False)

testloader = DataLoader(test_dataset, 
                batch_size = 8, 
                shuffle = False,
                num_workers = 0,
                pin_memory = False)



In [None]:


def display_image(image, label):
    
    # reshape data
    ndata = np.moveaxis(np.array(image), 0, -1)  # ndata is of shape (40, 40, 13)
    # select RGB (or any other combination)
    rgb = ndata[...,[2,1,0]]
    
    ndvi = ndata[..., -1]
    
    f = plt.figure()
    f.add_subplot(1,3,1)
    plt.imshow(rgb)
    
    f.add_subplot(1,3,2)
    plt.imshow(label)
    
    f.add_subplot(1,3,3)
    plt.imshow(ndvi)
    plt.show(block=True)


for i in range(20):
    
    sample = dataset[i]
    display_image(**sample)

    if i == 10:
        break

In [16]:
monitor_options = ['val_loss']
model_options = ['unet', 'deeplabv3'] #'unet', 'fpn', 'pspnet', 'deeplabv3', 'deeplabv3plus', 'pan', 'manet', 'linknet'
encoder_options = ['resnet50']
lr_options = [1e-4]
loss_options = ["ce"]
weight_init_options = ["imagenet"]
in_channel = 5
out_channel = 3

# class_distr = torch.Tensor([0.02515424979261073, 0.9748457502073893])
class_distr = torch.Tensor([0.744150945415795, 0.0064356910282143895, 0.24941336355599059]) # background, high marsh, low marsh
weight = gen_weights(class_distr, c=1.03)


for (model, encoder, lr, loss, weight_init, monitor_state) in itertools.product(
        model_options,
        encoder_options,
        lr_options,
        loss_options,
        weight_init_options,
        monitor_options):
    
    experiment_name = f"HLSegmentation_{monitor_state}_{model}_{encoder}_{lr}_{loss}_{weight_init}"
    
    print(experiment_name)

    experiment_dir = os.path.join(root_path, experiment_name)
    logger = TensorBoardLogger(experiment_dir, name="models")
    
    if monitor_state == 'val_loss':
        tracking_mode = 'min'
    elif monitor_state == 'val_JaccardIndex':
        tracking_mode = 'max'
    
    checkpoint_callback = ModelCheckpoint(
        monitor=monitor_state, dirpath=experiment_dir, save_top_k=1, save_last=True, mode=tracking_mode)

    early_stopping_callback = EarlyStopping(monitor=monitor_state, min_delta=0.00, patience=5, mode=tracking_mode)


    model = SemanticSegmentationTask(
                    segmentation_model=model,
                    encoder_name=encoder,
                    encoder_weights=weight_init,
                    ignore_index = 0,
                    c_weights = weight,
                    monitor_state = monitor_state,
                    learning_rate=lr,
                    in_channels=in_channel,
                    num_classes=out_channel,
                    learning_rate_schedule_patience=6,
                    loss=loss,
                    imagenet_pretraining=True)

#     trainer = pl.Trainer(
#                 callbacks=[checkpoint_callback, early_stopping_callback],
#                 logger=logger,
#                 default_root_dir=experiment_dir,
#                 min_epochs=1,
#                 max_epochs=100,
#                 accelerator="gpu",
#                 devices=[2])

#     trainer.fit(model, trainloader, testloader)



HLSegmentation_val_loss_unet_resnet50_0.0001_ce_imagenet
HLSegmentation_val_loss_deeplabv3_resnet50_0.0001_ce_imagenet


In [19]:
# for count,child in enumerate(model.children()):
#     print(" Child ", count , "is -")
#     print(child)

In [46]:

for count_parents, parents in enumerate(model.children()):
    
    if count_parents == 0:
        
        for count_child, child in enumerate(parents.children()):
            
            if count_child == 2:
                break
            
            for param in child.parameters():
                param.requires_grad=False
                print("Child ",count_child," is frozen now")
                print(count_child)
        
# #         if count==2:
# #             break
        
# #         for param in child.parameters():
# #             param.requires_grad=False
# #             print("Child ",count," is frozen now")
# #             print(child)

Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is frozen now
0
Child  0  is

In [52]:
model_unet = smp.DeepLabV3("resnet50")
for parents in model_unet.children():
    for count, child in enumerate(parents.children()):
        if count==2:
            break
        
        for param in child.parameters():
            param.requires_grad=False
        
        print("Child ",count," is frozen now")
        print(child)

Child  0  is frozen now
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Child  1  is frozen now
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Child  0  is frozen now
ASPP(
  (convs): ModuleList(
    (0): Sequential(
      (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): ASPPConv(
      (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): ASPPConv(
      (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(24, 24), dilation=(24, 24), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): ASPPConv(
      (0): Con