In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt 
import torch 
import torch.nn as nn
from PIL import Image 
from pathlib import Path
import cv2
import torch.optim
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader, sampler
import time
import torchvision.models as models 
from torchinfo import summary
from tqdm import tqdm
import glob
from IPython.display import clear_output

# READING DATASET

In [None]:
base_path=Path('/kaggle/input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training/')
red_dir =Path('/kaggle/input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training/train_red/')
green_dir =Path('/kaggle/input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training/train_green/')
blue_dir=Path('/kaggle/input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training/train_blue/')
nir_dir=Path('/kaggle/input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training/train_nir/')
gt_dir= Path('/kaggle/input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training/train_gt/')

# CREATING CUSTOM PYTORCH DATACLASS

In [None]:
class CloudDataset(Dataset):
    def __init__(self ,red_dir ,green_dir ,blue_dir ,nir_dir ,gt_dir ,pytorch=True):
        super().__init__()
        self.files =[self.combine_files(f ,green_dir ,blue_dir ,nir_dir ,gt_dir)  for f in red_dir.iterdir() if not f.is_dir()]
        self.pytorch =pytorch
        
        
    def combine_files (self ,r_file:Path ,green_dir ,blue_dir ,nir_dir ,gt_dir):
        files ={'red' :r_file,
               'green':green_dir/r_file.name.replace('red' ,'green'),
               'blue' :blue_dir/r_file.name.replace('red','blue'),
               'nir' :nir_dir/r_file.name.replace('red' ,'nir'),
               'gt' :gt_dir/r_file.name.replace('red','gt')}
        return files 
    
    def __len__(self):
        return len(self.files)
    
    def open_as_array(self, idx, invert=False, include_nir=False):

        raw_rgb = np.stack([np.array(Image.open(self.files[idx]['red'])),
                            np.array(Image.open(self.files[idx]['green'])),
                            np.array(Image.open(self.files[idx]['blue'])),
                           ], axis=2)
    
        if include_nir:
            nir = np.expand_dims(np.array(Image.open(self.files[idx]['nir'])), 2)
            raw_rgb = np.concatenate([raw_rgb, nir], axis=2)
    
        if invert:
            raw_rgb = raw_rgb.transpose((2,0,1))
            
        # normalize
        return (raw_rgb / np.iinfo(raw_rgb.dtype).max)
    
    def open_mask(self, idx, add_dims=False):
        
        raw_mask = np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask
    
    def __getitem__(self, idx):
        
        x = torch.tensor(self.open_as_array(idx, invert=self.pytorch, include_nir=True), dtype=torch.float32)
        y = torch.tensor(self.open_mask(idx, add_dims=False), dtype=torch.torch.int64)
        
        return x, y
    
    def open_as_pil(self, idx):
        
        arr = 256*self.open_as_array(idx)
        
        return Image.fromarray(arr.astype(np.uint8), 'RGB')
    
    def __repr__(self):
        s = 'Dataset class with {} files'.format(self.__len__())

        return s

In [None]:
data =CloudDataset(red_dir ,green_dir ,blue_dir ,nir_dir ,gt_dir)

In [None]:
len(data)

In [None]:
print(data)

# Plotting some stacked images 

In [None]:
x ,y =data[2000]
x.shape ,y.shape

In [None]:
fig ,ax =plt.subplots(1 ,2 ,figsize=(10,9))
ax[0].imshow(data.open_as_array(150))
ax[1].imshow(data.open_mask(150))
ax[0].set_title('image')
ax[1].set_title('mask')
plt.show()

In [None]:
id= torch.randint(0 ,len(data) ,(15,))
id_list = id.tolist()
for ids in id_list:
    plt.subplot(1, 2, 1)
    plt.imshow(data.open_as_array(ids))
    plt.subplot(1, 2, 2)
    plt.imshow(data.open_mask(ids))
    plt.show()

In [None]:
train_ds  ,valid_ds = torch.utils.data.random_split(data, (7400, 1000))
train_dl = DataLoader(train_ds, batch_size=12, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=12, shuffle=True)

In [None]:
xb ,yb =next(iter(train_dl))

In [None]:
xb.shape ,yb.shape

# Define Jaccard function as a performance metric

In [None]:
def jaccard(img1, img2):
    img1 = np.array(img1).astype(bool)
    img2 = np.array(img2).astype(bool)
    
    U = np.logical_or(img1, img2)
    I = np.logical_and(img1, img2)
    
    num = I.reshape(I.shape[0], -1).mean(axis=-1)
    denum = U.reshape(U.shape[0], -1).mean(axis=-1)
    
    # to avoid division to 0
    denum = np.where(denum == 0, -1, denum)
    
    measure = num / denum
    measure = np.where(measure < 0, 0, measure)
    return measure.mean(0)

# Models

# VGGnet

