In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import torch
import gdown
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm.auto import tqdm
import statistics
import importlib
import shutil

In [None]:
import torch
import numpy as np
import tifffile
from torch import nn
from torchvision.models import resnet18

def calculate_centroids(image_label: np.ndarray) -> np.ndarray:
    width, height = image_label.shape

    # Dimensions of each grid cell
    grid_size_x = width // 8
    grid_size_y = height // 8

    # Initialize the result array with -1 and flag 0
    result_array = np.full((8, 8, 3), [-1, -1, 0]).astype(np.float32)

    # Process each grid cell
    for i in range(8):
        for j in range(8):
            x_start = j * grid_size_x
            y_start = i * grid_size_y
            x_end = x_start + grid_size_x
            y_end = y_start + grid_size_y

            # Extract the cell from the binary mask
            cell = image_label[y_start:y_end, x_start:x_end]

            # Find pixels belonging to the object in this cell
            points = np.column_stack(np.where(cell > 0))

            if points.size > 0:
                # Calculate the centroid of these points
                centroid = np.mean(points, axis=0)
                # Adjust centroid to the coordinate in the full image
                result_array[i, j] = [centroid[1] + x_start, centroid[0] + y_start, 1]

    return result_array

class SidewalkPrompter(nn.Module):
    def __init__(self):
        super(SidewalkPrompter, self).__init__()
        self.resnet = resnet18()
        self.resnet.fc = nn.Linear(in_features=512, out_features=192, bias=True)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), 8, 8, 3)
        x[:, :, :, 2] = self.sigmoid(x[:, :, :, 2])
        return x
    
class LossFn(nn.Module):
    """
    Loss function for SidewalkPrompter
    
    Details:
    - Ground truth is a 8x8x3 tensor with the last dimension being [x, y, flag]
    - flag is 1 if there is a centroid in the grid cell, 0 otherwise.
    - Loss fn = MSE([x_hat, y_hat], [x, y]) (if flag == 1) + \lambda BCE(flag_hat, flag)
    - \lambda is a hyperparameter, default value is 5.
    """
    def __init__(self):
        super(LossFn, self).__init__()
        self.mse = nn.MSELoss(reduction='sum')  # Set reduction to 'sum'
        self.bce = nn.BCELoss(reduction='sum')  # Set reduction to 'sum'
        
    def forward(self, pred, target):
        # Extract x, y, flag from the target tensor
        x, y, flag = target[:, :, :, 0], target[:, :, :, 1], target[:, :, :, 2]
        
        # Extract x_hat, y_hat, flag_hat from the pred tensor
        x_hat, y_hat, flag_hat = pred[:, :, :, 0], pred[:, :, :, 1], pred[:, :, :, 2]
        
        # Calculate MSE loss only for grid cells where flag is 1
        mse_loss = self.mse(x_hat * flag, x * flag) + self.mse(y_hat * flag, y * flag)
        
        # Calculate BCE loss
        bce_loss = self.bce(flag_hat, flag)

        return mse_loss + 5 * bce_loss

In [None]:
# Prepare the data
label_zip_path = '/content/drive/MyDrive/sidewalks/label.tar.gz'
train_zip_path = '/content/drive/MyDrive/sidewalks/train.tar.gz'
val_zip_path = '/content/drive/MyDrive/sidewalks/val.tar.gz'

In [None]:
# Download and unzip the files
data_path = os.path.join('.', 'data')
os.makedirs(data_path, exist_ok=True)

train_path = os.path.join(data_path, 'Train')
if not os.path.exists(train_path):
    pass
    !tar -xzf {train_zip_path} -C {data_path}
label_path = os.path.join(data_path, 'Label')
if not os.path.exists(label_path):
    pass
    !tar -xzf {label_zip_path} -C {data_path}
    !rm -rf {os.path.join(label_path, 'Test2')}
val_path = os.path.join(data_path, 'Test')
if not os.path.exists(val_path):
    pass
    !tar -xzf {val_zip_path} -C {data_path}


train_label_path = os.path.join(label_path, 'Train')
val_label_path = os.path.join(label_path, 'Test')

In [None]:
# train_files = [f for f in os.listdir(train_path) if (f.endswith('.tif') and np.max(tifffile.imread(os.path.join(train_label_path, f))) > 0)]
# val_files = [f for f in os.listdir(val_path) if (f.endswith('.tif') and np.max(tifffile.imread(os.path.join(val_label_path, f))) > 0)]
train_files = []
train_imgs = []
train_centroids = []
for f in os.listdir(train_path):
    if f.endswith('.tif'):
        ground_truth = tifffile.imread(os.path.join(train_label_path, f))
        if np.max(ground_truth) > 0:
            train_files.append(f)
            train_imgs.append(tifffile.imread(os.path.join(train_path, f)))
            train_centroids.append(calculate_centroids(ground_truth))
train_imgs, train_centroids = np.array(train_imgs), np.array(train_centroids)

val_files = []
val_imgs = []
val_centroids = []
for f in os.listdir(val_path):
    if f.endswith('.tif'):
        ground_truth = tifffile.imread(os.path.join(val_label_path, f))
        if np.max(ground_truth) > 0:
            val_files.append(f)
            val_imgs.append(tifffile.imread(os.path.join(val_path, f)))
            val_centroids.append(calculate_centroids(ground_truth))
val_imgs, val_centroids = np.array(val_imgs), np.array(val_centroids)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')    # DEBUG

In [None]:
class SidewalkDataset(Dataset):
    def __init__(self, data_path: str, label_path: str, files: list, transform=None):
        self.data_path = data_path
        self.label_path = label_path
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img = train_imgs[idx]
        file_name = self.files[idx]
        img = np.moveaxis(img, -1, 0)
        ground_truth = train_centroids[idx]
        return {'image': torch.tensor(img).float(), 'ground_truth': torch.tensor(ground_truth).float(), 'file_name': file_name}

In [None]:
train_dataset = SidewalkDataset(train_path, train_label_path, train_files)
val_dataset = SidewalkDataset(val_path, val_label_path, val_files)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

In [None]:
model = SidewalkPrompter().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
loss = LossFn().to(device).cpu()

In [None]:
# Result path to save the model
result_path = os.path.join('..', 'models')
os.makedirs(result_path, exist_ok=True)

In [None]:
num_epochs = 100

model.train()
for epoch in range(num_epochs):
    epoch_loss = []
    for batch in tqdm(train_loader):
        img, ground_truth = batch['image'].to(device), batch['ground_truth'].to(device)
        optimizer.zero_grad()
        pred = model(img)
        l = loss(pred, ground_truth)
        l.backward()
        optimizer.step()
        epoch_loss.append(l.item())
    print(f'Epoch {epoch+1} loss: {statistics.mean(epoch_loss)}')

    if (epoch+1) % 10 == 0:
        torch.save(model.state_dict(), os.path.join(result_path, f'sidewalk_prompter_epoch_{epoch+1:04d}.pt'), map_location=device)

In [None]:
torch.save(model.state_dict(), os.path.join(result_path, 'sidewalk_prompter.pth'), map_location=device)

In [None]:
model.eval()
val_loss = []
for batch in tqdm(val_loader):
    with torch.no_grad():
        img, ground_truth = batch['image'].to(device), batch['ground_truth'].to(device)
        pred = model(img)
        l = loss(pred, ground_truth)
        val_loss.append(l.item())
print(f'Validation loss: {statistics.mean(val_loss)}')