Notebook notes: probably the most important notebook. Takes the model weights from transfer learning and then trains on the 2000+ CAFO images with Microsoft segmentation. Then saves model

In [1]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
from tqdm import tqdm
from torch.nn import BCEWithLogitsLoss
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datetime import date
import sys
sys.path.insert(0, '../utils')
from ground_truth_dataset import groundTruthDataset
from data_functions import splitDataset, returnLoaders
from metrics import returnPreReF
sys.path.insert(0, '../models')
from unet_model import UNet
sys.path.insert(0, '../train')
from training import train_one_epoch, valid_one_epoch 

In [2]:
NUM_CLASSES = 2 # For ground truth data, there's 2 classes of Background, CAFO
batch_size = 8

In [3]:
dataset = groundTruthDataset("../../../../../datadrive/data/raw/ground_truth/", 
                             transform=True, 
                             make_small=True, 
                             ignore_lagoon=True)
#dataset = groundTruthDataset("../../../segmentation_ground_truth", make_small=True)
datasets = splitDataset(dataset)
trainloader, validloader, testloader = returnLoaders(datasets, batch_size, True)
model = UNet(3, NUM_CLASSES) # 3 Channels, 2 Classes (background, CAFO)

# Transfer Learning Part

In [4]:
model_pretrained = UNet(3, 2)
model_pretrained.load_state_dict(torch.load('../../../saved_models/driven/07_24_driven_lr_0.0003_epochs_40_batch_size_4.pth'))

<All keys matched successfully>

In [5]:
# Transferring over the weights in new model
model_dict = model.state_dict()
pretrained_model_dict = model_pretrained.state_dict()
for key in pretrained_model_dict:
    if (key != 'outc.conv.weight') & (key != 'outc.conv.bias'):
        model_dict[key] = pretrained_model_dict[key]

# Regular Training

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
train_num_batches = len(trainloader)
valid_num_batches = len(validloader)
train_num_examples = len(trainloader.dataset)
valid_num_examples = len(validloader.dataset)

In [8]:
#Set model to either cpu or gpu
model.to(device)            

#Define loss function
#Weight due to class imbalance
pos_weight = torch.tensor([1, 30]) #23 is good when doing 3 class
pos_weight = torch.reshape(pos_weight,(1,2,1,1)).to(device)
criterion = BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3,
                                     weight_decay = 1e-7)

In [9]:
today = date.today()
date_prefix = today.strftime("%m_%d")
log_dir_suffix = f"{date_prefix}_groundtruth_lr_{3e-4}_epochs_{10}_batch_size_{4}"
log_dir = "../logs/groundtruth/" + log_dir_suffix
writer = SummaryWriter(log_dir=log_dir)

In [10]:
torch.manual_seed(0)
np.random.seed(0)
model.zero_grad()
class_list = [0, 1]
for epoch in range(8):
    ### TRAINING ###
    print("Beginning Training in Epoch " + str(epoch))
    with tqdm(total = train_num_batches) as epoch_pbar:
        model.train()
        train_loss, train_correct, \
            train_IoU = train_one_epoch(epoch, train_num_batches, model, 
                                        device, trainloader, epoch_pbar, 
                                        optimizer, writer, criterion)

    ### VALIDATION ###
    print("Beginning Validation in Epoch " + str(epoch))
    valid_loss = []
    valid_correct = 0

    conf_matrix = np.zeros((2, 2))

    with tqdm(total = valid_num_batches) as epoch_pbar:
        model.eval()                           
        valid_loss, valid_correct, \
            conf_matrix, valid_IoU = valid_one_epoch(epoch, valid_num_batches, model, 
                                                     device, validloader, epoch_pbar, 
                                                     optimizer, writer, criterion,
                                                     conf_matrix, class_list)

  0%|          | 0/193 [00:00<?, ?it/s]

Beginning Training in Epoch 0


