In [None]:
# Imports
import sys
sys.path.append('../')
from config import *

In [None]:
# VGG16 Model Architecture

class vgg16_conv_block(nn.Module):
    def __init__(self, input_channels, out_channels, dropout_rate=0.3, drop=True):
        # A convolutional block of a vgg containing a convolution, batch norm, relu, and a dropout layer.
        super().__init__()
        self.conv = nn.Conv2d(input_channels, out_channels, 3, 1, 1)  #This line creates a 2D convolutional layer using PyTorch's nn.Conv2d module.
        self.bn = nn.BatchNorm2d(out_channels) # normalize the activations of the convolutional layer in the neural network.
        self.relu = nn.ReLU(inplace=True) #introduces non-linearity (-ve --> zeros)
        self.dropout = nn.Dropout(dropout_rate)
        self.drop = drop

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        if self.drop:
            x = self.dropout(x)
        return x

def vgg16_layer(input_channels, out_channels, num_blocks, dropout_rate=0.3):
    # A layer of vgg blocks, ending with a 2x2 max pooling layer.
    
    layers = []
    for _ in range(num_blocks): 
        layers.append(vgg16_conv_block(input_channels, out_channels, dropout_rate))
        input_channels = out_channels
    layers.append(nn.MaxPool2d(2, 2))

    return nn.Sequential(*layers)

class VGG16(nn.Module):
    def __init__(self, input_channels = 3, num_classes=10, dropout_rates=[0.3, 0.4], convlayers = [16, 32, 64, 128, 256, 512, 512], netlayers = [4096, 4096], input_size = context_width):
        # A full VGG model, with modifiable convolutional layer and fully connected layer parameters        

        convlayers = [input_channels] + convlayers # Create input channel layer

        final_size = input_size // (2**(len(convlayers)-1))
        netlayers = [convlayers[-1]*final_size*final_size] + netlayers + [num_classes] # Add a net layer for the final feature map size and the number of classes

        super(VGG16, self).__init__()

        self.conv = nn.Sequential(
            *(vgg16_layer(convlayers[i], convlayers[i+1], 2, dropout_rates[0]) for i in range(len(convlayers)-1))
        )

        self.net = nn.Sequential(
            nn.Dropout(dropout_rates[1]),
            nn.Flatten(),
            *(nn.Linear(netlayers[i], netlayers[i+1], bias=True) for i in range(len(netlayers)-1))
        )

    def forward(self, x):
        return self.net(self.conv(x))

In [None]:
# Load Dataset

X = []
y = []

distribution = [0]*7

pathX = '../data/dataset/X'
pathY = '../data/dataset/y'
files = [f for f in os.listdir(pathX) if f.endswith('.tif')]

# load data from files
for file in files:
    X.append(rio.open(os.path.join(pathX, file)).read())
    y.append(rio.open(os.path.join(pathY, file)).read())

# flatten outputs
for i in range(len(X)):
    X[i] = np.moveaxis(X[i], 0, -1)
    y[i] = np.moveaxis(y[i], 0, -1).flatten()

    for j in range(len(y[i])):
        distribution[y[i][j]] += 1

    one_hot = np.zeros(decision_width*decision_width*7, dtype=int)  # Create a one-hot encoded array
    for j in range(len(y[i])):
        one_hot[j*7+y[i][j]] = 1

    y[i] = one_hot.astype(np.int8)

distribution = np.array(distribution)
distribution = distribution / np.sum(distribution)

print('X shape:', np.array(X).shape)
print('y shape:', np.array(y).shape)    

In [None]:
# Training Parameters

batch_size = 32

model = VGG16(
    input_channels=8,
    num_classes=decision_width*decision_width*7,
    convlayers=convlayers,
    netlayers=netlayers,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001) #Adam optimizer with a learning rate of 0.001.


# input augmentation with transforms

transform = transforms.Compose([
    transforms.ToTensor(), #Convert the data to PyTorch tensors.
])

class CustomDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.y[idx]

        if self.transform:
            x = self.transform(x)

        return x, y

dataset = CustomDataset(X, y, transform=transform)