In [None]:
VGG_types ={
'VGG16' : [64 ,64 ,'M', 128 ,128 ,'M' ,256 ,256 ,256 ,'M', 512 ,512 ,512 ,'M' ,512 ,512 ,512 ,'M'],
'VGG19' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}
class VGG_net(nn.Module):
    def __init__(self ,in_channels ,out_channels):
        super(VGG_net ,self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self.create_conv_layers(VGG_types['VGG19'])
        #512x12x12. This should make the number of features after flattening equal to 512x12x12=73728, which matches the first fully connected layer's input size.
        
        self.fcs =nn.Sequential(
            nn.Linear(512*12*12 ,4096),  
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096 ,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096 ,out_channels)
            
        )
        
    #Given an input tensor x, it passes the input through the convolutional layers, flattens the output, and then passes it through the fully connected layers.
    def forward(self ,x):
        x =self.conv_layers(x)
        x = x.view(x.size(0), -1) #flattens the output tensor from the convolutional layers into a 1-dimensional tensor
        x = self.fcs(x)
        return x
    
    def create_conv_layers(self,architecture):
        layers =[]
        in_channels =self.in_channels
        
        for x in architecture :
            if type(x) ==int:
                out_channels = x
                
                layers +=[nn.Conv2d(in_channels ,out_channels ,kernel_size =(3,3) ,stride=(1,1) ,padding= (1,1)),
                            nn.BatchNorm2d(x),
                            nn.ReLU()]
                in_channels =x
                
            elif x =='M':
                layers += [nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))]
                
        return nn.Sequential(*layers)

In [None]:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#VGG_NET = VGG_net(4 ,2).to(device)
VGG_NET = VGG_net(in_channels=4 ,out_channels=2)
x =torch.randn(1 ,4 ,384 ,384)
print(VGG_NET(x).shape)

In [None]:
#import hiddenlayer as hl

#transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

#graph = hl.build_graph(VGG_NET, torch.zeros([12, 4, 384, 384]), transforms=transforms)
#graph.theme = hl.graph.THEMES['blue'].copy()
#graph.save('rnn_hiddenlayer', format='png')

In [None]:
VGG_NET

In [None]:
summary(VGG_NET, input_size=(8, 4, 384, 384))

In [None]:
#testing one pass
xb, yb = next(iter(train_dl))
xb.shape, yb.shape

pred = VGG_NET(xb)
pred.shape

# UNET

In [None]:
from torch import nn
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        upconv3 = self.upconv3(conv3)

        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand

In [None]:
unet_model =UNET(4 , 2).cuda()

In [None]:
summary(unet_model, input_size=(8, 4, 384, 384))

In [None]:
xb, yb = next(iter(train_dl))
xb.shape, yb.shape
pred = unet_model(xb)
pred.shape

In [None]:
loss_fn =nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(unet_model.parameters() , lr =0.001)

In [None]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs=12):
    start = time.time()
    model.cuda()

    train_loss, valid_loss = [], []

    best_acc = 0.0

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-')

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:
                x = x.cuda()
                y = y.cuda()
                step += 1

                # forward pass
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss = loss_fn(outputs, y)

                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y.long())

                # stats - whatever is the phase
                acc = acc_fn(outputs, y)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 100 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    # print(torch.cuda.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)

            clear_output(wait=True)
            print('Epoch {}/{}'.format(epoch, epochs - 1))
            print('-' * 10)
            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
            print('-' * 10)

            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return train_loss, valid_loss    

def acc_metric(predb, yb):
    return (predb.argmax(dim=1) == yb.cuda()).float().mean()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.cuda.amp as amp
from torch.utils.data import DataLoader
from tqdm import tqdm

optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0

def train_segmentation_model(model, train_loader, val_loader, epochs, optimizer, criterion,
                             scheduler, gradient_clipping, device, amp=False):
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    global_step = 0

    for epoch in range(1, epochs + 1):
        model.train()
        with tqdm(total=len(train_loader.dataset), desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    masks_pred = model(images)
                    if model.n_classes == 1:
                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                    else:
                        loss = criterion(masks_pred, true_masks)
                        loss += dice_loss(
                            F.softmax(masks_pred, dim=1).float(),
                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                            multiclass=True
                        )

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1

                # Evaluation round (if needed)
                division_step = (len(train_loader.dataset) // (5 * train_loader.batch_size))
                if division_step > 0 and global_step % division_step == 0:
                    val_score = evaluate(model, val_loader, device, amp)
                    scheduler.step(val_score)

# Define other necessary functions (evaluate, dice_loss, etc.) before calling this function.

# Usage example:
# train_segmentation_model(model, train_loader, val_loader, epochs, optimizer, criterion,
#                          scheduler, gradient_clipping, device, amp=False)


In [None]:
train_loss ,valid_loss = train(unet_model ,train_dl ,valid_dl ,loss_fn ,optimizer ,acc_matric ,epochs =12)

# Save the model

In [None]:
torch.save(UNet, '/kaggle/working/trained_model')