# Google Research - Identify Contrails to Reduce Global Warming
This is the notebook used to train the UNET model for the "Google Research - Identify Contrails to Reduce Global Warming" competition on Kaggle. It ranked 765/954 with dice score 0.59090.

The dataset used for training was made by kaggler Shashwat Raman.

# Imports

In [150]:
from pathlib import Path
import os
import random
import math
from collections import defaultdict
import cv2
import skimage

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from tqdm.notebook import tqdm
from transformers import get_cosine_schedule_with_warmup

torch.__version__

'2.0.0'

In [151]:
!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp

[0m

# Settings and Paths

In [152]:
class Config:
    train=True
    
    num_epochs=20
    
    num_classes=1
    batch_size=32
    seed=42
    
    encoder = 'efficientnet-b0'
    pretrained = True
    weights = 'imagenet'
    
    classes = ['contrail']
    activation = None
    in_chans = 3
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    image_size = 256
    lr = 3e-4
    warmup = 0
    
class Paths:
    train_data_csv = '/kaggle/input/contrails-images-ash-color/train_df.csv'
    valid_data_csv = '/kaggle/input/contrails-images-ash-color/valid_df.csv'
    contrails = '/kaggle/input/contrails-images-ash-color/contrails/'

In [153]:
def set_seed(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

# Dataset

In [154]:
class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, df, train):
        self.df=df
        self.train=train
        
    def __getitem__(self, index):
        # Accesses sample and label here (according to index)
        row = self.df.iloc[index]
        con_path = row.path
        con = np.load(str(con_path))
        
        # Selects all dimensions before last one. In last dimension, selects all
        # elements, excluding last one (':-1' means slice to but not include last 
        # one).
        # All dimensions excluding last element of last one makes up the sample image.
        img = con[..., :-1]
        
        # Selects all dimensions before last one + last element of last dimension
        # All dimensions + last element of last one make up label
        label = con[..., -1]
        
        img = torch.tensor(img)
        label = torch.tensor(label)
        
        img = img.permute(2, 0, 1)
    
        # Returns tuple !!! (sample, label)
        # Indexing into ContrailDataset returns both sample and label!
        # img = C x H x W (3 x 256 x 256)
        # label = H x W (256 x 256)
        return img.float(), label.float()
    
    def __len__(self):
        return len(self.df)

In [155]:
train_df = pd.read_csv(Paths.train_data_csv)
valid_df = pd.read_csv(Paths.valid_data_csv)

display(train_df)
display(valid_df)

train_df['path'] = Paths.contrails+train_df['record_id'].astype(str)+'.npy'
valid_df['path'] = Paths.contrails+valid_df['record_id'].astype(str)+'.npy'

display(train_df)
display(valid_df)

train_ds = ContrailsDataset(train_df, train=True)
valid_ds = ContrailsDataset(valid_df, train=False)

display(train_ds[0])

train_dl = DataLoader(train_ds, batch_size=Config.batch_size, shuffle=True, num_workers=2)
valid_dl = DataLoader(valid_ds, batch_size=Config.batch_size, shuffle=False, num_workers=2)

Unnamed: 0,record_id,train
0,1284412112608546821,train
1,7457695218848685981,train
2,836236084461732921,train
3,7829917977180135058,train
4,5319255125658459358,train
...,...,...
20524,8443915190215904823,train
20525,8495643844280686935,train
20526,856381910009426679,train
20527,3751790308836191485,train


Unnamed: 0,record_id,train
0,3687499407028137410,valid
1,6558861185867890815,valid
2,7355354609194882312,valid
3,7547747455642200110,valid
4,5456834089979970017,valid
...,...,...
1851,922629314296188212,valid
1852,3319793057592206418,valid
1853,5640456394563366318,valid
1854,6742201885695641013,valid


Unnamed: 0,record_id,train,path
0,1284412112608546821,train,/kaggle/input/contrails-images-ash-color/contr...
1,7457695218848685981,train,/kaggle/input/contrails-images-ash-color/contr...
2,836236084461732921,train,/kaggle/input/contrails-images-ash-color/contr...
3,7829917977180135058,train,/kaggle/input/contrails-images-ash-color/contr...
4,5319255125658459358,train,/kaggle/input/contrails-images-ash-color/contr...
...,...,...,...
20524,8443915190215904823,train,/kaggle/input/contrails-images-ash-color/contr...
20525,8495643844280686935,train,/kaggle/input/contrails-images-ash-color/contr...
20526,856381910009426679,train,/kaggle/input/contrails-images-ash-color/contr...
20527,3751790308836191485,train,/kaggle/input/contrails-images-ash-color/contr...


Unnamed: 0,record_id,train,path
0,3687499407028137410,valid,/kaggle/input/contrails-images-ash-color/contr...
1,6558861185867890815,valid,/kaggle/input/contrails-images-ash-color/contr...
2,7355354609194882312,valid,/kaggle/input/contrails-images-ash-color/contr...
3,7547747455642200110,valid,/kaggle/input/contrails-images-ash-color/contr...
4,5456834089979970017,valid,/kaggle/input/contrails-images-ash-color/contr...
...,...,...,...
1851,922629314296188212,valid,/kaggle/input/contrails-images-ash-color/contr...
1852,3319793057592206418,valid,/kaggle/input/contrails-images-ash-color/contr...
1853,5640456394563366318,valid,/kaggle/input/contrails-images-ash-color/contr...
1854,6742201885695641013,valid,/kaggle/input/contrails-images-ash-color/contr...


(tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.8563e-02,
           2.8351e-02, 0.0000e+00],
          [2.6169e-02, 6.3438e-03, 1.7853e-02,  ..., 2.5330e-02,
           3.0472e-02, 1.0271e-03],
          [1.3269e-01, 1.7322e-01, 2.4255e-01,  ..., 2.1347e-02,
           1.9592e-02, 5.0664e-06],
          ...,
          [4.1821e-01, 4.3457e-01, 4.5190e-01,  ..., 4.8364e-01,
           5.1270e-01, 5.0977e-01],
          [4.2090e-01, 4.3213e-01, 4.3335e-01,  ..., 5.0098e-01,
           5.0977e-01, 5.1611e-01],
          [4.7070e-01, 4.2310e-01, 3.6890e-01,  ..., 5.3027e-01,
           5.2246e-01, 5.0928e-01]],
 
         [[5.8008e-01, 5.0049e-01, 4.6729e-01,  ..., 6.4844e-01,
           6.4893e-01, 6.2109e-01],
          [6.2354e-01, 5.6445e-01, 5.1807e-01,  ..., 6.2646e-01,
           6.3379e-01, 6.2402e-01],
          [5.9814e-01, 5.8203e-01, 5.8691e-01,  ..., 6.2109e-01,
           6.3281e-01, 6.2793e-01],
          ...,
          [3.2251e-01, 3.2202e-01, 3.4790e-01,  ..., 4.675

# Training

In [156]:
def dice_coef(y_true, y_pred, thr=0.5, epsilon=0.001):
    
    # If not flattened, will do 2D matrix multiplication in "(y_true * y_pred)" which is crazy. But all we 
    # want to do is multiply each pair of elements that are in the same position across both Tensors.
    y_true = y_true.flatten()
    
    # "(y_pred>thr)" is basically [(i>thr?1:0) for i in y_pred]. Basically maps >thr conditional 
    # statement on each element of y_pred, so each element becomes 1 or 0.
    y_pred = (y_pred>thr).astype(np.float32).flatten()
    
    # Sums the intersection of 1s ONLY (however mathematically, the entire dice coefficient will account
    # for 0s as well, in all 4 cases of 0 0, 1 0, 0 1, 1 1, where left is ground right is pred).
    inter = (y_true*y_pred).sum()
    inter2 = sum([y_true[i]==y_pred[i] for i in range(len(y_pred))])
    
    # Sum of the sum of the set of 1s in ground truth and of the sum of the set of 1s in prediction
    den = y_true.sum() + y_pred.sum()
    
    # ** Notice how we add epsilon = 0.001 to the numerator and denominator? This is to prevent 0 
    # division error, as in the situation that the ground truth has no masks and the prediction has no 
    # masks, "den" would equal 0. Thus, epsilon is used to prevent 0 division error, while maintaining
    # (relatively) proportionality between numerator and denominator.
    dice = ((2*inter+epsilon)/(den+epsilon))
    return dice

In [157]:
class UNet(nn.Module):
    def __init__(self, cfg):
        super(UNet, self).__init__()
        
        self.cfg = cfg
        self.training = True
        
        self.model = smp.Unet(
            encoder_name = cfg.encoder,
            encoder_weights=cfg.weights, 
            decoder_use_batchnorm=True,
            classes=len(cfg.classes), 
            activation=cfg.activation,
        )
        
        self.loss_fn = smp.losses.DiceLoss(mode='binary')
        
    def forward(self, x, y):
        '''
        x = sample (B x C x H x W image)
        y = label (B x H x W ground truth)
        '''
        
        # pre-activation function output from CNN
        
        logits = self.model(x)
        
        loss = self.loss_fn(logits, y)
        
        # probabilities are given after passing logits through sigmoid function (sigmoid(x)=1/(1+e^(-x))).
        # They can be interpreted as percentage probabilities.
        # Note: The probabilities across different classes are not required to sum to 1, so percentage may
        # not be the best interpretation.
        probabilities = logits.sigmoid()
        return {"loss": loss, "probabilities": probabilities, "logits": logits, "target": y}

In [158]:
def train_step(model, dataloader, optimizer, device):
    
    # Put model in training mode
    model.train()
    
    # Array containing train loss values
    train_losses = []
    
    # Progress bar
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training")
    
    for idx, (x, y) in progress_bar:
        
        # Move tensors to other device for efficiency
        x = x.to(device)
        y = y.to(device)
        
        torch.set_grad_enabled(True)
        
        output_dict = model(x, y)
        
        loss = output_dict["loss"]
        
        train_losses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if scheduler is not None:
            scheduler.step()
            
    train_loss = np.sum(train_losses)
        
    return train_loss

In [159]:
def test_step(model, dataloader, device):
    model.eval()
    
    torch.set_grad_enabled(False)
    
    val_data = defaultdict(list)
    
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Validating")
    
    for idx, (x, y) in progress_bar:
        x = x.to(device)
        y = y.to(device)
        
        output_dict = model(x, y)
        
        for i, j in output_dict.items():
            val_data[i]+=[output_dict[i]]
        
    val_data['loss'] = torch.stack(val_data['loss'])
    val_data['target'] = torch.cat(val_data['target'], dim=0).cpu().detach().numpy()
    val_data['logits'] = torch.cat(val_data['logits'], dim=0).cpu().detach().numpy()
    
    
    val_losses = val_data["loss"].cpu().numpy()
    val_loss = np.sum(val_losses)
    val_dice = dice_coef(val_data['target'], val_data['logits'])
        
    return val_loss, val_dice

In [160]:
from tqdm.auto import tqdm

In [161]:
def train(model, train_dl, val_dl, optimizer, epochs, device):
    
    train_data = {'train_loss':[], 'val_loss':[], 'val_dice':[]}
    
    for epoch in range(Config.num_epochs):
        set_seed(Config.seed+epoch)
        
        train_loss = train_step(model, train_dl, optimizer, device)
        
        val_loss, val_dice = test_step(model, val_dl, device)
        
        train_loss = train_loss/len(train_ds)
        val_loss = val_loss/len(valid_ds)
        
        train_data['train_loss'].append(train_loss)
        train_data['val_loss'].append(val_loss)
        train_data['val_dice'].append(val_dice)
        
        epoch_path = f"epoch-{epoch}.pth"
        torch.save(model.state_dict(), epoch_path)
        
    return train_data

In [162]:
def get_optimizer(lr, params):
    model_optimizer = torch.optim.Adam(
        filter(lambda parameter: parameter.requires_grad, params),
        lr, weight_decay=0
    )
    return model_optimizer

In [163]:
def get_scheduler(cfg, optimizer, total_steps):
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        
        # Note: We must do (total_steps//cfg.batch_size) because data is being passed into model in 
        # batches. I.e. every time a prediction is made, the learning rate is adjusted.
        num_warmup_steps= cfg.warmup * (total_steps // cfg.batch_size),
        
        # cfg.num_epochs * (total_steps//cfg.batch_size) is the total number of predictions
        num_training_steps= cfg.num_epochs * (total_steps // cfg.batch_size)
    )
    
    return scheduler

In [164]:
display(len(train_dl))
for i in train_dl:
    display(len(i))
    display(i[0].shape)
    display(i[1].shape)
    display(i[0].type)
    display(i[1].type)
    break
    
display(len(valid_dl))
for i in valid_dl:
    display(len(i))
    display(i[0].shape)
    display(i[1].shape)
    display(i[0].type)
    display(i[1].type)
    break



model = UNet(Config).to(Config.device)

optimizer = get_optimizer(Config.lr, model.parameters())
scheduler = get_scheduler(Config, optimizer, len(train_dl))
epochs = Config.num_epochs

from timeit import default_timer as timer
start_time = timer()

model_results = train(model, train_dl, valid_dl, optimizer, epochs, Config.device)

end_time = timer()

print(f'Total Training time: {end_time-start_time:.3f} seconds')

642

2

torch.Size([32, 3, 256, 256])

torch.Size([32, 256, 256])

<function Tensor.type>

<function Tensor.type>

58

2

torch.Size([32, 3, 256, 256])

torch.Size([32, 256, 256])

<function Tensor.type>

<function Tensor.type>

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Training:   0%|          | 0/642 [00:00<?, ?it/s]

Validating:   0%|          | 0/58 [00:00<?, ?it/s]

Total Training time: 5224.133 seconds
