# Contrail Detection - Part II - Binary Classification
This is part 2 in a short series of notebooks that will aim to tackle the Kaggle competition of predicting the presence of contrails in infrared image bands. **Please see Part I of this notebook where I introduce the problem and go over using UNet.**

## Introduction
In Part I, we introduced the Kaggle competition of predicting contrails in the sky given infrared image bands. We explained why this detection is important, as well as go over the data in detail.

Using the infrared image bands, we converted them into human-interpretable false color images. This is an **image segmentation** task, where a single image is classified at the pixel level. In our case, each pixel is either part of a contrail, or it isn't. 

To handle this new task, we utilized a special neural network called a **UNet**, which consists of an encoder and a decoder. The encoder processes the input image directly, and encodes the image in a smaller latent dimension space. The output layers for the encoder are usually taken from a pre-trained image classification model. In our case, we chose MobileNet as our backbone. Meanwhile, the decoder upsamples the encoded image until it is the same size as the input image. We use pix2pix layers which primarily consist of transposed 2D convolutions to achieve this.

While the performance was satisfactory given enough epochs, we proposed an optimization to this network in the final section. As part of our exploratory data analysis, we examined what percentage of images and pixels actually contained contrails. We discovered that **approximately 70%** of images _do not_ contain contrails at all. Therefore, it could be feasible to implement a two-stage model: The first stage is a traditional binary classification at the image level. This model will predict whether the input image contains contrails **anywhere in the image.** The second stage will be our UNet that was trained in Part I.

The combination of these two should eliminate any false positives that might occur in images where there are no contrails, and subsequently improve our Dice coefficient score.

We will primarily go over the implementation and integration of the regular binary classification model in this notebook. For details on the UNet and its implementation, please refer to Part I.

## Import Packages
All packages that we used earlier are applicable here. There are no additions or changes necessary. Additionally, we will keep the seed consistent so the dataset is split the same way.

In [1]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from typing import Optional
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchinfo

from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.models.feature_extraction import create_feature_extractor

from torchmetrics import Dice, MeanMetric
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.functional import dice

import pytorch_lightning as pl
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS

pl.seed_everything(8128)

ROOT = '../'
DATA_DIR = os.path.join(ROOT, 'data', 'google-research-identify-contrails-reduce-global-warming', 'validation')
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Seed set to 8128


## UNet Definitions
Next, we'll fully define the UNet which will comprise the second stage of our model. We will use everything other than the final `LightningModule` definition, as that will be adjusted due to the addition of the binary classification model.

In [2]:
# Function to get the segmented images
def get_contrail_image_mask(sample_id):
    # Read the 3 bands we need, and extract the target time step
    band_11 = np.load(os.path.join(DATA_DIR, sample_id, 'band_11.npy'))[:, :, 5]
    band_14 = np.load(os.path.join(DATA_DIR, sample_id, 'band_14.npy'))[:, :, 5]
    band_15 = np.load(os.path.join(DATA_DIR, sample_id, 'band_15.npy'))[:, :, 5]
    
    # Let's save the image size, will be useful later
    IMAGE_SIZE = band_11.shape[0]
    
    # Calculate R, G, and B channels, with the scaling.
    # Clip to between 0 and 1 so that we don't get invalid values
    red = ((band_15 - band_14 + 4) / (2 + 4)).clip(0, 1)
    green = ((band_14 - band_11 + 4) / (5 + 4)).clip(0, 1)
    blue = ((band_11 - 243) / (303 - 243)).clip(0, 1)
    # Stack them correctly, and transpose so that the channels are list
    image = np.stack((red, green, blue), axis=0).transpose((1, 2, 0))
    
    # Now read the mask, it has an extra singleton channel dimension at the end,
    # so get rid of that.
    mask = np.load(os.path.join(DATA_DIR, sample_id, 'human_pixel_masks.npy')).squeeze()
    
    return image, mask

