In [1]:
import numpy as np
import os
import torch
from time import time
from torch.autograd import Variable
import random
import sys
sys.path.append("..")

from src.sample_tiles import (
    extract_tile,
    sample_distant_same,
    sample_neighbor,
    load_img,
    sample_anchor,
    sample_distant_diff,
)
from tqdm import tqdm
import pandas as pd

import pandas as pd
from pathlib import Path
from osgeo import gdal
img_type = "landsat"  # images are in float - this parameter specifies that there is a need for normalization of floats
tile_dir = Path("/storage/tile2vec/distant_diff")
base_eurosat_dir = Path("/storage/EuroSATallBands")
bands = 13
num_workers = 4
n_triplets = 50000 # number of triplets to sample

train_path = Path("/storage/EuroSATallBands/train.csv")
train_df = pd.read_csv(train_path)

random.seed(44)
np.random.seed(44)

In [2]:
def split_path(path):
    folder, file = path.split('/', 1)
    return folder, file

def get_modified_triplet_imgs(img_df, n_triplets=1000):
    """
    Returns a numpy array of dimension (n_triplets, 2). First column is
    the img name of anchor/neighbor tiles and second column is img name
    of distant tiles.
    """

    img_df[['Folder', 'File']] = img_df['Filename'].apply(split_path).apply(pd.Series)

    # Grupuj pliki według folderów
    grouped = img_df.groupby('Folder')['File'].apply(list).to_dict()

    pairs = []
    folders = list(grouped.keys())

    print("Sampling tiles")
    for _ in tqdm(range(n_triplets)):
        # Losuj dwa różne foldery
        folder1, folder2 = random.sample(folders, 2)
        file1 = random.choice(grouped[folder1])
        file2 = random.choice(grouped[folder2])
        pairs.append([f"{folder1}/{file1}", f"{folder2}/{file2}"])
    pairs = np.vstack(pairs)

    return pairs

img_triplets = get_modified_triplet_imgs(train_df, n_triplets)
print("finished generating triplet sources")
print(img_triplets)

Sampling tiles


100%|██████████| 50000/50000 [00:00<00:00, 160025.03it/s]


finished generating triplet sources
[['PermanentCrop/PermanentCrop_1995.tif' 'River/River_38.tif']
 ['Forest/Forest_659.tif'
  'HerbaceousVegetation/HerbaceousVegetation_1455.tif']
 ['Industrial/Industrial_290.tif' 'AnnualCrop/AnnualCrop_2758.tif']
 ...
 ['AnnualCrop/AnnualCrop_45.tif' 'SeaLake/SeaLake_2438.tif']
 ['PermanentCrop/PermanentCrop_2493.tif' 'River/River_2420.tif']
 ['AnnualCrop/AnnualCrop_1855.tif' 'Pasture/Pasture_764.tif']]


In [3]:
def get_triplet_tiles(
    tile_dir,
    img_dir,
    img_triplets,
    tile_size=50,
    neighborhood=100,
    val_type="uint8",
    bands_only=False,
    save=True,
    verbose=False,
):
    print("Loading and preprocessing triplets")
    if not os.path.exists(tile_dir):
        os.makedirs(tile_dir)
    size_even = tile_size % 2 == 0
    tile_radius = tile_size // 2

    n_triplets = img_triplets.shape[0]
    unique_imgs = np.unique(img_triplets)
    tiles = np.zeros((n_triplets, 3, 2), dtype=np.int16)

    for img_name in tqdm(unique_imgs):
        if img_name[-3:] == "npy":
            img = np.load(os.path.join(img_dir, img_name))
        else:
            img = load_img(
                os.path.join(img_dir, img_name),
                val_type=val_type,
                bands_only=bands_only,
            )
        img_padded = np.pad(
            img,
            pad_width=[(tile_radius, tile_radius), (tile_radius, tile_radius), (0, 0)],
            mode="reflect",
        )
        img_shape = img_padded.shape

        for idx, row in enumerate(img_triplets):
            if row[0] == img_name:
                xa, ya = sample_anchor(img_shape, tile_radius)
                xn, yn = sample_neighbor(img_shape, xa, ya, neighborhood, tile_radius)

                if verbose:
                    print("    Saving anchor and neighbor tile #{}".format(idx))
                    print("    Anchor tile center:{}".format((xa, ya)))
                    print("    Neighbor tile center:{}".format((xn, yn)))
                if save:
                    tile_anchor = extract_tile(img_padded, xa, ya, tile_radius)
                    tile_neighbor = extract_tile(img_padded, xn, yn, tile_radius)
                    if size_even:
                        tile_anchor = tile_anchor[:-1, :-1]
                        tile_neighbor = tile_neighbor[:-1, :-1]
                    np.save(
                        os.path.join(tile_dir, "{}anchor.npy".format(idx)), tile_anchor
                    )
                    np.save(
                        os.path.join(tile_dir, "{}neighbor.npy".format(idx)),
                        tile_neighbor,
                    )

                tiles[idx, 0, :] = xa - tile_radius, ya - tile_radius
                tiles[idx, 1, :] = xn - tile_radius, yn - tile_radius

                if row[1] == img_name:
                    # distant image is same as anchor/neighbor image
                    try:
                        xd, yd = sample_distant_same(
                            img_shape, xa, ya, neighborhood, tile_radius
                        )
                    except ValueError:
                        print("Could not sample from the same image")
                        print("Image name ", img_name)
                        print("Exiting...")
                        exit(0)

                    if verbose:
                        print("    Saving distant tile #{}".format(idx))
                        print("    Distant tile center:{}".format((xd, yd)))
                    if save:
                        tile_distant = extract_tile(img_padded, xd, yd, tile_radius)
                        if size_even:
                            tile_distant = tile_distant[:-1, :-1]
                        np.save(
                            os.path.join(tile_dir, "{}distant.npy".format(idx)),
                            tile_distant,
                        )
                    tiles[idx, 2, :] = xd - tile_radius, yd - tile_radius

            elif row[1] == img_name:
                # distant image is different from anchor/neighbor image
                xd, yd = sample_distant_diff(img_shape, tile_radius)
                if verbose:
                    print("    Saving distant tile #{}".format(idx))
                    print("    Distant tile center:{}".format((xd, yd)))
                if save:
                    tile_distant = extract_tile(img_padded, xd, yd, tile_radius)
                    if size_even:
                        tile_distant = tile_distant[:-1, :-1]
                    np.save(
                        os.path.join(tile_dir, "{}distant.npy".format(idx)),
                        tile_distant,
                    )
                tiles[idx, 2, :] = xd - tile_radius, yd - tile_radius

    return tiles

