In [None]:
import os
import sys

import cv2
import torch
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import numpy as np
from torch import nn, optim

In [None]:
CPU_DEVICE = 'cpu'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# DEVICE = 'cpu'
DEVICE

In [None]:
sys.path.append('/home/przemek/Projects/pp/corn-field-damage/src')
import model_training_v2.common.dataset_preparation as dataset_preparation
import model_training_v2.common.corn_dataset as corn_dataset
import model_training_v2.common.model_definition as model_definition
import model_training_v2.common.model_training_results as model_training_results
import model_training_v2.common.model_training as model_training
import model_training_v2.common.plot_graphs as plot_graphs


from importlib import reload 
reload(dataset_preparation)
reload(corn_dataset)
reload(model_definition)
reload(model_training_results)
reload(model_training)
reload(plot_graphs)


dataset_preparation.init_seeds(778)

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
#TILES_BASE_DIR = '/media/data/local/corn/new/tiles_stride_768/'
TILES_BASE_DIR = '/media/data/local/corn/new/tiles_ndvi_stride_768'

# BASE_OUTPUT_DIR = '/tmp/aaa/out/'
BASE_OUTPUT_DIR = '/media/data/local/corn/out/15_01_22'

In [None]:
model_type = model_definition.ModelType.UNET_PLUS_PLUS__EFFICIENT_NET_B3
# model_type = model_definition.ModelType.PAN


is_ndvi = 'ndvi' in TILES_BASE_DIR.lower()
# is_ndvi = False


image_channels = 4 if is_ndvi else 3
    

    

model, model_params = model_definition.get_model_with_params(model_type, in_channels=image_channels)
res = model_training_results.ModelTrainingResults(model_params=model_params)
# model_params.batch_size = 1

print(f'model_params = {model_params}')
print(model)


# model.load_state_dict(torch.load('/media/data/local/corn/processed_stride768_v2/segfromer/model_cpu_segformer_epoch20__after_first_training'))
# model.eval()
# model.to(DEVICE)

In [None]:
train_loader, valid_loader, test_loader = corn_dataset.get_train_valid_test_loaders(
#     dataset_name='dataset_split_demo.json',  # only a few samples for demo testing
    base_dir_path=TILES_BASE_DIR, 
    batch_size=model_params.batch_size, 
    mask_scalling=model_params.mask_scalling_factor,
    is_ndvi=is_ndvi)


OUTPUT_DIR = os.path.join(BASE_OUTPUT_DIR, model_params.get_model_file_name())
os.makedirs(OUTPUT_DIR, exist_ok=True)

res.set(TILES_BASE_DIR=TILES_BASE_DIR)
res.set(OUTPUT_DIR=OUTPUT_DIR)

In [None]:
# dataset = test_loader.dataset
# image, mask = dataset[0]

# np.max(image[3, :, :]), np.min(image[3, :, :])
# image.shape

In [None]:
fig_example_imgages_testloader = plot_graphs.plot_images_from_dataloader(test_loader, seed=345, is_ndvi=is_ndvi, number_of_images=8)
res.add_fig(fig_example_imgages_testloader=fig_example_imgages_testloader)



In [None]:
model_trainer = model_training.ModelTrainer(
    model=model,
    device=DEVICE,
    model_params=model_params,
    res=res,
    )

In [None]:
# lrs = [0.0005] * 3 + [0.0001] * 4  + [0.00005] * 11 +  [0.00001] * 11 + [0.000005] * 11 + [0.000005] * 25

# NUM = 7
# lrs = [0.0001] * NUM + [0.00003] * NUM + [0.00001] * NUM + [0.000003] * NUM + [0.000001] * NUM #  + [0.0000001] * NUM

# if is_ndvi:
NUM = 11
lrs = [0.00005] * NUM + [0.00001] * NUM + [0.000001] * NUM #  + [0.0000001] * NUM
# else:
#     NUM = 9
#     lrs = [0.00001] * NUM + [0.000003] * NUM + [0.000001] * NUM #  + [0.0000001] * NUM


model_trainer.train(train_loader=train_loader, valid_loader=valid_loader, device=DEVICE, cpu_device=CPU_DEVICE, lrs=lrs)
last_model = model

In [None]:
figures = plot_graphs.plot_training_metrics(model_trainer=model_trainer, lrs=lrs)
res.add_fig(**figures)

In [None]:
model_trainer.run_test(test_loader=test_loader, device=DEVICE)

In [None]:
model = model_trainer.best_model
model_file_path = os.path.join(OUTPUT_DIR, 'model')
model = model.to('cpu')                 
torch.save(model.state_dict(), model_file_path)


# model.load_state_dict(torch.load('/media/data/local/corn/processed_stride768/model_cpu'))
# model.eval()
# best_model = model

In [None]:
figs = plot_graphs.plot_example_predictions(
    model=model,
    model_params=model_params,
    test_loader=test_loader,
    number_of_images=32,
    is_ndvi=is_ndvi)

res.add_fig(**figs)

In [None]:
# # 'manual' calculations as sanity check

# model = model.to(CPU_DEVICE)
# model_training.manual_prediction_test(
#     test_loader=test_loader, 
#     model=model, 
#     device=CPU_DEVICE,
#     model_params=model_params,
#     )


In [None]:
res.save(dir_path=OUTPUT_DIR)