# pix2pix upsample layer
class Pix2PixUpsample(nn.Module):
    def __init__(self, in_chan, out_chan, kernel_size):
        super().__init__()
        
        self.conv2d = nn.ConvTranspose2d(in_chan, out_chan, kernel_size, stride=2, padding=1, bias=False)
        # Initialize weights with mean 0 and standard deviation 0.02
        nn.init.normal_(self.conv2d.weight, mean=0, std=0.02)
        
        self.model = nn.Sequential(
            self.conv2d, 
            nn.BatchNorm2d(out_chan),
            nn.Dropout(0.5),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.model(x)

# False color dataset class
class GRContrailsFalseColorDataset(Dataset):
    def __init__(self, image_dir, sample_ids=None, test=False):
        """
        If sample_ids is None, then all the samples in image_idr will be read.
        :param image_dir: The directory with all the images.
        :param sample_ids: The list of sample IDs to use. Default None
        :param test: Whether this is test data or not. If true, then does not return the mask. Default False.
        """
        super().__init__()
        self.image_dir = image_dir
        if sample_ids is None:
            # Get a list of all the subdirectories in image_dir.
            # The first element is the directory itself, so index it out.
            self.sample_ids = [os.path.basename(subdir) for subdir, _, _ in os.walk(self.image_dir)][1:]
        else:
            self.sample_ids = sample_ids
        self.test = test
    
    def __len__(self):
        # Just return the length of the sample IDs
        return len(self.sample_ids)
    
    def __getitem__(self, idx):
        sample_id = self.sample_ids[idx]
        # Read in bands 11, 14, and 15 at the target time stamp
        band_11 = np.load(os.path.join(self.image_dir, sample_id, 'band_11.npy'))[:, :, 5]
        band_14 = np.load(os.path.join(self.image_dir, sample_id, 'band_14.npy'))[:, :, 5]
        band_15 = np.load(os.path.join(self.image_dir, sample_id, 'band_15.npy'))[:, :, 5]
        # Calculate R, G, and B channels
        red = ((band_15 - band_14 + 4) / (2 + 4)).clip(0, 1)
        green = ((band_14 - band_11 + 4) / (5 + 4)).clip(0, 1)
        blue = ((band_11 - 243) / (303 - 243)).clip(0, 1)
        # Concatenate them to create a false color image.
        # Do CHANNELS FIRST ordering (axis=0), the default for PyTorch.
        image = torch.from_numpy(np.stack((red, green, blue), axis=0))
        # Read in the mask, unless this is for testing.
        if not self.test:
            mask = np.load(os.path.join(self.image_dir, sample_id, 'human_pixel_masks.npy'))
            # Mask is 256 x 256 x 1, do a transpose so both input image and mask are the same shape.
            # Also convert to float.
            mask = torch.from_numpy(np.transpose(mask, (2, 0, 1))).to(torch.float)
            return image, mask
        else:
            return image

class GRContrailDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 128, num_workers: int = 4, 
                 pin_memory: bool = True, validation_split: float = 0.2):
        super().__init__()
        # This method allows all parameters to be in self.hparams
        # without defining each one individually.
        self.save_hyperparameters()
        # Define all dataset objects, initially setting to None
        self.train_dataset: Optional[Dataset] = None
        self.val_dataset: Optional[Dataset] = None
    
    def prepare_data(self) -> None:
        # Download data here, but we've already done so.
        pass
    
    def setup(self, stage: str) -> None:
        # Assign dataset objects here by invoking Dataset class with correct parameters
        if not self.train_dataset and not self.val_dataset:
            # Apply a train test split on the list of all sample IDs present in validation
            all_files = [os.path.basename(subdir) for subdir, _, _ in os.walk(DATA_DIR)][1:]
            train_files, val_files = train_test_split(all_files, test_size=self.hparams.validation_split)
            # Create the two Dataset objects using each of the file lists
            self.train_dataset = GRContrailsFalseColorDataset(DATA_DIR, sample_ids=train_files)
            self.val_dataset = GRContrailsFalseColorDataset(DATA_DIR, sample_ids=val_files)
    
    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True,
                          num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory)
    
    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, shuffle=False,
                          num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory)