train_size = int(0.8 * len(dataset)) #80% of the data for training.
test_size = len(dataset) - train_size #20% of the data for testing.

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) #Splitting the dataset into training and testing sets.


In [None]:
# Loss function and Class Weighting

class_weights = 1/distribution
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device) #Convert the class weights to a PyTorch tensor and move it to the device (GPU or CPU).

cross_entropy = torch.nn.CrossEntropyLoss(reduction='none', weight=class_weights) #Cross entropy loss function with inverse class weights.

# Custom loss function to weight individual cross entropies.
def loss_fn(outputs, labels, error_margin = 1):
    # the error margin is how far the predicted class can be from an instance of the true class before being counted as an error.
    losses = []

    outputs = outputs.reshape(decision_width, decision_width, 7)
    labels = labels.reshape(decision_width, decision_width, 7)
    labels = labels.argmax(axis=2)

    for i in range(0, decision_width):
        for j in range(0, decision_width):
            output = outputs[i, j]
            label = labels[i, j]
            
            # Relaxed Error
            # predicted_class = output.argmax()

            # margin = labels[max(0, i-error_margin):i+error_margin, max(0, j-error_margin):j+error_margin]

            # if predicted_class in margin:
            #     label = torch.Tensor(predicted_class).to(device)    

            loss = cross_entropy(output, label)
            losses.append(loss)
    
    return torch.stack(losses).mean()

# Directly weighing dataset by class combination. (sum of weights)
# weights = torch.tensor(np.zeros(len(train_dataset)), dtype=torch.float32) #Initialize weights to zero.
# 
# for i in range(len(train_dataset)): #loop through each training sample
#     train = train_dataset[i][1].reshape(decision_width*decision_width, 7).argmax(axis=1) #Reshape the training sample to a 2D array.
#     for cl in train:
#         weights[i] += classWeights[cl] #Add the class weight to the sample weight.
# sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
# 
# train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)

train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) #DataLoader for testing data.



In [None]:
# Training loop
model.train()

best_epoch = 0
min_loss = float('inf') #Initialize minimum loss to infinity.
global_start_time = time.time() #Start time for training.

for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    epoch_start_time = time.time() #Start time for epoch.

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device) 

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = [loss_fn(out, label) for out, label in zip(outputs, labels)] #Calculate the loss.
        loss = torch.stack(loss).mean()

        loss.backward()
        optimizer.step()

        total += batch_size*decision_width*decision_width #Total number of samples.
        running_loss += loss.item() #Accumulate loss.

        # calculating accuracy
        for label, output in zip(labels, outputs):
            for i in range(0, len(label), 7):
                correct+= (torch.argmax(label[i:i+7]) == torch.argmax(output[i:i+7])).item() #Count correct predictions.

    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%, Time: {time.time() - epoch_start_time:.2f}s')

    # save best model
    if min_loss >= running_loss:
        min_loss = running_loss
        best_epoch = epoch + 1
        torch.save(model.state_dict(), f'vgg16_trained.pth')

# load best model
model.load_state_dict(torch.load(f'vgg16_trained.pth'))
print(f'\nBest Loss: {min_loss/len(train_loader):.4f}, Epoch {best_epoch} model saved.')
print(f'Total Training Time: {time.time() - global_start_time:.2f}s, Average Time per Epoch: {(time.time() - global_start_time)/num_epochs:.2f}s')


In [None]:
# Testing
model.eval()

correct = 0
total = 0
running_loss = 0.0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device) #Move data to GPU if available.

        outputs = model(inputs)

        loss = [loss_fn(out, label) for out, label in zip(outputs, labels)] #Calculate the loss.
        loss = torch.stack(loss).mean()

        running_loss += loss.item() #Accumulate loss.
        total += labels.size(0)*decision_width*decision_width #Total number of samples.
        
        for label, output in zip(labels, outputs):
            for i in range(0, len(label), 7):
                correct+= (torch.argmax(label[i:i+7]) == torch.argmax(output[i:i+7])).item() #Count correct predictions.

accuracy = 100 * correct / total #Calculate accuracy.

