In [57]:
import sys
import os
import torch
from torch import optim
from time import time

In [58]:
tile2vec_dir = '/atlas/u/swang/software/GitHub/tile2vec'
sys.path.append('../')
sys.path.append(tile2vec_dir)

In [59]:
from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet

In [60]:
from src.training import prep_triplets, train_triplet_epoch

# Step 1. Download triplets from bucket

Using the download link, unzip triplets into the directory /tile2vec/data/triplets.

# Step 2. Set up dataloader

In [61]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()


Set up the dataloader for training.

In [62]:
# Change these arguments to match your directory and desired parameters
img_type = 'naip'
tile_dir = '/Users/calummcmeekin/Documents/GitHub/MInf-Project/SatelliteSegmentation/tile2vec/data/nan_removed_tiles'
#tile_dir = '/Users/calummcmeekin/Downloads/tile2vec-master-original/data/triplets/' 
bands = 3
augment = True
batch_size = 50
shuffle = True
num_workers = 4
n_triplets = 19

In [63]:
dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')

Dataloader set up complete.


# Step 3. Set up TileNet

In [64]:
in_channels = bands
z_dim = 512

In [65]:
TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

TileNet set up complete.


Set up optimizer.

In [66]:
lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

# Step 4. Train model!

In [67]:
epochs = 1
margin = 10
l2 = 0.01
print_every = 1000
save_models = True


Define the directory for saving models.

In [68]:
model_dir = '/Users/calummcmeekin/Documents/GitHub/MInf-Project/SatelliteSegmentation/tile2vec/models'
if not os.path.exists(model_dir): os.makedirs(model_dir)

In [69]:
t0 = time()
with open("results_fn", 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

Begin training.................
Finished epoch 1: 29.079s
  Num Batches: 1
  Sum Loss: 13.022397994995117
  Average loss: 13.0224
  Average l_n: 9.3350
  Average l_d: -8.8828
  Average l_nd: 0.4523



In [37]:
# Save model after last epoch
if save_models:
    model_fn = os.path.join(model_dir, 'TileNet_epoch50.ckpt')
    torch.save(TileNet.state_dict(), model_fn)