In [18]:
import numpy as np
import pandas as pd
import sys
# np.set_printoptions(threshold=sys.maxsize)
import os
import matplotlib.pyplot as plt
from osgeo import gdal
from pathlib import Path

import torch
from torch import optim
from torch.autograd import Variable
from time import time

from src.datasets import TileTripletsDataset, GetBands, RandomFlipAndRotate, ClipAndScale, ToFloatTensor, triplet_dataloader
from src.tilenet import make_tilenet
from src.training import prep_triplets, train_triplet_epoch
from src.sample_tiles import get_triplet_imgs, get_triplet_tiles, get_triplet_imgs_with_dirs, get_triplet_imgs_from_df
from src.resnet import ResNet18

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Prepare tiles

In [6]:
n_triplets = 10000

sample_tiles_train_path = "/storage/EuroSATallBands/train.csv"
train_paths = pd.read_csv(sample_tiles_train_path)

base_eurosat_dir = "/storage/EuroSATallBands"

img_triplets = get_triplet_imgs_from_df(train_paths, n_triplets=n_triplets)
img_triplets
# print(img_triplets)

print(len(img_triplets))

                                                Filename  Label  \
0                   PermanentCrop/PermanentCrop_2401.tif      6   
1                   PermanentCrop/PermanentCrop_1006.tif      6   
2      HerbaceousVegetation/HerbaceousVegetation_1025...      2   
3                               SeaLake/SeaLake_1439.tif      9   
4                                   River/River_1052.tif      8   
...                                                  ...    ...   
19312  HerbaceousVegetation/HerbaceousVegetation_2292...      2   
19313                     AnnualCrop/AnnualCrop_1226.tif      0   
19314                           SeaLake/SeaLake_2010.tif      9   
19315                           SeaLake/SeaLake_2291.tif      9   
19316                               River/River_1323.tif      8   

                  ClassName  
0             PermanentCrop  
1             PermanentCrop  
2      HerbaceousVegetation  
3                   SeaLake  
4                     River  
...            

In [8]:
tile_dir = '/storage/tile2vec/tiles2' # where you want to save your tiles
# tile_dir = '../data/triplets'

tiles = get_triplet_tiles(tile_dir,
                        base_eurosat_dir, 
                        img_triplets, 
                        tile_size=64,
                        val_type='float32',
                        bands_only=False,
                        save=True,
                        verbose=False)

Sampling image AnnualCrop/AnnualCrop_10.tif from dir
Sampling image AnnualCrop/AnnualCrop_100.tif from dir
Sampling image AnnualCrop/AnnualCrop_1000.tif from dir
Sampling image AnnualCrop/AnnualCrop_1001.tif from dir
Sampling image AnnualCrop/AnnualCrop_1004.tif from dir
Sampling image AnnualCrop/AnnualCrop_1008.tif from dir
Sampling image AnnualCrop/AnnualCrop_1009.tif from dir




Sampling image AnnualCrop/AnnualCrop_1010.tif from dir
Sampling image AnnualCrop/AnnualCrop_1011.tif from dir
Sampling image AnnualCrop/AnnualCrop_1012.tif from dir
Sampling image AnnualCrop/AnnualCrop_1013.tif from dir
Sampling image AnnualCrop/AnnualCrop_1018.tif from dir
Sampling image AnnualCrop/AnnualCrop_1019.tif from dir
Sampling image AnnualCrop/AnnualCrop_1023.tif from dir
Sampling image AnnualCrop/AnnualCrop_1024.tif from dir
Sampling image AnnualCrop/AnnualCrop_1025.tif from dir
Sampling image AnnualCrop/AnnualCrop_1028.tif from dir
Sampling image AnnualCrop/AnnualCrop_103.tif from dir
Sampling image AnnualCrop/AnnualCrop_1030.tif from dir
Sampling image AnnualCrop/AnnualCrop_1033.tif from dir
Sampling image AnnualCrop/AnnualCrop_1036.tif from dir
Sampling image AnnualCrop/AnnualCrop_1037.tif from dir
Sampling image AnnualCrop/AnnualCrop_1040.tif from dir
Sampling image AnnualCrop/AnnualCrop_1045.tif from dir
Sampling image AnnualCrop/AnnualCrop_1046.tif from dir
Sampling im

# Train

## Setup data

In [9]:
# Environment stuff
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
cuda = torch.cuda.is_available()

# Change these arguments to match your directory and desired parameters
img_type = 'naip'
bands = 12
augment = True
batch_size = 50
shuffle = True
num_workers = 4

dataloader = triplet_dataloader(img_type, tile_dir, bands=bands, augment=augment,
                                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, 
                                n_triplets=n_triplets, pairs_only=True)
print('Dataloader set up complete.')

Dataloader set up complete.


## Setup tilenet


In [10]:
in_channels = bands
z_dim = 512

TileNet = make_tilenet(in_channels=in_channels, z_dim=z_dim)
TileNet.train()
if cuda: TileNet.cuda()
print('TileNet set up complete.')

lr = 1e-3
optimizer = optim.Adam(TileNet.parameters(), lr=lr, betas=(0.5, 0.999))

TileNet set up complete.


# Train model

In [12]:
epochs = 60
margin = 10
l2 = 0.01
print_every = 10000
save_models = True

model_dir = '/storage/tile2vec/models/'
if not os.path.exists(model_dir): os.makedirs(model_dir)

results_fn = 'result_fn'