Epoch 0 - loss 0.3962 - acc 0.9837 - Mean IoU 0.1090: 100%|██████████| 193/193 [01:35<00:00,  2.01it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 0


Epoch 0 - loss 0.2152 - acc 0.9914 - Mean IoU 0.2014 - Precision 0.2737 - Recall 0.4327: 100%|██████████| 42/42 [00:30<00:00,  1.38it/s]
Epoch 1 - loss 0.1976 - acc 0.9928 - Mean IoU 0.2179:   1%|          | 1/193 [00:00<00:32,  5.85it/s]

Beginning Training in Epoch 1


Epoch 1 - loss 0.1558 - acc 0.9914 - Mean IoU 0.1888: 100%|██████████| 193/193 [01:35<00:00,  2.01it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 1


Epoch 1 - loss 0.1658 - acc 0.9946 - Mean IoU 0.2052 - Precision 0.4335 - Recall 0.2805: 100%|██████████| 42/42 [00:30<00:00,  1.38it/s]
Epoch 2 - loss 0.0962 - acc 0.9940 - Mean IoU 0.2157:   1%|          | 1/193 [00:00<00:31,  6.07it/s]

Beginning Training in Epoch 2


Epoch 2 - loss 0.1056 - acc 0.9928 - Mean IoU 0.2096: 100%|██████████| 193/193 [01:36<00:00,  2.00it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 2


Epoch 2 - loss 0.1023 - acc 0.9909 - Mean IoU 0.2039 - Precision 0.2655 - Recall 0.4677: 100%|██████████| 42/42 [00:30<00:00,  1.37it/s]
Epoch 3 - loss 0.0795 - acc 0.9918 - Mean IoU 0.2682:   1%|          | 1/193 [00:00<00:30,  6.35it/s]

Beginning Training in Epoch 3


Epoch 3 - loss 0.0859 - acc 0.9924 - Mean IoU 0.2113: 100%|██████████| 193/193 [01:36<00:00,  2.01it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 3


Epoch 3 - loss 0.1099 - acc 0.9947 - Mean IoU 0.1876 - Precision 0.4488 - Recall 0.2437: 100%|██████████| 42/42 [00:30<00:00,  1.38it/s]
Epoch 4 - loss 0.0608 - acc 0.9944 - Mean IoU 0.2434:   1%|          | 1/193 [00:00<00:31,  6.04it/s]

Beginning Training in Epoch 4


Epoch 4 - loss 0.0789 - acc 0.9928 - Mean IoU 0.2217: 100%|██████████| 193/193 [01:36<00:00,  2.00it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 4


Epoch 4 - loss 0.0753 - acc 0.9926 - Mean IoU 0.2455 - Precision 0.3316 - Recall 0.4859: 100%|██████████| 42/42 [00:30<00:00,  1.38it/s]
Epoch 5 - loss 0.0495 - acc 0.9923 - Mean IoU 0.1664:   1%|          | 1/193 [00:00<00:33,  5.78it/s]

Beginning Training in Epoch 5


Epoch 5 - loss 0.0754 - acc 0.9926 - Mean IoU 0.2220: 100%|██████████| 193/193 [01:37<00:00,  1.99it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 5


Epoch 5 - loss 0.0731 - acc 0.9899 - Mean IoU 0.2198 - Precision 0.2639 - Recall 0.5679: 100%|██████████| 42/42 [00:30<00:00,  1.38it/s]
Epoch 6 - loss 0.0474 - acc 0.9943 - Mean IoU 0.2184:   1%|          | 1/193 [00:00<00:32,  5.92it/s]

Beginning Training in Epoch 6


Epoch 6 - loss 0.0708 - acc 0.9926 - Mean IoU 0.2288: 100%|██████████| 193/193 [01:36<00:00,  1.99it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 6


Epoch 6 - loss 0.0806 - acc 0.9948 - Mean IoU 0.2488 - Precision 0.4720 - Recall 0.3448: 100%|██████████| 42/42 [00:30<00:00,  1.39it/s]
Epoch 7 - loss 0.0596 - acc 0.9933 - Mean IoU 0.2840:   1%|          | 1/193 [00:00<00:32,  5.93it/s]

Beginning Training in Epoch 7


Epoch 7 - loss 0.0697 - acc 0.9928 - Mean IoU 0.2305: 100%|██████████| 193/193 [01:36<00:00,  2.00it/s]
  0%|          | 0/42 [00:00<?, ?it/s]

Beginning Validation in Epoch 7


Epoch 7 - loss 0.0686 - acc 0.9921 - Mean IoU 0.2471 - Precision 0.3191 - Recall 0.5225: 100%|██████████| 42/42 [00:30<00:00,  1.37it/s]


In [None]:
returnPreReF(conf_matrix, 1)

In [None]:
# Testing it out on validation
val_example = next(iter(trainloader))
inputs = val_example[0].to(device)
labels = val_example[1].to(device)
with torch.no_grad():
    outputs = model(inputs)
_, predictions = torch.max(outputs, 1)
predictions_one_hot = torch.nn.functional.one_hot(predictions).permute(0, 3, 1, 2)
y_pred = predictions.flatten().cpu().numpy()
_, blah = torch.max(labels, 1)
y_true = blah.flatten().cpu().numpy()

In [None]:
torch.sum((predictions[0] == 1).int())

# CRF attempt

In [None]:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import compute_unary, unary_from_softmax

In [None]:
def returnCRFmask(img_tensor, prediction, num_classes):
    """
    img_tensor : Tensor (Channel x Width x Length)
        Tensor of the original image that is gpu attached
    prediction : Tensor (Width x Length)
        Softmax output of the model
    num_classes : Int
        Number of prediction classes
    
    Returns
    -------
    A post-processed prediction mask for the image tensor
    
    """        
    # Changes input image into 255 format
    changedInput = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8') 
    
    # Get unary energy of the prediction image
    feat_first = prediction.reshape((num_classes, - 1)).cpu().numpy()
    unary = unary_from_softmax(feat_first)
    unary = np.ascontiguousarray(unary)        
    d = dcrf.DenseCRF2D(img_tensor.shape[2], img_tensor.shape[1], num_classes) # Create CRF filter
    d.setUnaryEnergy(unary)
    
    # Add original image to CRF
    d.addPairwiseGaussian(sxy=(3, 3), compat=5, kernel=dcrf.DIAG_KERNEL,
                          normalization=dcrf.NORMALIZE_SYMMETRIC)
    d.addPairwiseBilateral(sxy=(3, 3), srgb=(3, 3, 3), rgbim=np.ascontiguousarray(changedInput),
                       compat=10,
                       kernel=dcrf.DIAG_KERNEL,
                       normalization=dcrf.NORMALIZE_SYMMETRIC)
    
    Q = d.inference(5)
    res = np.argmax(Q, axis=0).reshape((img_tensor.shape[1], img_tensor.shape[2])) # Get the new mask    
    res = torch.nn.functional.one_hot(torch.Tensor(res).to(torch.int64)).permute(2, 0, 1) # Make it one hot
    
    return res
    

In [None]:
# New plotting function including CRF
kwarg_dict = { 'BACKGROUND':{'cmap':'prism', 'alpha': 0.5},
                  'Lagoon': {'cmap':'cool', 'alpha': 0.5},
                  'CAFO Shed': {'cmap':'hot', 'alpha': 0.0}}

def plotLabelPredictCRF(inputs, label, predictions_one_hot, softmax, 
                        num_classes, kwarg_dict):
    """
    inputs : Tensor (Channel x Width x Length)
        Tensor of a CAFO image
    label : Tensor (Num Classes x Width x Length)
        Tensor of the labels for original CAFO image
    predictions_one_hot : Tensor (Num Classes x Width x Length)
        One-hot tensor of predictions from model
    softmax : Tensor (Width x Length)
        Softmax output of the model
    num_classes : Int
        Number of classes
    kwarg_dict : Dict
        Dictionary for plotting
    
    Returns
    -------
    Plots the original image, prediction mask, and crf mask
    """
    crf_mask = returnCRFmask(inputs, softmax, num_classes)
    changedInput = (inputs.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8') 
    
    f = plt.figure()
    a= f.add_subplot(1, 3, 1)    
    plt.imshow(changedInput)
    if label != None:
        plt.imshow(label[0].cpu().numpy(), **kwarg_dict['BACKGROUND'])
        plt.imshow(label[1].cpu().numpy(), **kwarg_dict['CAFO Shed']) 
    plt.axis('off')

    a = f.add_subplot(1, 3, 2)
    plt.imshow(changedInput)
    plt.imshow(predictions_one_hot[0].cpu().numpy(), **kwarg_dict['BACKGROUND'])
    plt.imshow(predictions_one_hot[1].cpu().numpy(), **kwarg_dict['CAFO Shed'])    
    plt.axis('off')

    a = f.add_subplot(1, 3, 3)
    plt.imshow(changedInput)
    plt.imshow(crf_mask[0].numpy(), **kwarg_dict['BACKGROUND'])
    plt.imshow(crf_mask[1].numpy(), **kwarg_dict['CAFO Shed'])    
    plt.axis('off')
    plt.show()
    

In [None]:
for i in range(4):
    plotLabelPredictCRF(inputs[i], labels[i], predictions_one_hot[i], outputs[i], NUM_CLASSES, kwarg_dict)

# Now trying on Planet Data

In [None]:
import rasterio
from PIL import Image
import glob
from torchvision.transforms import ToTensor
planet_images_dir = '../../../../../datadrive/data/raw/planet_images_il-2019-07/'
#pic_list = ['planet_loc_103-date_2019-07-01.tif', 'planet_loc_107-date_2019-07-01.tif', 'planet_loc_110-date_2019-07-01.tif',
#           'planet_loc_112-date_2019-07-01.tif']
#image_list = [planet_images_dir + i for i in pic_list]
image_list = glob.glob(planet_images_dir + '*.tif')

In [None]:
def planet_images(picture, model, device, kwarg_dict, num_classes):
    
    with rasterio.open(picture) as src:
        b, g, r, n = src.read()
    rgb = np.stack((r,g,b), axis=0)   
    example = Image.fromarray((np.rollaxis(rgb/rgb.max(), 0, 3)*255).astype(np.uint8))    
    example = ToTensor()(example)
    with torch.no_grad():
        output = model((example.unsqueeze(0)).to(device))    
    _, predictions = torch.max(output, 1)
    predictions_one_hot = torch.nn.functional.one_hot(predictions, num_classes=num_classes).permute(0, 3, 1, 2).squeeze(0)
        
    plotLabelPredictCRF(torch.Tensor(rgb/rgb.max()), None, predictions_one_hot, 
                         output, num_classes, kwarg_dict)    

In [None]:
for i in range(10):
    planet_images(image_list[i], model, device, kwarg_dict, NUM_CLASSES)

## Comparing before and after

In [None]:
dir_2019 = '../../../../../datadrive/data/raw/planet_images_il-2019-07/'
dir_2020 = '../../../../../datadrive/data/raw/planet_images_il-2020-07/'
image_num_list = ['035', '045', '048', '081', '160', '175', '177', '227']

f = plt.figure(figsize=(8,17)) 
f.suptitle("2019 vs 2020 CAFO Prediction on Planet Satellite Images", y = .92, fontsize=14)
for i, number in enumerate(image_num_list):
    im2019 = dir_2019 + "planet_loc_"+number+"-date_2019-07-01.tif"
    im2020 = dir_2020 + "planet_loc_"+number+"-date_2020-07-01.tif"
    
    with rasterio.open(im2019) as src:
        b1, g1, r1, n1 = src.read()
    rgb1 = np.stack((r1,g1,b1), axis=0)   
    im1 = Image.fromarray((np.rollaxis(rgb1/rgb1.max(), 0, 3)*255).astype(np.uint8))    
    im1 = ToTensor()(im1)
    with torch.no_grad():
        output1 = model((im1.unsqueeze(0)).to(device))    
    _, predictions1 = torch.max(output1, 1)
    predictions_one_hot1 = torch.nn.functional.one_hot(predictions1).permute(0, 3, 1, 2)
    crf_mask1 = returnCRFmask(torch.Tensor(rgb1/rgb1.max()), output1, NUM_CLASSES)
    
    with rasterio.open(im2020) as src:
        b2, g2, r2, n2 = src.read()
    rgb2 = np.stack((r2,g2,b2), axis=0)   
    im2 = Image.fromarray((np.rollaxis(rgb2/rgb2.max(), 0, 3)*255).astype(np.uint8))    
    im2 = ToTensor()(im2)
    
    with torch.no_grad():
        output2 = model((im2.unsqueeze(0)).to(device))    
    _, predictions2 = torch.max(output2, 1)
    predictions_one_hot2 = torch.nn.functional.one_hot(predictions2).permute(0, 3, 1, 2)
    crf_mask2 = returnCRFmask(torch.Tensor(rgb2/rgb2.max()), output2, NUM_CLASSES)
    
    # BEFORE IMAGE
    
    a1 = f.add_subplot(len(image_num_list), 6, 6*i + 1)
    plt.imshow(im1.permute(1, 2, 0).cpu().numpy())    
    plt.axis('off')
    
    a2 = f.add_subplot(len(image_num_list), 6, 6*i + 2)
    plt.imshow(im1.permute(1, 2, 0).cpu().numpy())
    plt.imshow(predictions_one_hot1[0][0].cpu().numpy(), **kwarg_dict['BACKGROUND'])    
    plt.imshow(predictions_one_hot1[0][1].cpu().numpy(), **kwarg_dict['CAFO Shed'])        
    a2.set_title(f'#CAFO: {torch.sum(predictions1 == 0).item()}')
    plt.axis('off')    
        
    a3 = f.add_subplot(len(image_num_list), 6, 6*i + 3)    
    plt.imshow(im1.permute(1, 2, 0).cpu().numpy())
    plt.imshow(crf_mask1[0].numpy(), **kwarg_dict['BACKGROUND'])
    plt.imshow(crf_mask1[1].numpy(), **kwarg_dict['CAFO Shed'])    
    plt.axis('off')    
    
    # AFTER IMAGE
    
    a4 = f.add_subplot(len(image_num_list), 6, 6*i + 4)
    plt.imshow(im2.permute(1, 2, 0).cpu().numpy())    
    plt.axis('off')    
    
    a5 = f.add_subplot(len(image_num_list), 6, 6*i + 5)
    plt.imshow(im2.permute(1, 2, 0).cpu().numpy())
    plt.imshow(predictions_one_hot2[0][0].cpu().numpy(), **kwarg_dict['BACKGROUND'])    
    plt.imshow(predictions_one_hot2[0][1].cpu().numpy(), **kwarg_dict['CAFO Shed'])        
    a5.set_title(f'#CAFO: {torch.sum(predictions2 == 0).item()}')
    plt.axis('off')
    
    a6 = f.add_subplot(len(image_num_list), 6, 6*i + 6)    
    plt.imshow(im2.permute(1, 2, 0).cpu().numpy())
    plt.imshow(crf_mask2[0].numpy(), **kwarg_dict['BACKGROUND'])
    plt.imshow(crf_mask2[1].numpy(), **kwarg_dict['CAFO Shed'])    
    plt.axis('off')    
plt.show()

# Saving the model

In [None]:
torch.save(model.state_dict(), f"../../../saved_models/finished/model8_10_ia_data.pth")