# Train
Simple notebook to train a fastai U-Net with the default settings for 12 fine-tuning epochs.

In [None]:
from fastai.vision.all import *

In [None]:
ice_path = Path("../arctic_images_original_2")
def seginput2segmap(f): return ice_path/"segmaps"/f"{f.stem}.png"
dls = SegmentationDataLoaders.from_label_func(
    ice_path, get_image_files(ice_path/"seginput"),
    seginput2segmap,  # Pickling doesn't like lambda expressions
    codes = ["water", "sky", "ice", "other"],
    bs = 8,
    seed = 47
)

In [None]:
learn = unet_learner(dls, resnet34, metrics=[DiceMulti])
learn.fine_tune(12)
learn.show_results()

Make and save predictions for all the validation set:

In [None]:
import time
output_path = Path("../inferred/basictrain"+time.strftime("_%Y-%m-%dT%H-%M-%S"))
output_path.mkdir()

def remove_whitespace():  # sheesh
  plt.gcf().add_axes(plt.Axes(plt.gcf(), [0,0,1,1]))
  plt.axis("off")

def plt_superimposed(base, mask):
  plt.imshow(base)
  plt.imshow(mask, alpha=0.25, cmap="tab20", vmax=4)

for i, (img, actual) in enumerate(dls.valid_ds):
  pred = learn.predict(img)[0]

  remove_whitespace()
  plt.imshow(img)
  plt.savefig(output_path/f"orig_{i:03d}.png")
  
  remove_whitespace()
  plt_superimposed(img, actual)
  plt.savefig(output_path/f"true_{i:03d}.png", bbinches="tight")

  remove_whitespace()
  plt_superimposed(img, pred)
  plt.savefig(output_path/f"pred_{i:03d}.png", bbinches="tight")

  print(f"{i+1}/{len(dls.valid_ds)}")


Save the model itself for later use:

In [None]:
export_path = Path("../saved_models/export"+time.strftime("_%Y-%m-%dT%H-%M-%S")+".pkl")
learn.export(export_path)

Save the validation set so we can be sure we're getting the same one later:

In [None]:
with open("../saved_models/validlist"+time.strftime("_%Y-%m-%dT%H-%M-%S")+".txt", 'w') as validlist:
  for f in dls.valid.items:
    validlist.write(f.name+"\n")

## To load from saved:

In [None]:
# from fastai.vision.all import *

# ice_path = Path("../arctic_images_original_2")
# def seginput2segmap(f): return ice_path/"segmaps"/f"{f.stem}.png"
# dls = SegmentationDataLoaders.from_label_func(
#     ice_path, get_image_files(ice_path/"seginput"),
#     seginput2segmap,  # Pickling doesn't like lambda expressions
#     codes = ["water", "sky", "ice", "other"],
#     bs = 8,
#     seed = 47
# )
# learn = load_learner("../saved_models/export_2022-08-01T20-46-50.pkl")