## Imports

In [None]:
import os
import argparse
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
import torch.optim as optim
from skimage import filters

from dataloader import ForestDataset
from utils import get_train_weights, get_train_weights_binary
from utils import get_acc_seg, get_acc_seg_weighted, get_acc_nzero, get_acc_class, get_acc_binseg
from models import UNet
from eq_models import UNet_eq


## Parameters which should be the same as those used to train the model

In [None]:
savedir = f'UNET_eq_code_test'
load_epoch = 10
batch_size = 32
lr = 0.001
device = 'cuda'
model = 'unet_eq'

## Create datasets, dataloaders and model

In [None]:
train_dataset = ForestDataset(csv_file='ForestNetDataset/train.csv',
                                    root_dir='ForestNetDataset')
val_dataset = ForestDataset(csv_file='ForestNetDataset/val.csv',
                                    root_dir='ForestNetDataset')
test_dataset = ForestDataset(csv_file='ForestNetDataset/test.csv',
                                    root_dir='ForestNetDataset')

## Get weights for re-weighting optimisation due to inbalance in dataset
train_weights = get_train_weights(train_dataset)
train_weights_bin_seg = get_train_weights_binary(train_dataset)

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                         shuffle=False)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False)

if model == 'unet':
    net = UNet(3, 1)
elif model == 'unet_eq':
    net = UNet_eq(3, 1)
net = net.to(device)

train_weights = torch.from_numpy(train_weights).type(torch.float).to(device)
criterion_class = nn.CrossEntropyLoss(weight=train_weights)
criterion_seg = nn.BCEWithLogitsLoss(pos_weight=train_weights_bin_seg.to(device))

optimizer = optim.AdamW(net.parameters(), lr=lr)

## Load saved model

In [None]:
net.train()
net.eval()

checkpoint = torch.load(f'Outputs/{savedir}/model_{load_epoch}.pt')
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

net.eval()

## Plot the loss and accuracies

In [None]:
lines = np.loadtxt(open(f'Outputs/{savedir}/training.txt', 'r'), delimiter=' ', dtype=str)
loss = lines[:,4].astype('float')
train_loss = loss[::2]
val_loss = loss[1::2]
plt.plot(train_loss)
plt.plot(val_loss)
plt.savefig(f'Outputs/{savedir}/loss.png')
plt.show()

accseg = lines[:,7].astype('float')
train_accseg = accseg[::2]
val_accseg = accseg[1::2]
plt.plot(train_accseg)
plt.plot(val_accseg)
plt.savefig(f'Outputs/{savedir}/seg_acc.png')
plt.show()

accclass = lines[:,10].astype('float')
train_accclass = accclass[::2]
val_accclass = accclass[1::2]
plt.plot(train_accclass)
plt.plot(val_accclass)
plt.savefig(f'Outputs/{savedir}/class_Acc.png')
plt.show()


## Check the number of trainable parameters in the model

In [None]:
pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'Number of trainable parameters in the model : {pytorch_total_params}')

## Create some plots to analyse the models segmentation

In [None]:
img_idx = 1
break_idx = 6

In [None]:
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))
    
    outputs, out_class = net(inputs)

    inputs_rot = torch.rot90(inputs, k=1, dims=[-2,-1])
    segs_rot = torch.rot90(segs, k=1, dims=[-2,-1])
    
    outputs_rot, out_class_rot = net(inputs_rot)
    
    if i == break_idx:
        break
        
outputs = torch.sigmoid(outputs)
outputs[outputs>=0.5] = 1
outputs[outputs<0.5] = 0
outputs = torch.squeeze(outputs, dim=1).detach().cpu().numpy()

outputs_rot = torch.sigmoid(outputs_rot)
outputs_rot[outputs_rot>=0.5] = 1
outputs_rot[outputs_rot<0.5] = 0
outputs_rot = torch.squeeze(outputs_rot, dim=1).detach().cpu().numpy()

inputs = inputs.detach().cpu().numpy()
inputs_rot = inputs_rot.detach().cpu().numpy()
out_class = out_class.detach().cpu().numpy()
out_class_rot = out_class_rot.detach().cpu().numpy()
segs = segs.detach().cpu().numpy()
segs_rot = segs_rot.detach().cpu().numpy()
segs_labelled = segs_labelled.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()

