In [None]:
import os
from dataclasses import dataclass, field
from typing import List
import random
import copy
import sys
from functools import partial

import cv2
import torch
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import numpy as np
from torch import nn, optim
import json

In [None]:
sys.path.append('/home/przemek/Projects/pp/corn-field-damage/src')
import model_training_v2.common.dataset_preparation as dataset_preparation
import model_training_v2.common.corn_dataset as corn_dataset
import model_training_v2.common.model_definition as model_definition
import model_training_v2.common.model_training_results as model_training_results
import model_training_v2.common.model_training as model_training
import model_training_v2.common.plot_graphs as plot_graphs


from importlib import reload 
reload(dataset_preparation)
reload(corn_dataset)
reload(model_definition)
reload(model_training_results)
reload(model_training)
reload(plot_graphs)

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
CPU_DEVICE = 'cpu'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# DEVICE = 'cpu'
DEVICE

In [None]:
model_type = model_definition.ModelType.UNET_PLUS_PLUS
# model_type = model_definition.ModelType.UNET


model, model_params = model_definition.get_model_with_params(model_type)
res = model_training_results.ModelTrainingResults(model_params=model_params)


print(f'model_params = {model_params}')
print(model)


# model.load_state_dict(torch.load('/media/data/local/corn/processed_stride768_v2/segfromer/model_cpu_segformer_epoch20__after_first_training'))
# model.eval()
# model.to(DEVICE)

In [None]:
TILES_BASE_DIR = '/media/data/local/corn/new/tiles_stride_768/'


train_loader, valid_loader, test_loader = corn_dataset.get_train_valid_test_loaders(
    dataset_name='dataset_split_demo.json',  # only a few samples for demo testing
    base_dir_path=TILES_BASE_DIR, 
    batch_size=model_params.batch_size, 
    mask_scalling=model_params.mask_scalling_factor)

In [None]:
OUTPUT_DIR = os.path.join('/tmp/aaa/out/', model_params.get_model_file_name())
os.makedirs(OUTPUT_DIR, exist_ok=True)

res.set(TILES_BASE_DIR=TILES_BASE_DIR)
res.set(OUTPUT_DIR=OUTPUT_DIR)

In [None]:
fig_example_imgages_testloader = plot_graphs.plot_images_from_dataloader(test_loader)
res.add_fig(fig_example_imgages_testloader=fig_example_imgages_testloader)



In [None]:
model_trainer = model_training.ModelTrainer(
    model=model,
    device=DEVICE,
    model_params=model_params,
    res=res,
    )

In [None]:
# lrs = [0.0005] * 3 + [0.0001] * 4  + [0.00005] * 11 +  [0.00001] * 11 + [0.000005] * 11 + [0.000005] * 25
NUM = 9
NUM = 3
lrs = [0.0001] * NUM + [0.00003] * NUM + [0.00001] * NUM + [0.000003] * NUM + [0.000001] * NUM + [0.0000001] * NUM
        
model_trainer.train(train_loader=train_loader, valid_loader=valid_loader, device=DEVICE, cpu_device=CPU_DEVICE, lrs=lrs)


In [None]:
model_trainer.valid_logs_vec[0]

In [None]:
# for log in model_trainer.valid_logs_vec:
#     print(log['dice_loss'] + log['fscore'])

In [None]:
last_model = model

In [None]:
figures = plot_graphs.plot_training_metrics(model_trainer=model_trainer)
res.add_fig(**figures)

In [None]:



model_trainer.run_test(test_loader=test_loader, device=DEVICE)

In [None]:
model = model_trainer.best_model
model_file_path = os.path.join(OUTPUT_DIR, 'model')
model = model.to('cpu')                 
torch.save(model.state_dict(), model_file_path)


# model.load_state_dict(torch.load('/media/data/local/corn/processed_stride768/model_cpu'))
# model.eval()
# best_model = model

In [None]:
# model = model.to('cpu')

In [None]:
figs = plot_graphs.plot_example_predictions(
    model=model,
    model_params=model_params,
    test_loader=test_loader,
    number_of_images=3)

res.add_fig(**figs)

In [None]:
# 'manual' calculations as sanity check

model = model.to(CPU_DEVICE)
model_training.manual_prediction_test(
    test_loader=test_loader, 
    model=model, 
    device=CPU_DEVICE,
    model_params=model_params,
    )


In [None]:
res.save(dir_path=OUTPUT_DIR)