In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from utilities import create_split_dataloaders, PASTIS
from models import UNet
from utilities import get_rgb

In [12]:
# VARIABLES TO SET
EPOCHS = 20
batch_size = 10
best_vloss = 1_000_000.

In [4]:
pastis = PASTIS('./data/PASTIS', 'DATA_S2', 'ANNOTATIONS', rgb_only=False, no_time=True)
print('Dataset Size:', len(pastis))

Dataset Size: 115222


In [15]:
train_loader, val_loader, test_loader = create_split_dataloaders(pastis, (0.8, 0.2), batch_size=batch_size)

In [16]:
model = UNet(
    enc_chs=(10, 64, 128, 256),
    dec_chs=(256, 128, 64),
    retain_dim=True,
    out_sz=(128, 128), 
    num_class=20
)

In [17]:
# Loss function (Need to cast the output to float after argmax)
loss_fn = torch.nn.CrossEntropyLoss()

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [18]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/satellite_class_{}'.format(timestamp))
epoch_number = 0

In [19]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.requires_grad = True
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % batch_size == batch_size-1:
            last_loss = running_loss / batch_size # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [20]:
for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    model.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(val_loader):
        vinputs, vlabels = vdata
        voutputs = model(vinputs)
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 10 loss: 2885.53681640625
  batch 20 loss: 3011.4302490234377
  batch 30 loss: 2821.4328735351564
  batch 40 loss: 2813.116418457031
  batch 50 loss: 2936.4901611328123
  batch 60 loss: 2956.5807861328126
  batch 70 loss: 3131.52236328125
  batch 80 loss: 2885.5777221679687
  batch 90 loss: 3128.7815185546874
  batch 100 loss: 3214.8813720703124
  batch 110 loss: 2965.0197631835936
  batch 120 loss: 2672.0143310546873
  batch 130 loss: 2772.2175048828126
  batch 140 loss: 3219.1356201171875
  batch 150 loss: 3152.1183837890626
  batch 160 loss: 3108.233447265625
  batch 170 loss: 2928.0587646484373
  batch 180 loss: 3145.9421142578126
  batch 190 loss: 3140.9084716796874
  batch 200 loss: 3266.9286865234376
  batch 210 loss: 2966.866552734375
  batch 220 loss: 2923.515478515625
  batch 230 loss: 2911.730810546875
  batch 240 loss: 2974.380615234375
  batch 250 loss: 3017.3000732421874
  batch 260 loss: 3014.6599853515627
  batch 270 loss: 2814.90546875
  batch 280 loss

KeyboardInterrupt: 

In [None]:
def get_rgb_simple(im):
    if isinstance(im, torch.Tensor): im = im.cpu().numpy()
    mx = im.max(axis=(1,2))
    mi = im.min(axis=(1,2))   
    im = (im - mi[:,None,None])/(mx - mi)[:,None,None]
    im = im.swapaxes(0,2).swapaxes(0,1)
    im = np.clip(im, a_max=1, a_min=0)
    return im

color_map = [
    (0, 0, 0),
    (0.6823529411764706, 0.7803921568627451, 0.9098039215686274),
    (1.0, 0.4980392156862745, 0.054901960784313725),
    (1.0, 0.7333333333333333, 0.47058823529411764),
    (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
    (0.596078431372549, 0.8745098039215686, 0.5411764705882353),
    (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
    (1.0, 0.596078431372549, 0.5882352941176471),
    (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
    (0.7725490196078432, 0.6901960784313725, 0.8352941176470589),
    (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
    (0.7686274509803922, 0.611764705882353, 0.5803921568627451),
    (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
    (0.9686274509803922, 0.7137254901960784, 0.8235294117647058),
    (0.4980392156862745, 0.4980392156862745, 0.4980392156862745),
    (0.7803921568627451, 0.7803921568627451, 0.7803921568627451),
    (0.7372549019607844, 0.7411764705882353, 0.13333333333333333),
    (0.8588235294117647, 0.8588235294117647, 0.5529411764705883),
    (0.09019607843137255, 0.7450980392156863, 0.8117647058823529),
    (1, 1, 1)
]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
model.train(False)

test_x, test_y = test_loader.__iter__().__next__()

y_hat = model(test_x)
y_hat.shape


In [None]:
test_img = test_x[0, :]
test_img = get_rgb_simple(test_img)

label_img = test_y[0, :].type(torch.int).numpy()

out_img = y_hat[0, :]

In [None]:
from matplotlib.colors import ListedColormap

crop_map = ListedColormap(color_map)
crop_map

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(test_img)
axs[0].set_title('Test Image')
axs[1].imshow(label_img, cmap=crop_map)
axs[1].set_title('Label Image')
axs[2].imshow(out_img, cmap=crop_map)
axs[2].set_title('Predicted Image')
plt.show()


In [None]:
out_img