In [19]:
import numpy as np
import os
import matplotlib.pyplot as plt
from src.sample_tiles import get_triplet_imgs, get_triplet_tiles

import torch
from torch import optim
from time import time

from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet
from src.training import prep_triplets, train_triplet_epoch

%load_ext autoreload
%autoreload 2
%matplotlib inline

ModuleNotFoundError: No module named 'gdal'

# Prepare tiles

In [17]:
img_dir_list = [x[0] for x in os.walk("/storage/EuroSATallBands/")][1:]
tile_dir = '../data/triplets' # where you want to save your tiles

for img_dir in img_dir_list: 
    print(f"Current dir: {dir}")
    img_triplets = get_triplet_imgs(img_dir, ".tif", n_triplets=10)
    print(img_triplets[:5,:])

    tiles = get_triplet_tiles(tile_dir,
                          img_dir, 
                          img_triplets, 
                          tile_size=64,
                          val_type='float32',
                          bands_only=True,
                          save=True,
                          verbose=True)

Current dir: /storage/EuroSATallBands/SeaLake
[['AnnualCrop_1852.tif' 'AnnualCrop_1876.tif']
 ['AnnualCrop_2660.tif' 'AnnualCrop_2129.tif']
 ['AnnualCrop_2763.tif' 'AnnualCrop_2685.tif']
 ['AnnualCrop_1109.tif' 'AnnualCrop_1567.tif']
 ['AnnualCrop_42.tif' 'AnnualCrop_243.tif']]
Sampling image AnnualCrop_1109.tif from dir
    Saving anchor and neighbor tile #3
    Anchor tile center:(72, 95)
    Neighbor tile center:(77, 53)
Sampling image AnnualCrop_1567.tif from dir
    Saving distant tile #3
    Distant tile center:(78, 94)
Sampling image AnnualCrop_1633.tif from dir
    Saving anchor and neighbor tile #6
    Anchor tile center:(41, 84)
    Neighbor tile center:(76, 72)
Sampling image AnnualCrop_1657.tif from dir
    Saving anchor and neighbor tile #7
    Anchor tile center:(86, 65)
    Neighbor tile center:(34, 89)
Sampling image AnnualCrop_1852.tif from dir
    Saving anchor and neighbor tile #0
    Anchor tile center:(41, 47)
    Neighbor tile center:(91, 92)
Sampling image Annual

# Train

## Setup data

In [18]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

# Change these arguments to match your directory and desired parameters
img_type = 'naip'
tile_dir = '../data/triplets/'
bands = 12
augment = True
batch_size = 50
shuffle = True
num_workers = 4
n_triplets = 100000

dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')


NameError: name 'torch' is not defined

## Setup tilenet


In [None]:
in_channels = bands
z_dim = 512

TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

# Train model

In [None]:
epochs = 50
margin = 10
l2 = 0.01
print_every = 10000
save_models = True

model_dir = '../models/'
if not os.path.exists(model_dir): os.makedirs(model_dir)

t0 = time()
with open(results_fn, 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

# Save model after last epoch
if save_models:
    model_fn = os.path.join(model_dir, 'TileNet_epoch50.ckpt')
    torch.save(TileNet.state_dict(), model_fn)