In [None]:
print(f'Predicted label {np.argmax(out_class[img_idx])} - True label {labels[img_idx]}')
print(f'Predicted label {np.argmax(out_class_rot[img_idx])} - True label {labels[img_idx]}')
plt.figure(figsize=(21,14))

plt.subplot(2,3,1)
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
plt.title(f'Original Image')

plt.subplot(2,3,2)
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
segs_plot = np.repeat(np.expand_dims(segs[img_idx], axis=2), 4, axis=2)
segs_plot[:,:,0][segs_plot[:,:,0]==1] = 220
segs_plot[:,:,1][segs_plot[:,:,1]==1] = 50
segs_plot[:,:,2][segs_plot[:,:,2]==1] = 32
segs_plot[:,:,3][segs_plot[:,:,3]==1] = 200
plt.imshow(segs_plot)
plt.title(f'Original Image with Segmentation Map')

plt.subplot(2,3,3)
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
outputs_plot = np.repeat(np.expand_dims(outputs[img_idx], axis=2), 4, axis=2)
outputs_plot = outputs_plot.astype(int)
outputs_plot[:,:,0][outputs_plot[:,:,0]==1] = 0
outputs_plot[:,:,1][outputs_plot[:,:,1]==1] = 90
outputs_plot[:,:,2][outputs_plot[:,:,2]==1] = 181
outputs_plot[:,:,3][outputs_plot[:,:,3]==1] = 255
im = plt.imshow(outputs_plot)
plt.imshow(edge_sobel_segs)
plt.title(f'Original Segmentation Prediction')


plt.subplot(2,3,4)
plt.imshow(np.transpose(inputs_rot[img_idx].astype(int)+100, (1,2,0)))
plt.title(f'Image Rotated 90deg')

plt.subplot(2,3,5)
plt.imshow(np.transpose(inputs_rot[img_idx].astype(int)+100, (1,2,0)))
segs_rot_plot = np.repeat(np.expand_dims(segs_rot[img_idx], axis=2), 4, axis=2)
segs_rot_plot[:,:,0][segs_rot_plot[:,:,0]==1] = 220
segs_rot_plot[:,:,1][segs_rot_plot[:,:,1]==1] = 50
segs_rot_plot[:,:,2][segs_rot_plot[:,:,2]==1] = 32
segs_rot_plot[:,:,3][segs_rot_plot[:,:,3]==1] = 200
plt.imshow(segs_rot_plot)
plt.title(f'Original Image with Segmentation Map Rotated 90deg')

plt.subplot(2,3,6)
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
outputs_rot_plot = np.repeat(np.expand_dims(outputs_rot[img_idx], axis=2), 4, axis=2)
outputs_rot_plot = outputs_rot_plot.astype(int)
outputs_rot_plot[:,:,0][outputs_rot_plot[:,:,0]==1] = 0
outputs_rot_plot[:,:,1][outputs_rot_plot[:,:,1]==1] = 90
outputs_rot_plot[:,:,2][outputs_rot_plot[:,:,2]==1] = 181
outputs_rot_plot[:,:,3][outputs_rot_plot[:,:,3]==1] = 255
im = plt.imshow(outputs_rot_plot)
plt.imshow(edge_sobel_segs_rot)

plt.title(f'Segmentation Prediction with Rotated 90deg Image')
# plt.tight_layout()
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}.png')
plt.show()

In [None]:
plt.figure()
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
plt.axis('off')
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}_orig_img.png', bbox_inches='tight', pad_inches=0)
plt.show()

plt.figure()
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
plt.imshow(segs[img_idx], cmap='Greys', alpha=0.5)
plt.axis('off')
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}_orig_img_orig_seg.png', bbox_inches='tight', pad_inches=0)
plt.show()

plt.figure()
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
outputs_plot = np.repeat(np.expand_dims(outputs[img_idx], axis=2), 4, axis=2)
outputs_plot = outputs_plot.astype(int)
outputs_plot[:,:,0][outputs_plot[:,:,0]==1] = 136
outputs_plot[:,:,1][outputs_plot[:,:,1]==1] = 204
outputs_plot[:,:,2][outputs_plot[:,:,2]==1] = 238
outputs_plot[:,:,3][outputs_plot[:,:,3]==1] = 255
im = plt.imshow(outputs_plot)
plt.imshow(edge_sobel_segs)
plt.axis('off')
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}_orig_img_pred_seg_v2.png', bbox_inches='tight', pad_inches=0)
plt.show()