In [4]:
in_channels = bands
z_dim = 512

img_triplets = get_modified_triplet_imgs(train_df, n_triplets)
print("finished generating triplet sources")

tiles = get_triplet_tiles(tile_dir, base_eurosat_dir, img_triplets, tile_size=60)

Sampling tiles


100%|██████████| 50000/50000 [00:00<00:00, 336881.61it/s]


finished generating triplet sources
Loading and preprocessing triplets


100%|██████████| 19162/19162 [20:18<00:00, 15.72it/s]


# Train

In [5]:
from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet
from src.training import prep_triplets, train_triplet_epoch
from torch import optim

import matplotlib.pyplot as plt
import pickle 

import os
import torch
from time import time

# script to train the tile2vec model

# values to change during training
model_name = 'TileNet_Distant_Diff.ckpt'
img_type = "landsat" # images are in float - this parameter specifies that there is a need for normalization of floats
bands = 13
augment = True
batch_size = 50
shuffle = True
num_workers = 16
n_triplets = 50000
z_dim = 512


# initialize GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

print("Cuda device: ", cuda)


# dataloader in shor loades data for the model
dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')



# function to initialize model structure
TileNet = make_tilenet(in_channels=bands, z_dim=z_dim)
TileNet.train()
if cuda: 
    TileNet.cuda()
print('TileNet set up complete.')

# set up the learning rate and select the optimizer
lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))


# training-level parameters 
epochs = 50
margin = 10
l2 = 0.01
print_every = 1000 # how often model will produce the information about the loss
save_models = True

# create model directory
model_dir = '/storage/tile2vec/models'
if not os.path.exists(model_dir): 
    os.makedirs(model_dir)
    

results_fn = "/storage/tile2vec/results_fn"

# avg_losses = []
# avg_l_ns = []
# avg_l_ds = []
# avg_l_nds = []

t0 = time()
with open(results_fn, 'w') as file:
    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)
        
        # avg_losses.append(avg_loss)
        # avg_l_ns.append(avg_l_n)
        # avg_l_ds.append(avg_l_d)
        # avg_l_nds.append(avg_l_nd)

# Save model after last epoch
if save_models:
    print("saving model")
    model_fn = os.path.join(model_dir, model_name)
    torch.save(TileNet.state_dict(), model_fn)
 
# with open(model_name + ".pkl", "wb") as f:
#     avg = {"losses": avg_losses, "l_n": avg_l_ns, "l_d": avg_l_ds, "l_nd": avg_l_nds}
#     pickle.dump(avg, f)


Cuda device:  True
Dataloader set up complete.
TileNet set up complete.
Begin training.................
Finished epoch 1: 124.739s
  Average loss: 5.4123
  Average l_n: 4.3698
  Average l_d: -13.6826
  Average l_nd: -9.3128

Finished epoch 2: 249.613s
  Average loss: 3.8473
  Average l_n: 3.1863
  Average l_d: -13.9259
  Average l_nd: -10.7396

Finished epoch 3: 375.221s
  Average loss: 3.4781
  Average l_n: 2.7171
  Average l_d: -13.3664
  Average l_nd: -10.6493

Finished epoch 4: 501.294s
  Average loss: 3.3505
  Average l_n: 2.5990
  Average l_d: -13.1654
  Average l_nd: -10.5664

Finished epoch 5: 627.822s
  Average loss: 3.1103
  Average l_n: 2.4066
  Average l_d: -13.0676
  Average l_nd: -10.6610

Finished epoch 6: 755.458s
  Average loss: 3.0793
  Average l_n: 2.3560
  Average l_d: -12.9717
  Average l_nd: -10.6156

Finished epoch 7: 881.721s
  Average loss: 3.1340
  Average l_n: 2.4941
  Average l_d: -13.1563
  Average l_nd: -10.6622

Finished epoch 8: 1007.507s
  Average loss: