In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import torch
import gdown
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import statistics
import importlib


In [2]:
import SidewalkPrompter
importlib.reload(SidewalkPrompter)
from SidewalkPrompter import *

In [3]:
# Prepare the data
label_url = 'https://drive.google.com/file/d/1T8RDNBtxuBidm9ttNW9ShauDB49dBjWH/view?usp=drive_link'
train_url = 'https://drive.google.com/file/d/1De6cOV0UtS310-vkILWpmY7hiJZRSU9Y/view?usp=drive_link'
val_url = 'https://drive.google.com/file/d/1MFLm_5c0G6CUGNx2o2wrwGAKZHvUBCTI/view?usp=drive_link'
DVRPC_train_url = 'https://drive.google.com/file/d/1pHzGmjQUvrH1TY4XL1vw8xg72u8K5BuI/view?usp=drive_link'
DVRPC_val_url = 'https://drive.google.com/file/d/1YC5oUmGDa0sO14Qc4d-PM8cn2dbU1BKK/view?usp=drive_link'

In [4]:
# Download and unzip the files
data_path = os.path.join('..', 'data')
os.makedirs(data_path, exist_ok=True)
label_path = os.path.join(data_path, 'label.tar.gz')
train_path = os.path.join(data_path,'train.tar.gz')
val_path = os.path.join(data_path,'val.tar.gz')
DVRPC_train_path = os.path.join(data_path,'DVRPC_train.json')
DVRPC_val_path = os.path.join(data_path,'DVRPC_val.json')

train_path_new = os.path.join(data_path, 'Train')
if not os.path.exists(train_path_new):
    gdown.download(train_url, train_path, fuzzy=True)
    !tar -xzf {train_path} -C {data_path}
    !rm -rf {train_path}
train_path = train_path_new
label_path_new = os.path.join(data_path, 'Label')
if not os.path.exists(label_path_new):
    gdown.download(label_url, label_path, fuzzy=True)
    !tar -xzf {label_path} -C {data_path}
    # File too large, need to delete the file after unzipping
    !rm -rf {label_path} {os.path.join(label_path_new, 'Test2')}
label_path = label_path_new
val_path_new = os.path.join(data_path, 'Test')
if not os.path.exists(val_path_new):
    gdown.download(val_url, val_path, fuzzy=True)
    !tar -xzf {val_path} -C {data_path}
    !rm -rf {val_path}
val_path = val_path_new
if not os.path.exists(DVRPC_train_path):
    gdown.download(DVRPC_train_url, DVRPC_train_path, fuzzy=True)
if not os.path.exists(DVRPC_val_path):
    gdown.download(DVRPC_val_url, DVRPC_val_path, fuzzy=True)

train_label_path = os.path.join(label_path, 'Train')
val_label_path = os.path.join(label_path, 'Test')

In [5]:
train_files = [f for f in os.listdir(train_path) if (f.endswith('.tif') and np.max(tifffile.imread(os.path.join(train_label_path, f))) > 0)]
val_files = [f for f in os.listdir(val_path) if (f.endswith('.tif') and np.max(tifffile.imread(os.path.join(val_label_path, f))) > 0)]

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')    # DEBUG

In [7]:
class SidewalkDataset(Dataset):
    def __init__(self, data_path: str, label_path: str, files: list, transform=None):
        self.data_path = data_path
        self.label_path = label_path
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = tifffile.imread(os.path.join(self.data_path, self.files[idx]))
        label = tifffile.imread(os.path.join(self.label_path, self.files[idx]))
        img = np.moveaxis(img, -1, 0)
        if self.transform:
            img, label = self.transform(img, label)
        
        ground_truth = calculate_centroids(label)
        return {'image': torch.tensor(img).float(), 'label': label, 'ground_truth': torch.tensor(ground_truth).float()}

In [8]:
train_dataset = SidewalkDataset(train_path, train_label_path, train_files)
val_dataset = SidewalkDataset(val_path, val_label_path, val_files)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [12]:
model = SidewalkPrompter().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
loss = LossFn().to(device).cpu()

In [13]:
# Result path to save the model
result_path = os.path.join('..', 'models')
os.makedirs(result_path, exist_ok=True)

In [14]:
num_epochs = 100

model.train()
for epoch in range(num_epochs):
    epoch_loss = []
    for batch in tqdm(train_loader):
        img, ground_truth = batch['image'].to(device), batch['ground_truth'].to(device)
        optimizer.zero_grad()
        pred = model(img)
        l = loss(pred, ground_truth)
        l.backward()
        optimizer.step()
        epoch_loss.append(l.item())
    print(f'Epoch {epoch}: {statistics.mean(epoch_loss)}')

  0%|          | 0/2583 [00:00<?, ?it/s]

Epoch 0: 6088073.529169086


  0%|          | 0/2583 [00:00<?, ?it/s]

Epoch 1: 101827.426511687


  0%|          | 0/2583 [00:00<?, ?it/s]

Epoch 2: 81155.68583042973


  0%|          | 0/2583 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), os.path.join(result_path, 'sidewalk_prompter.pth'))

In [15]:
model.eval()
val_loss = []
for batch in tqdm(val_loader):
    with torch.no_grad():
        img, ground_truth = batch['image'].to(device), batch['ground_truth'].to(device)
        pred = model(img)
        l = loss(pred, ground_truth)
        val_loss.append(l.item())
print(f'Validation loss: {statistics.mean(val_loss)}')

  0%|          | 0/294 [00:00<?, ?it/s]

Validation loss: 72953.86735358206


In [None]:
ramdon_val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)
sample = next(iter(val_loader))
with torch.no_grad():
    img, ground_truth = sample['image'].to(device), sample['ground_truth'].to(device)
    pred = model(img)

pred = pred.cpu().numpy().squeeze(axis=0)
points_ground_truth = []
for i in range(ground_truth.shape[0]):
    for j in range(ground_truth.shape[1]):
        if ground_truth[i, j, 2] > 0:
            points_ground_truth.append(np.array(ground_truth[i, j, :2]))
points_pred = []
for i in range(pred.shape[0]):
    for j in range(pred.shape[1]):
        if pred[i, j, 2] > 0.5:
            points_pred.append(np.array(pred[i, j, :2]))

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(img.cpu().numpy().squeeze(axis=0).transpose(1, 2, 0))
ax[0].set_title('Input Image')
ax[1].imshow(sample['label'].squeeze(), cmap='gray')
ax[1].scatter([p[0] for p in points_ground_truth], [p[1] for p in points_ground_truth], c='r', s=1)
ax[1].set_title('Ground Truth')
ax[2].imshow(sample['label'].squeeze(), cmap='gray')
ax[2].scatter([p[0] for p in points_pred], [p[1] for p in points_pred], c='r', s=1)
ax[2].set_title('Prediction')