plt.figure()
plt.imshow(np.transpose(inputs_rot[img_idx].astype(int)+100, (1,2,0)))
plt.axis('off')
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}_rot_img.png', bbox_inches='tight', pad_inches=0)
plt.show()

plt.figure()
plt.imshow(np.transpose(inputs_rot[img_idx].astype(int)+100, (1,2,0)))
plt.imshow(segs_rot[img_idx], cmap='Greys', alpha=0.5)
plt.axis('off')
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}_rot_img_orig_seg.png', bbox_inches='tight', pad_inches=0)
plt.show()

plt.figure()
plt.imshow(np.transpose(inputs[img_idx].astype(int)+100, (1,2,0)))
outputs_rot_plot = np.repeat(np.expand_dims(outputs_rot[img_idx], axis=2), 4, axis=2)
outputs_rot_plot = outputs_rot_plot.astype(int)
outputs_rot_plot[:,:,0][outputs_rot_plot[:,:,0]==1] = 136
outputs_rot_plot[:,:,1][outputs_rot_plot[:,:,1]==1] = 204
outputs_rot_plot[:,:,2][outputs_rot_plot[:,:,2]==1] = 238
outputs_rot_plot[:,:,3][outputs_rot_plot[:,:,3]==1] = 255
im = plt.imshow(outputs_rot_plot)
plt.imshow(edge_sobel_segs_rot)
plt.axis('off')
plt.savefig(f'Outputs/{savedir}/seg_{break_idx}_{img_idx}_rot_img_pred_seg_v2.png', bbox_inches='tight', pad_inches=0)
plt.show()

## Print classification accuracies

In [None]:
running_acc_class = 0.0
for i, data in enumerate(trainloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    outputs, out_class = net(inputs)
    
    running_acc_class += get_acc_class(out_class, labels).item()
    
print(f'Train Classification Accuracy {running_acc_class/len(trainloader)}')


In [None]:
running_acc_class = 0.0
for i, data in enumerate(valloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    outputs, out_class = net(inputs)
    
    running_acc_class += get_acc_class(out_class, labels).item()
    
print(f'Validation Classification Accuracy {running_acc_class/len(valloader)}')


In [None]:
running_acc_class = 0.0
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    outputs, out_class = net(inputs)
    
    running_acc_class += get_acc_class(out_class, labels).item()
    
print(f'Test Classification Accuracy {running_acc_class/len(testloader)}')


### Test classification accuracy with rotated images

In [None]:
running_acc_class = 0.0
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    inputs_rot = torch.rot90(inputs, k=1, dims=[-2,-1])
    segs_rot = torch.rot90(segs, k=1, dims=[-2,-1])

    outputs, out_class = net(inputs_rot)
    
    running_acc_class += get_acc_class(out_class, labels).item()
    
print(f'Test Classification Accuracy {running_acc_class/len(testloader)}')


## Print segmentation accuracies

In [None]:
running_acc_mean = 0.0
running_acc_mean_1 = 0.0
running_acc_mean_0 = 0.0
for i, data in enumerate(trainloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    outputs, out_class = net(inputs)
    
    m1, m0, m_avg = get_acc_binseg(outputs, torch.unsqueeze(segs.float(), dim=1))
    running_acc_mean += m_avg.item()
    running_acc_mean_1 += m1.item()
    running_acc_mean_0 += m0.item()
    
print(f'Train Seg Accuracy Mean {running_acc_mean/len(trainloader)}')
print(f'Train Seg Accuracy Mean 1s {running_acc_mean_1/len(trainloader)}')
print(f'Train Seg Accuracy Mean 0s {running_acc_mean_0/len(trainloader)}')


In [None]:
running_acc_mean = 0.0
running_acc_mean_1 = 0.0
running_acc_mean_0 = 0.0
for i, data in enumerate(valloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    outputs, out_class = net(inputs)
    
    m1, m0, m_avg = get_acc_binseg(outputs, torch.unsqueeze(segs.float(), dim=1))
    running_acc_mean += m_avg.item()
    running_acc_mean_1 += m1.item()
    running_acc_mean_0 += m0.item()
    
print(f'Val Seg Accuracy Mean {running_acc_mean/len(valloader)}')
print(f'Val Seg Accuracy Mean 1s {running_acc_mean_1/len(valloader)}')
print(f'Val Seg Accuracy Mean 0s {running_acc_mean_0/len(valloader)}')


In [None]:
running_acc_mean = 0.0
running_acc_mean_1 = 0.0
running_acc_mean_0 = 0.0
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))

    outputs, out_class = net(inputs)
    
    m1, m0, m_avg = get_acc_binseg(outputs, torch.unsqueeze(segs.float(), dim=1))
    running_acc_mean += m_avg.item()
    running_acc_mean_1 += m1.item()
    running_acc_mean_0 += m0.item()
    
print(f'Test Seg Accuracy Mean {running_acc_mean/len(testloader)}')
print(f'Test Seg Accuracy Mean 1s {running_acc_mean_1/len(testloader)}')
print(f'Test Seg Accuracy Mean 0s {running_acc_mean_0/len(testloader)}')


### Segmentation accuracy with rotated images

In [None]:
running_acc_mean = 0.0
running_acc_mean_1 = 0.0
running_acc_mean_0 = 0.0
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))
    
    inputs_rot = torch.rot90(inputs, k=1, dims=[-2,-1])
    segs_rot = torch.rot90(segs, k=1, dims=[-2,-1])

    outputs, out_class = net(inputs_rot)
    
    m1, m0, m_avg = get_acc_binseg(outputs, torch.unsqueeze(segs_rot.float(), dim=1))
    running_acc_mean += m_avg.item()
    running_acc_mean_1 += m1.item()
    running_acc_mean_0 += m0.item()
    
