In [2]:
import numpy as np
import sys
np.set_printoptions(threshold=sys.maxsize)
import os
import matplotlib.pyplot as plt
from src.sample_tiles import get_triplet_imgs, get_triplet_tiles, get_triplet_imgs_with_dirs

import sys
sys.path.append("..")
import torch
from torch import optim
from time import time

from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet
from src.training import prep_triplets, train_triplet_epoch

%load_ext autoreload
%autoreload 2
%matplotlib inline

# Prepare tiles

In [3]:
dir_of_img_dirs = '/storage/EuroSATallBands/'
img_dir_list = [x[0] for x in os.walk(dir_of_img_dirs)][1:]
print(img_dir_list)

tile_dir = '../data/triplets' # where you want to save your tiles
n_triplets = 500

# TODO - zmienic by brało obrazki tylko z train.csv
img_triplets = get_triplet_imgs_with_dirs(img_dir_list, n_triplets=n_triplets)
# print(img_triplets)

print(len(img_triplets))

tiles = get_triplet_tiles(tile_dir,
                        dir_of_img_dirs, 
                        img_triplets, 
                        tile_size=64,
                        val_type='float32',
                        bands_only=False,
                        save=True,
                        verbose=True)

['/storage/EuroSATallBands/AnnualCrop', '/storage/EuroSATallBands/Highway', '/storage/EuroSATallBands/HerbaceousVegetation', '/storage/EuroSATallBands/Industrial', '/storage/EuroSATallBands/River', '/storage/EuroSATallBands/Pasture', '/storage/EuroSATallBands/Residential', '/storage/EuroSATallBands/PermanentCrop', '/storage/EuroSATallBands/Forest', '/storage/EuroSATallBands/SeaLake']
500
Sampling image AnnualCrop/AnnualCrop_1012.tif from dir
    Saving distant tile #17
    Distant tile center:(50, 71)
Sampling image AnnualCrop/AnnualCrop_1048.tif from dir
    Saving anchor and neighbor tile #108
    Anchor tile center:(56, 81)
    Neighbor tile center:(53, 83)
Sampling image AnnualCrop/AnnualCrop_1055.tif from dir
    Saving distant tile #437
    Distant tile center:(70, 41)
Sampling image AnnualCrop/AnnualCrop_1072.tif from dir
    Saving distant tile #32
    Distant tile center:(42, 41)
Sampling image AnnualCrop/AnnualCrop_1082.tif from dir




    Saving distant tile #78
    Distant tile center:(33, 67)
Sampling image AnnualCrop/AnnualCrop_1123.tif from dir
    Saving distant tile #260
    Distant tile center:(43, 36)
Sampling image AnnualCrop/AnnualCrop_1172.tif from dir
    Saving anchor and neighbor tile #162
    Anchor tile center:(35, 54)
    Neighbor tile center:(95, 74)
Sampling image AnnualCrop/AnnualCrop_1182.tif from dir
    Saving anchor and neighbor tile #451
    Anchor tile center:(67, 93)
    Neighbor tile center:(48, 90)
Sampling image AnnualCrop/AnnualCrop_1203.tif from dir
    Saving anchor and neighbor tile #341
    Anchor tile center:(69, 76)
    Neighbor tile center:(56, 34)
Sampling image AnnualCrop/AnnualCrop_1209.tif from dir
    Saving distant tile #303
    Distant tile center:(35, 72)
Sampling image AnnualCrop/AnnualCrop_1219.tif from dir
    Saving distant tile #83
    Distant tile center:(93, 65)
Sampling image AnnualCrop/AnnualCrop_1261.tif from dir
    Saving distant tile #90
    Distant tile cen

# Train

## Setup data

In [4]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

# Change these arguments to match your directory and desired parameters
img_type = 'naip'
# tile_dir = '../data/triplets/'
bands = 12
augment = True
batch_size = 50
shuffle = True
num_workers = 4
# n_triplets = 10 # trzeba ustawić na tyle ile jest w data/triplets

dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')

Dataloader set up complete.


## Setup tilenet


In [5]:
in_channels = bands
z_dim = 512

TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

TileNet set up complete.


# Train model

In [6]:
epochs = 50
margin = 10
l2 = 0.01
print_every = 10000
save_models = True

model_dir = '../models/'
if not os.path.exists(model_dir): os.makedirs(model_dir)

results_fn = 'result_fn'

t0 = time()
with open(results_fn, 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

# Save model after last epoch
if save_models:
    model_fn = os.path.join(model_dir, 'TileNet_epoch50.ckpt')
    torch.save(TileNet.state_dict(), model_fn)

Begin training.................
self.tile_dir: ../data/triplets/234anchor.npy
self.tile_dir: ../data/triplets/249anchor.npy


self.tile_dir: ../data/triplets/76anchor.npy

self.tile_dir: ../data/triplets/56anchor.npy

self.tile_dir: ../data/triplets/333anchor.npy

self.tile_dir: ../data/triplets/464anchor.npy

self.tile_dir: ../data/triplets/210anchor.npy

self.tile_dir: ../data/triplets/479anchor.npy
self.tile_dir: ../data/triplets/330anchor.npy


self.tile_dir: ../data/triplets/66anchor.npy

self.tile_dir: ../data/triplets/6anchor.npy
self.tile_dir: ../data/triplets/289anchor.npy


self.tile_dir: ../data/triplets/327anchor.npy
self.tile_dir: ../data/triplets/392anchor.npy


self.tile_dir: ../data/triplets/423anchor.npy
self.tile_dir: ../data/triplets/118anchor.npy
self.tile_dir: ../data/triplets/412anchor.npy



self.tile_dir: ../data/triplets/4anchor.npy

self.tile_dir: ../data/triplets/150anchor.npy
self.tile_dir: ../data/triplets/280anchor.npy


self.tile_dir: ../data/triplets/238