In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import albumentations as alb
import cv2
import tensorboard
#import tensorflow as tf
import datetime
import torch

from utils import plot
from model import UNet
#from torchsummary import summary
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

In [None]:
# Define constants

# directories
data_dir = 'data' # change to directory containing the data
train_dir = 'train'
log_dir = 'runs'

In [None]:
# Load metadata and get random sample
metadata = pd.read_csv(os.path.join(data_dir, 'metadata.csv'))

print(metadata)

sample = metadata[metadata['split'] == 'train'].sample(n=1)

In [None]:
# Plot random sample

plot(sample, data_dir)

In [None]:
# Test albumentations

sample_path = os.path.join(data_dir, sample['sat_image_path'].iloc[0])
sample_mask_path = os.path.join(data_dir, sample['mask_path'].iloc[0])

# Define augmentation pipeline
transform = alb.Compose([
    alb.RandomCrop(width=256, height=256),
    alb.HorizontalFlip(p=0.5),
    alb.RandomBrightnessContrast(p=0.2),],
    # we want the mask and the image to have the same augmentation (or at least the same crop)
    # this way we pass the image and the mask simultaneously to the pipeline
    additional_targets={'image': 'image', 'mask': 'mask'}
    )

# Read an image with OpenCV and convert it to the RGB colorspace
image = cv2.imread(sample_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image_mask = cv2.imread(sample_mask_path)
image_mask = cv2.cvtColor(image_mask, cv2.COLOR_BGR2RGB)

# Augment an image
transformed = transform(image = image, mask = image_mask)
transformed_image = transformed['image']
transformed_image_mask = transformed['mask']

In [None]:
plt.subplot(121),plt.imshow(transformed_image),plt.title('Image');
plt.subplot(122),plt.imshow(transformed_image_mask),plt.title('Mask');

In [None]:
# Set up summary writer for tensorboard
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(os.path.join(log_dir, current_time))

#writer.add_scalar('plot', sclar_x, sclar_y)

In [None]:
# Training

# configure hyperparameters
epochs = 50

# init data loader/generator
dataloader = None

# init model, optimizer
model = UNet(3, 16, 256, 5)
#print(model)
opt = None
loss_func = None

sample = torch.tensor(transformed_image, dtype=torch.float).T # shape = (3, 256, 256)

model.eval()

pred = model(sample)

pred = torch.argmax(pred.T, 2).detach().numpy()

print(pred.shape)

plt.subplot(121),plt.imshow(pred),plt.title('Image');
plt.subplot(122),plt.imshow(transformed_image_mask),plt.title('Mask');

# for epoch in range(epochs):


#     for x, y in dataloader:
#         pred = model(x)
#         loss = loss_func(pred, y)

#         loss.backward()
#         opt.step()
#         opt.zero_grad()