class UMobileNet(nn.Module):
    def __init__(self, image_size):
        super().__init__()
        mobilenet = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
        layers = {
            'features.2.conv.0': 'block_1',    # 128 x 128
            'features.4.conv.0': 'block_3',    # 64 x 64
            'features.7.conv.0': 'block_6',    # 32 x 32
            'features.14.conv.0': 'block_13',  # 16 x 16
            'features.17.conv.2': 'block_16'   # 8 x 8
        }
        encoder = create_feature_extractor(mobilenet, return_nodes=layers)
        self.encoder = encoder
        
        self.up_stack = nn.ModuleList([
            Pix2PixUpsample(320, 512, 4),        # 8 x 8 ==> 16 x 16
            Pix2PixUpsample(576 + 512, 256, 4),  # 16 x 16 ==> 32 x 32
            Pix2PixUpsample(192 + 256, 128, 4),  # 32 x 32 ==> 64 x 64
            Pix2PixUpsample(144 + 128, 64, 4)    # 64 x 64 ==> 128 x 128
        ])


        # The final layer is just a simple transpose, with a single output channel.
        self.last_conv = nn.ConvTranspose2d(in_channels=96 + 64, out_channels=1, kernel_size=4, stride=2, padding=1)
    
    def forward(self, x: torch.Tensor):
        # Push through encoder
        skips = list(self.encoder(x).values())
        x = skips[-1]
        skips = skips[::-1][1:]
        
        for up, skip_connection in zip(self.up_stack, skips):
            x = up(x)
            x = torch.cat([x, skip_connection], dim=1)
        x = self.last_conv(x)
        return x

## Dataset and DataModule changes
Because we need to carry forward information about whether the image contains contrails at all, we need to include a binary flag to the label in addition to the original mask. This is simply an extra output in the return statement. I've marked the main change below.

In [3]:
# False color dataset class
class GRContrailsFalseColorDataset(Dataset):
    def __init__(self, image_dir, sample_ids=None, test=False):
        """
        If sample_ids is None, then all the samples in image_idr will be read.
        :param image_dir: The directory with all the images.
        :param sample_ids: The list of sample IDs to use. Default None
        :param test: Whether this is test data or not. If true, then does not return the mask. Default False.
        """
        super().__init__()
        self.image_dir = image_dir
        if sample_ids is None:
            # Get a list of all the subdirectories in image_dir.
            # The first element is the directory itself, so index it out.
            self.sample_ids = [os.path.basename(subdir) for subdir, _, _ in os.walk(self.image_dir)][1:]
        else:
            self.sample_ids = sample_ids
        self.test = test
    
    def __len__(self):
        # Just return the length of the sample IDs
        return len(self.sample_ids)
    
    def __getitem__(self, idx):
        sample_id = self.sample_ids[idx]
        # Read in bands 11, 14, and 15 at the target time stamp
        band_11 = np.load(os.path.join(self.image_dir, sample_id, 'band_11.npy'))[:, :, 5]
        band_14 = np.load(os.path.join(self.image_dir, sample_id, 'band_14.npy'))[:, :, 5]
        band_15 = np.load(os.path.join(self.image_dir, sample_id, 'band_15.npy'))[:, :, 5]
        # Calculate R, G, and B channels
        red = ((band_15 - band_14 + 4) / (2 + 4)).clip(0, 1)
        green = ((band_14 - band_11 + 4) / (5 + 4)).clip(0, 1)
        blue = ((band_11 - 243) / (303 - 243)).clip(0, 1)
        # Concatenate them to create a false color image.
        # Do CHANNELS FIRST ordering (axis=0), the default for PyTorch.
        image = torch.from_numpy(np.stack((red, green, blue), axis=0))
        # Read in the mask, unless this is for testing.
        if not self.test:
            mask = np.load(os.path.join(self.image_dir, sample_id, 'human_pixel_masks.npy'))
            # Mask is 256 x 256 x 1, do a transpose so both input image and mask are the same shape.
            # Also convert to float.
            mask = torch.from_numpy(np.transpose(mask, (2, 0, 1))).to(torch.float)
            return image, int(torch.any(mask)), mask   #### CHANGE HERE!!
        else:
            return image

The `DataModule` itself won't change because that is mainly `Dataset` output agnostic. What **will** need to be changed is the `LightningModule`, because we have an extra output to worry about and manage.

But before we get to that, we need to define our binary classification first. This is just as standard `nn.Module`. 