print(f'Test Loss: {running_loss/len(test_loader):.4f}, Test Accuracy: {accuracy:.2f}%')

In [None]:
# Visualize Model Predictions
inputs, labels = random.choice(list(test_loader))
inputs, labels = inputs.to(device), labels.to(device) #Move data to GPU if available.

outputs = model(inputs)
outputs = outputs.cpu().detach().numpy()
labels = labels.cpu().detach().numpy()

for _ in range(3):
    ind = random.randint(0, len(inputs)-1)
    input, label, output = inputs[ind], labels[ind], outputs[ind] 
    
    # show 8 input bands
    for band in range(8):
        plt.subplot(1, 10, band+1)
        plt.imshow(input[band].cpu().numpy(), cmap='gray')
        plt.title('Band {}'.format(band+1))
        plt.axis('off')

    # show output and label
    plt.subplot(1,10,9)
    plt.imshow(label.reshape(decision_width, decision_width, 7).argmax(axis=2), cmap='gray')
    plt.title('Label')
    plt.axis('off')

    plt.subplot(1,10,10)
    plt.imshow(output.reshape(decision_width, decision_width, 7).argmax(axis=2), cmap='gray')
    plt.title('Output')
    plt.axis('off')
    plt.tight_layout()
    plt.show()    


In [None]:
# Class Metrics

# calculate tp, tn, fp, fn for each class, then APRF1 ([[]*4]*7)
labelcount = {i:0 for i in range(7)}
outcount = {i:0 for i in range(7)}
tp = [0]*7
fp = [0]*7
fn = [0]*7

with torch.no_grad():
    for x, yi in zip(X, y):
        x = torch.tensor(x, dtype=torch.float32).to(device)
        x = x.permute(2, 0, 1) # Move channels to first dimension
        x = x.unsqueeze(0) # Add batch dimension

        out = model(x)

        out = out.flatten()

        for i in range(0, len(yi), 7):
            outlabel = torch.argmax(out[i:i + 7])
            label = torch.argmax(torch.tensor(yi[i:i + 7]))

            labelcount[label.item()] += 1
            outcount[outlabel.item()] += 1

            if label.item() == outlabel.item():
                tp[label.item()] += 1
            else:
                fp[outlabel.item()] += 1
                fn[label.item()] += 1
        
outcount = dict(sorted(outcount.items()))
labelcount = dict(sorted(labelcount.items()))

precisions = [0]*7
recalls = [0]*7

for i in range(7):
    if tp[i] + fp[i] != 0:
        precisions[i] = tp[i] / (tp[i] + fp[i])
    if tp[i] + fn[i] != 0:
        recalls[i] = tp[i] / (tp[i] + fn[i])

print("True Positives:", tp)
print("False Positives:", fp)
print("False Negatives:", fn)

# plot deviation bar diagram using the % error of the counts from the model and the counts from the labels 
# (higher value -> model overpredicts, lower value -> model underpredicts)
plt.figure(figsize=(6, 4))
plt.bar(classNames.values(), [(outcount[i] - labelcount[i] )/(labelcount[i] + 1e-6) for i in labelcount.keys()], color=classColorsNormalized.values(), alpha=0.7)
plt.title('Percent Error of Model Prediction Distribution')
plt.xlabel('Class')
plt.ylabel('Error')
# plt.ylim(-1, 1)
plt.xticks(rotation=45)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
plt.grid(axis='y')
plt.tight_layout()
plt.show()

# plot precision and recall histograms in subplot
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.bar(classNames.values(), precisions, color=classColorsNormalized.values(), alpha=0.7)
plt.title('Precision')
plt.xlabel('Class')
plt.ylabel('Precision')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
plt.grid(axis='y')

plt.subplot(1, 2, 2)
plt.bar(classNames.values(), recalls, color=classColorsNormalized.values(), alpha=0.7)
plt.title('Recall')
plt.xlabel('Class')
plt.ylabel('Recall')
plt.ylim(0, 1)
plt.xticks(rotation=45)
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.0%}'))
plt.grid(axis='y')
plt.show()