print(f'Test Seg Accuracy Mean {running_acc_mean/len(testloader)}')
print(f'Test Seg Accuracy Mean 1s {running_acc_mean_1/len(testloader)}')
print(f'Test Seg Accuracy Mean 0s {running_acc_mean_0/len(testloader)}')


In [None]:
running_acc_mean = 0.0
running_acc_mean_1 = 0.0
running_acc_mean_0 = 0.0
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))
    
    inputs_rot = torch.rot90(inputs, k=2, dims=[-2,-1])
    segs_rot = torch.rot90(segs, k=2, dims=[-2,-1])

    outputs, out_class = net(inputs_rot)
    
    m1, m0, m_avg = get_acc_binseg(outputs, torch.unsqueeze(segs_rot.float(), dim=1))
    running_acc_mean += m_avg.item()
    running_acc_mean_1 += m1.item()
    running_acc_mean_0 += m0.item()
    
print(f'Test Seg Accuracy Mean {running_acc_mean/len(testloader)}')
print(f'Test Seg Accuracy Mean 1s {running_acc_mean_1/len(testloader)}')
print(f'Test Seg Accuracy Mean 0s {running_acc_mean_0/len(testloader)}')


In [None]:
running_acc_mean = 0.0
running_acc_mean_1 = 0.0
running_acc_mean_0 = 0.0
for i, data in enumerate(testloader):
    inputs, segs, labels = data
    inputs = inputs.to(device)
    segs = segs.to(device)
    labels = labels.to(device)-1
    segs_labelled = torch.mul(segs, labels.view(-1,1,1))
    
    inputs_rot = torch.rot90(inputs, k=3, dims=[-2,-1])
    segs_rot = torch.rot90(segs, k=3, dims=[-2,-1])

    outputs, out_class = net(inputs_rot)
    
    m1, m0, m_avg = get_acc_binseg(outputs, torch.unsqueeze(segs_rot.float(), dim=1))
    running_acc_mean += m_avg.item()
    running_acc_mean_1 += m1.item()
    running_acc_mean_0 += m0.item()
    
print(f'Test Seg Accuracy Mean {running_acc_mean/len(testloader)}')
print(f'Test Seg Accuracy Mean 1s {running_acc_mean_1/len(testloader)}')
print(f'Test Seg Accuracy Mean 0s {running_acc_mean_0/len(testloader)}')