t0 = time()
with open(results_fn, 'w') as file:

    print('Begin training.................')
    for epoch in range(0, epochs):
        (avg_loss, avg_l_n, avg_l_d, avg_l_nd) = train_triplet_epoch(
            TileNet, cuda, dataloader, optimizer, epoch+1, margin=margin, l2=l2,
            print_every=print_every, t0=t0)

# Save model after last epoch
if save_models:
    model_fn = os.path.join(model_dir, 'TileNet_epoch500_10k.ckpt')
    torch.save(TileNet.state_dict(), model_fn)

Begin training.................
Finished epoch 1: 57.517s
  Average loss: 2.5561
  Average l_n: 1.5396
  Average l_d: -12.3868
  Average l_nd: -10.8472

Finished epoch 2: 114.969s
  Average loss: 2.5669
  Average l_n: 1.5392
  Average l_d: -12.3697
  Average l_nd: -10.8305

Finished epoch 3: 170.060s
  Average loss: 2.5340
  Average l_n: 1.5542
  Average l_d: -12.3711
  Average l_nd: -10.8169

Finished epoch 4: 196.813s
  Average loss: 2.5175
  Average l_n: 1.5198
  Average l_d: -12.3075
  Average l_nd: -10.7877

Finished epoch 5: 223.719s
  Average loss: 2.5283
  Average l_n: 1.5544
  Average l_d: -12.3266
  Average l_nd: -10.7722

Finished epoch 6: 250.608s
  Average loss: 2.5114
  Average l_n: 1.5360
  Average l_d: -12.3507
  Average l_nd: -10.8147

Finished epoch 7: 277.838s
  Average loss: 2.5229
  Average l_n: 1.5561
  Average l_d: -12.3609
  Average l_nd: -10.8048

Finished epoch 8: 304.841s
  Average loss: 2.5024
  Average l_n: 1.5286
  Average l_d: -12.3402
  Average l_nd: -10

# Evaluate model

In [30]:
# Get data
val_csv = '/storage/EuroSATallBands/validation.csv'
test_csv = '/storage/EuroSATallBands/test.csv'
# test_csv = '/storage/EuroSATallBands/train.csv'

val_df = pd.read_csv(val_csv)
test_df = pd.read_csv(test_csv)

In [None]:
# get model if needed

model_from_file = False

if model_from_file:
    # Setting up model
    cuda = torch.cuda.is_available()
    # Use old model for now
    TileNet = ResNet18()
    if cuda: TileNet.cuda()

    # Load parameters
    model_fn = '/storage/models/.ckpt'
    checkpoint = torch.load(model_fn)
    TileNet.load_state_dict(checkpoint)
    TileNet.eval()

In [43]:
# encode test data

base_dir = Path(base_eurosat_dir)

# val_and_test_df = pd.concat([val_df, test_df])
# val_and_test_df.reset_index(drop=True, inplace=True)
# n_tiles = len(val_and_test_df)

n_tiles = len(val_df)
X = np.zeros((n_tiles, z_dim))

t0 = time()
# this solution to iterate over examples is very suboptimal, one should use torch dataset
for index, row in val_df.iterrows():
    # read the tile from provided filepath
    
    tile_filepath = base_dir / row["Filename"]
    obj = gdal.Open(tile_filepath)
    img = obj.ReadAsArray().astype(np.float32)
    img = np.moveaxis(img, 0, -1)

    tile = img[:, :, :bands] # bands - for our model 12

    tile = np.moveaxis(tile, -1, 0)
    tile = np.expand_dims(tile, axis=0)

    tile = tile / 255
    # Embed tile
    tile = torch.from_numpy(tile).float()
    tile = Variable(tile)
    if cuda: tile = tile.cuda()
    z = TileNet.encode(tile)
    if cuda: z = z.cpu()
    z = z.data.numpy()
    X[index,:] = z


    if index % 100 == 0:
        print(f"embedded {index+1} images")

t1 = time()
print('Embedded {} tiles: {:0.3f}s'.format(n_tiles, t1-t0))

embedded 1 images
embedded 101 images
embedded 201 images
embedded 301 images
embedded 401 images
embedded 501 images
embedded 601 images
embedded 701 images
embedded 801 images
embedded 901 images
embedded 1001 images
embedded 1101 images
embedded 1201 images
embedded 1301 images
embedded 1401 images
embedded 1501 images
embedded 1601 images
embedded 1701 images
embedded 1801 images
embedded 1901 images
embedded 2001 images
embedded 2101 images
embedded 2201 images
embedded 2301 images
embedded 2401 images
embedded 2501 images
embedded 2601 images
embedded 2701 images
embedded 2801 images
embedded 2901 images
embedded 3001 images
embedded 3101 images
embedded 3201 images
embedded 3301 images
embedded 3401 images
embedded 3501 images
embedded 3601 images
embedded 3701 images
embedded 3801 images
embedded 3901 images
embedded 4001 images
embedded 4101 images
embedded 4201 images
embedded 4301 images
embedded 4401 images
embedded 4501 images
embedded 4601 images
embedded 4701 images
embe

In [45]:
# y = val_and_test_df["Label"].to_numpy()
y = val_df["Label"].to_numpy()

from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2)
rf = RandomForestClassifier()
rf.fit(X_tr, y_tr)
acc = rf.score(X_te, y_te)
print(f"Model accuracy: {acc}")

Model accuracy: 0.3079710144927536