In [None]:
# Full Image Prediction

generation_start_time = time.time() #Start time for full image prediction.
path = '../data/raw_labeled_data/images/'
pathY = '../data/raw_labeled_data/annotations/'
files = [f for f in os.listdir(path) if f.endswith('.tif')]


padding_width = (context_width - decision_width) // 2
for file in files:

    imgX = rio.open(os.path.join(path, file)).read()
    imgX = np.pad(imgX, ((0, 0), (padding_width, padding_width), (padding_width, padding_width)), mode='edge')
    imgX = np.moveaxis(imgX, 0, -1)  # Move the channel dimension to the last position

    imgY = rio.open(os.path.join(pathY, file)).read()[0]


    # color and display label
    colored_imgY = np.zeros((imgY.shape[0], imgY.shape[1], 3), dtype=np.uint8)
    for class_id, color in classColors.items():
        colored_imgY[imgY == class_id] = color

    plt.imshow(colored_imgY)
    plt.title('Colored Label')
    plt.axis('off')
    plt.show()


    # predict full image
    out = np.zeros((imgY.shape[0], imgY.shape[1], 1), dtype=np.float32)

    for i in range(0, imgX.shape[0]-context_width, decision_width):
        for j in range(0, imgX.shape[1]-context_width, decision_width):
            x = torch.from_numpy(imgX[i:i + context_width, j:j + context_width, :]).float().to(device)
            x = x.permute(2, 0, 1).unsqueeze(0)  # Add batch dimension and permute to (batch_size, channels, height, width)
            
            out1 = model(x)
            out1 = out1.cpu().detach().numpy()
            out[i:i + decision_width, j:j + decision_width, :] = out1.reshape(decision_width, decision_width, 7).argmax(axis=2).reshape(decision_width, decision_width, 1)    
        
        if i//decision_width % 5 == 0:
            print(f"{i/imgX.shape[0]*100:.2f}% done")
    output_path = os.path.join('../masksVGG/training/', file)
    with rio.open(
        output_path,
        'w',
        driver='GTiff',
        height=out.shape[0],
        width=out.shape[1],
        count=1,
        dtype=out.dtype,
        crs=rio.open(os.path.join(path, file)).crs,
        transform=rio.open(os.path.join(path, file)).transform,
    ) as dst:
        dst.write(out[:, :, 0], 1)

    colored_out = np.zeros((out.shape[0], out.shape[1], 3), dtype=np.uint8)

    for class_id, color in classColors.items():
        colored_out[out[:, :, 0] == class_id] = color

    plt.imshow(colored_out)
    plt.title('Colored Output')
    plt.axis('off')
    plt.show()
    print(f'Image Segmentation Time: {time.time() - generation_start_time:.2f}s')

In [None]:
# Full Image Prediction

generation_start_time = time.time() #Start time for full image prediction.
path = '../data/test_data/3/images/'
pathY = '../data/test_data/3/colored/'
files = [f for f in os.listdir(path) if f.endswith('.tif')]
filesy = [f for f in os.listdir(pathY) if f.endswith('.png')]


