In [None]:
import os
import sys
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

NUM_CHANNELS = 3  # RGB images
PIXEL_DEPTH = 255
NUM_LABELS = 2
TRAINING_SIZE = 20
IMG_PATCH_SIZE = 16
BATCH_SIZE = 16
NUM_EPOCHS = 100
LEARNING_RATE = 0.01
MOMENTUM = 0.0
WEIGHT_DECAY = 5e-4
RESTORE_MODEL = False

# Define your neural network model using PyTorch
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(NUM_CHANNELS, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * (IMG_PATCH_SIZE // 4) * (IMG_PATCH_SIZE // 4), 512)
        self.fc2 = nn.Linear(512, NUM_LABELS)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * (IMG_PATCH_SIZE // 4) * (IMG_PATCH_SIZE // 4))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Custom dataset class for loading and processing the data
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.from_numpy(data).float()
        self.labels = torch.from_numpy(labels).float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Function to extract patches from an image
def img_crop(im, w, h):
    list_patches = []
    imgwidth = im.shape[0]
    imgheight = im.shape[1]
    is_2d = len(im.shape) < 3
    for i in range(0, imgheight, h):
        for j in range(0, imgwidth, w):
            if is_2d:
                im_patch = im[j : j + w, i : i + h]
            else:
                im_patch = im[j : j + w, i : i + h, :]
            list_patches.append(im_patch)
    return list_patches

# Extract data and labels using PyTorch data loaders
def extract_data(filename, num_images):
    imgs = []
    for i in range(1, num_images + 1):
        imageid = "satImage_%.3d" % i
        image_filename = os.path.join(filename, imageid + ".png")
        if os.path.isfile(image_filename):
            print("Loading " + image_filename)
            img = np.array(Image.open(image_filename))
            imgs.append(img)
        else:
            print("File " + image_filename + " does not exist")

    num_images = len(imgs)
    IMG_WIDTH = imgs[0].shape[0]
    IMG_HEIGHT = imgs[0].shape[1]
    N_PATCHES_PER_IMAGE = (IMG_WIDTH // IMG_PATCH_SIZE) * (IMG_HEIGHT // IMG_PATCH_SIZE)

    img_patches = [
        img_crop(imgs[i], IMG_PATCH_SIZE, IMG_PATCH_SIZE) for i in range(num_images)
    ]
    data = [
        img_patches[i][j]
        for i in range(len(img_patches))
        for j in range(len(img_patches[i]))
    ]

    return np.asarray(data)

def extract_labels(filename, num_images):
    gt_imgs = []
    for i in range(1, num_images + 1):
        imageid = "satImage_%.3d" % i
        image_filename = os.path.join(filename, imageid + ".png")
        if os.path.isfile(image_filename):
            print("Loading " + image_filename)
            img = np.array(Image.open(image_filename))
            gt_imgs.append(img)
        else:
            print("File " + image_filename + " does not exist")

    num_images = len(gt_imgs)
    gt_patches = [
        img_crop(gt_imgs[i], IMG_PATCH_SIZE, IMG_PATCH_SIZE) for i in range(num_images)
    ]
    data = np.asarray(
        [
            gt_patches[i][j]
            for i in range(len(gt_patches))
            for j in range(len(gt_patches[i]))
        ]
    )
    labels = np.asarray(
        [value_to_class(np.mean(data[i])) for i in range(len(data))]
    ).astype(np.float32)

    return labels

# Assign a label to a patch v
def value_to_class(v):
    foreground_threshold = 0.25
    df = np.sum(v)
    if df > foreground_threshold:
        return [0, 1]  # road
    else:
        return [1, 0]  # bgrd

# Define the training loop using PyTorch
def train(model, train_loader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

if __name__ == "__main__":
    data_dir = "training/"
    train_data_filename = os.path.join(data_dir, "images/")
    train_labels_filename = os.path.join(data_dir, "groundtruth/")

    train_data = extract_data(train_data_filename, TRAINING_SIZE)
    train_labels = extract_labels(train_labels_filename, TRAINING_SIZE)

    # Balancing training data
    c0 = np.sum(train_labels[:, 0] == 1)
    c1 = np.sum(train_labels[:, 1] == 1)
    print(f"Number of data points per class: c0 = {c0}, c1 = {c1}")

    min_c = min(c0, c1)
    idx0 = np.where(train_labels[:, 0] == 1)[0][:min_c]
    idx1 = np.where(train_labels[:, 1] == 1)[0][:min_c]
    new_indices = np.concatenate((idx0, idx1))

    train_data = train_data[new_indices, :, :, :]
    train_labels = train_labels[new_indices]

    train_dataset = CustomDataset(train_data, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY
    )

    train(model, train_loader, criterion, optimizer, NUM_EPOCHS)

    # Save the model
    torch.save(model.state_dict(), "pytorch_model.pth")

    print("Training complete.")
