# Model Training


In [None]:
import sys
from os.path import expanduser
sys.path.append(expanduser("~/robosat.pink/"))

from robosat_pink.datasets import *
from robosat_pink.tiles import *

from sklearn.model_selection import train_test_split

from robosat_pink.models import albunet
from robosat_pink.tools.train import train, validate
from robosat_pink.losses.lovasz import Lovasz
import robosat_pink

import albumentations as A

from skimage import exposure

from numpy.random import choice
from math import floor

from tqdm import tqdm

from datetime import datetime

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from os import environ
environ['CURL_CA_BUNDLE']='/etc/ssl/certs/ca-certificates.crt'
environ['AWS_DEFAULT_PROFILE'] = 'esip'

from imp import reload
reload(robosat_pink.losses.lovasz)
reload(robosat_pink.tools.train)
from robosat_pink.tools.train import train, validate
from robosat_pink.losses.lovasz import Lovasz

from io import BytesIO

CHECKPOINT = "s3://planet-snowcover-models/checkpoint-190319-20:47:57"
S3_CHECKPOINT = False
if CHECKPOINT.startswith("s3://"):
    S3_CHECKPOINT = True
    # load from s3 
    CHECKPOINT = CHECKPOINT[5:]
    sess = boto3.Session(profile_name=environ['AWS_DEFAULT_PROFILE'])
    fs = s3fs.S3FileSystem(session=sess)
    s3ckpt = s3fs.S3File(fs, CHECKPOINT, 'rb')

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# DATA_DIR = "s3://planet-snowcover-imagery/20180601_181450_0f32_3B_AnalyticMS_SR_clip_tiled"
DATA_DIR = "s3://planet-snowcover-imagery/20180601_181448_0f32_3B_AnalyticMS_SR_clip_tiled"
MASK_DIR = "s3://planet-snowcover-snow/ASO_3M_SD_USCASJ_20180601_tiles_02"
all_tiles = SlippyMapTilesConcatenation(path = DATA_DIR, 
                                        target = MASK_DIR, 
                                        aws_profile = 'esip')

In [None]:
train_ids, test_ids = train_test_split(all_tiles.tiles)


In [None]:
transform = A.Compose([
    #A.ToFloat(p = 1),
    # A.RandomRotate90(p = 0.5),
    #A.RandomRotate90(p = 0.5),
    #A.RandomRotate90(p = 0.5), #these do something bad to the bands
#    A.Normalize(mean = mean, std = std, max_pixel_value = 1),
    A.HorizontalFlip(p = 0.5),
    A.VerticalFlip(p = 0.5),
#    A.ToFloat(p = 1, max_value = np.finfo(np.float64).max)
])

train_tiles = SlippyMapTilesConcatenation(path = DATA_DIR, 
                                          target = MASK_DIR, 
                                          tiles = train_ids,
                                          aws_profile = 'esip',
                                          joint_transform = transform)
valid_tiles = SlippyMapTilesConcatenation(path = DATA_DIR, 
                                          target = MASK_DIR, 
                                          tiles = test_ids, 
                                          aws_profile = 'esip', 
                                          joint_transform = transform)

## Setup Neural Network

In [None]:
net = albunet.Albunet(num_classes = 1, num_channels = 4)
device = torch.device('cuda')
net = torch.nn.DataParallel(net)

criterion = Lovasz().to(device)
optimizer = Adam(net.parameters(), lr=0.001)

if CHECKPOINT is not None:
    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()
    try: 
        if S3_CHECKPOINT:
            with s3fs.S3File(fs, CHECKPOINT, 'rb') as C:
                state = torch.load(io.BytesIO(C.read()))
        else: 
            state = torch.load(io.BytesIO(C.read()))
        optimizer.load_state_dict(state['optimizer'])
        net.load_state_dict(state['state_dict'])
        net.to(device)
    except FileNotFoundError as f:
        print("{} checkpoint not found.".format(CHECKPOINT))

In [None]:
train_loader = DataLoader(train_tiles,
                          batch_size = 8,
                          shuffle  = True,
                          drop_last=True, 
                          num_workers = 0)

In [None]:
valid_loader = DataLoader(valid_tiles,
                          batch_size = 8,
                          shuffle  = True,
                          drop_last=True, 
                          num_workers = 0)

##  Start the training

In [None]:
for epoch in range(10):
    print("epoch {}".format(epoch))
    train_hist = train(train_loader, 1, device, net, optimizer, criterion)
    print(train_hist)

## Look the run stats

In [None]:
train_hist

## Save the model and upload to S3

In [None]:
sess = boto3.Session(profile_name=environ['AWS_DEFAULT_PROFILE'])
fs = s3fs.S3FileSystem(session=sess)
f = s3fs.S3File(fs, 'planet-snowcover-models/' + fname, 'wb', )

In [None]:
torch.save({'state_dict' : net.state_dict(), 'optimizer': optimizer.state_dict()}, f)
f.close()

### Checkpoint information

In [None]:
print(f.key)