padding_width = (context_width - decision_width) // 2
for file, yfile in zip(files, filesy):

    imgX = rio.open(os.path.join(path, file)).read()
    imgX = np.pad(imgX, ((0, 0), (padding_width, padding_width), (padding_width, padding_width)), mode='edge')
    imgX = np.moveaxis(imgX, 0, -1)  # Move the channel dimension to the last position

    imgY = rio.open(os.path.join(pathY, yfile)).read()
    imgY = np.moveaxis(imgY, 0, -1)  # Move the channel dimension to the last position

    print(file)

    # color and display label

    plt.imshow(imgY)
    plt.title('image')
    plt.axis('off')
    plt.show()


    # predict full image
    out = np.zeros((imgY.shape[0], imgY.shape[1], 1), dtype=np.float32)

    for i in range(0, imgX.shape[0]-context_width, decision_width):
        for j in range(0, imgX.shape[1]-context_width, decision_width):
            x = torch.from_numpy(imgX[i:i + context_width, j:j + context_width, :]).float().to(device)
            x = x.permute(2, 0, 1).unsqueeze(0)  # Add batch dimension and permute to (batch_size, channels, height, width)
            
            out1 = model(x)
            out1 = out1.cpu().detach().numpy()
            out[i:i + decision_width, j:j + decision_width, :] = out1.reshape(decision_width, decision_width, 7).argmax(axis=2).reshape(decision_width, decision_width, 1)    
        
        if i//decision_width % 5 == 0:
            print(f"{i/imgX.shape[0]*100:.2f}% done")

    output_path = os.path.join('../masksVGG/3/', file)
    with rio.open(
        output_path,
        'w',
        driver='GTiff',
        height=out.shape[0],
        width=out.shape[1],
        count=1,
        dtype=out.dtype,
        crs=rio.open(os.path.join(path, file)).crs,
        transform=rio.open(os.path.join(path, file)).transform,
    ) as dst:
        dst.write(out[:, :, 0], 1)
        
    colored_out = np.zeros((out.shape[0], out.shape[1], 3), dtype=np.uint8)

    for class_id, color in classColors.items():
        colored_out[out[:, :, 0] == class_id] = color

    plt.imshow(colored_out)
    plt.title('Colored Output')
    plt.axis('off')
    plt.show()
    print(f'Image Segmentation Time: {time.time() - generation_start_time:.2f}s')

In [None]:
# Full Image Prediction

generation_start_time = time.time() #Start time for full image prediction.
path = '../data/test_data/1/images/'
pathY = '../data/test_data/1/colored/'
files = [f for f in os.listdir(path) if f.endswith('.tif')]
filesy = [f for f in os.listdir(pathY) if f.endswith('.png')]


padding_width = (context_width - decision_width) // 2
for file, yfile in zip(files, filesy):

    imgX = rio.open(os.path.join(path, file)).read()
    imgX = np.pad(imgX, ((0, 0), (padding_width, padding_width), (padding_width, padding_width)), mode='edge')
    imgX = np.moveaxis(imgX, 0, -1)  # Move the channel dimension to the last position

    imgY = rio.open(os.path.join(pathY, yfile)).read()
    imgY = np.moveaxis(imgY, 0, -1)  # Move the channel dimension to the last position

    print(file)

    # color and display label

    plt.imshow(imgY)
    plt.title('image')
    plt.axis('off')
    plt.show()


    # predict full image
    out = np.zeros((imgY.shape[0], imgY.shape[1], 1), dtype=np.float32)

    for i in range(0, imgX.shape[0]-context_width, decision_width):
        for j in range(0, imgX.shape[1]-context_width, decision_width):
            x = torch.from_numpy(imgX[i:i + context_width, j:j + context_width, :]).float().to(device)
            x = x.permute(2, 0, 1).unsqueeze(0)  # Add batch dimension and permute to (batch_size, channels, height, width)
            
            out1 = model(x)
            out1 = out1.cpu().detach().numpy()
            out[i:i + decision_width, j:j + decision_width, :] = out1.reshape(decision_width, decision_width, 7).argmax(axis=2).reshape(decision_width, decision_width, 1)    
        
        if i//decision_width % 5 == 0:
            print(f"{i/imgX.shape[0]*100:.2f}% done")

    output_path = os.path.join('../masksVGG/1/', file)
    with rio.open(
        output_path,
        'w',
        driver='GTiff',
        height=out.shape[0],
        width=out.shape[1],
        count=1,
        dtype=out.dtype,
        crs=rio.open(os.path.join(path, file)).crs,
        transform=rio.open(os.path.join(path, file)).transform,
    ) as dst:
        dst.write(out[:, :, 0], 1)
        
    colored_out = np.zeros((out.shape[0], out.shape[1], 3), dtype=np.uint8)

    for class_id, color in classColors.items():
        colored_out[out[:, :, 0] == class_id] = color

    plt.imshow(colored_out)
    plt.title('Colored Output')
    plt.axis('off')
    plt.show()
    print(f'Image Segmentation Time: {time.time() - generation_start_time:.2f}s')