<a href="https://colab.research.google.com/github/Nat-D/varuna-hackathon/blob/main/UNET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# UNET Model

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
                nn.BatchNorm2d(out_channels), 
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
                nn.BatchNorm2d(out_channels), 
                nn.ReLU(inplace=True),
            )

    def forward(self, x):
        return self.conv(x)



class UNET(nn.Module):
    def __init__(self, in_channels=12, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()

        self.out_channels = out_channels

        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)     

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part 
        for feature in reversed(features):
            self.ups.append(
                    nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
                )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)


    def standardize_per_channel(self, x):
        N = x.shape[0]
        C = x.shape[1]
        x_view = x.view(N,C,-1)
        x_mean = torch.mean(x_view, dim=2).view(N,C,1,1)
        x_std = 1e-5 + torch.std(x_view, dim=2).view(N,C,1,1)

        return (x - x_mean) / x_std

    
    def standardize_per_sample(self, x):
        N = x.shape[0]      
        x_view = x.view(N,-1)
        x_mean = torch.mean(x_view, dim=1).view(N,1,1,1) 
        x_std = 1e-5 + torch.std(x_view, dim=1).view(N,1,1,1)

        return (x - x_mean) / x_std


    def forward(self, x):

        # standardization
        x = self.standardize_per_sample(x)

        # unet model
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x) # the ConvTransposed
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)  #[b,c,h,w]
            x = self.ups[idx+1](concat_skip) # the DoubleConv

        return self.final_conv(x) 




def test():
    x = torch.randn((3, 1, 160, 160))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    
    assert preds.shape == x.shape

if __name__ == "__main__":
    test()

# Dataset

In [None]:
import os 
from PIL import Image
from torch.utils.data import Dataset
import numpy as np


class Sentinel2(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)


    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".npy", "_label.npy"))

        image = np.load(img_path).astype(np.float16)
        mask = np.load(mask_path)

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

def test_sentinel_data():
    image_dir = "/content/drive/MyDrive/data/train/img/"
    mask_dir = "/content/drive/MyDrive/data/train/mask/"

    sentinel = Sentinel2(image_dir, mask_dir)
    img, label = sentinel[0]

if __name__ == "__main__":
    test_sentinel_data()

# Utils

In [None]:
import torch 
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
    ):

    train_ds = Sentinel2(
            image_dir=train_dir,
            mask_dir=train_maskdir,
            transform=train_transform,
        )
    train_loader = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            shuffle=True
        )

    val_ds = Sentinel2(
            image_dir=val_dir,
            mask_dir=val_maskdir,
            transform=val_transform,
        )

    val_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            shuffle=False
        )

    return train_loader, val_loader

color_group = np.array([[0,0,0],  # no data
                       [244,35,232], 
                       [250,170,160],
                       [230,150,140],
                       [0, 255, 0],   #vegetation
                       [102,102,156],
                       [190,153,153],
                       [180,165,180],
                       [150,100,100],
                       [150,120, 90],
                       [153,153,153],
                       [1  , 1  ,1 ]
                       ])

class Logger():
    def __init__(self, device="cuda", log_dir='runs'):
        
        self.writer = SummaryWriter(log_dir)

        self.device = device
        self.accumulate_training_loss = 0.0
        self.training_step = 0
        self.epoch_num_step = 0
        self.loss_fn = nn.CrossEntropyLoss()

    def compute_precision(self, true_pos, false_pos, false_neg):
        return true_pos / (true_pos + false_pos + 1e-5)

    def compute_recall(self, true_pos, false_pos, false_neg):
        return true_pos / (false_neg + true_pos + 1e-5)

    def compute_iou(self, true_pos, false_pos, false_neg):
        return true_pos / (true_pos + false_pos + false_neg + 1e-5)

    def compute_weighted_iou():
        return


    def validation(self, loader, model):
        model.eval()

        total_loss = 0
        num_step = 0
        total_iou_for_each_class = 0

        with torch.no_grad():
            for x,y in loader:
                x = x.to(self.device)
                y = y.long().to(self.device)

                preds = model(x)
                loss = self.loss_fn(preds, y)
                
                total_loss += loss
                num_step += 1

                preds_max = torch.argmax(preds, dim=1, keepdim=False) #[B, H, W]
                preds_max = F.one_hot(preds_max, num_classes=model.out_channels) #[B, H, W, C]

                y = F.one_hot(y, num_classes=model.out_channels) #[B, H, W, C]
                
                ones = torch.ones_like(preds_max)
                zeros = torch.zeros_like(preds_max)

                true_pos = torch.logical_and((preds_max == ones), (preds_max == y)) 
                true_pos = torch.sum(true_pos, dim=(1,2)) # [B, C]

                false_pos = torch.logical_and((preds_max == ones),(preds_max != y))
                false_pos = torch.sum(false_pos, dim=(1,2))

                false_neg = torch.logical_and((preds_max == zeros), (preds_max != y))
                false_neg = torch.sum(false_neg, dim=(1,2))
                
                iou = self.compute_iou(true_pos, false_pos, false_neg) 
                total_iou_for_each_class += torch.sum(iou, dim=0)
                

        self.writer.add_scalar("Loss/Average_validation_loss", 
                                total_loss/num_step, 
                                self.training_step)

        for cls in range(model.out_channels):
            self.writer.add_scalar(f"IoU/Average_iou_class_{cls}",
                                   total_iou_for_each_class[cls]/(num_step * loader.batch_size),
                                   self.training_step)

        model.train()
        return  total_loss/num_step


    def save_predictions_as_img(self, loader, model):

        model.eval()
        
        x, y = next(iter(loader))
        x = x.to(device=self.device)
        with torch.no_grad():
            preds = torch.argmax(model(x), dim=1, keepdim=False)
            preds_np = preds.cpu().numpy()
 
            for idx in range(preds_np.shape[0]):
                
                rgb_mask = color_group[preds_np[idx]]
                self.writer.add_image(f'predict/{idx}', rgb_mask/255., self.training_step, dataformats="HWC")

                rgb_mask_groundtruth = color_group[y[idx]]
                self.writer.add_image(f'target/{idx}_mask', rgb_mask_groundtruth/255., self.training_step, dataformats="HWC")

        model.train()

    def log_step(self, loss):
        self.accumulate_training_loss += loss
        self.epoch_num_step += 1
        self.training_step += 1

    def log_epoch(self, val_loader, model, optimizer):
        self.writer.add_scalar('Loss/Average_training_loss',
                               self.accumulate_training_loss/ self.epoch_num_step, self.training_step)
        
        self.accumulate_training_loss = 0
        self.epoch_num_step = 0

        
        if self.training_step % 5 == 0:
            self.save_predictions_as_img(val_loader, model)
            val_loss = self.validation(val_loader, model)
        return val_loss

# Train

In [None]:
!pip install albumentations==0.4.6

In [None]:
import torch 
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim


# Experiment log
LOG_DIR = "/content/runs/standardize_per_sample_augmentation/"

# Hyper parameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 1000
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True
LOAD_MODEL= False 


TRAIN_IMG_DIR = "/content/drive/MyDrive/data/train/img/"
TRAIN_MASK_DIR = "/content/drive/MyDrive/data/train/mask/"
VAL_IMG_DIR = "/content/drive/MyDrive/data/val/img/"
VAL_MASK_DIR = "/content/drive/MyDrive/data/val/mask" 



def train_fn(epoch, loader, model, optimizer, loss_fn, scaler, logger):

    # decorate loader with tqdm
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.long().to(device=DEVICE)

        # forward (float16 )
        with torch.cuda.amp.autocast():
            predictions = model(data) # + 1e-8  # epsilon to prevent loss=nan
            loss = loss_fn(predictions, targets) 

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        

        # update tqdm loop
        loop.set_postfix(epoch=epoch, loss=loss.item())

        # log training 
        logger.log_step(loss.item())

def main():
    train_transform = A.Compose([
            A.ToFloat(max_value=65535.0), # support uint16
            # A.RandomResizedCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, p=1.0),
            # A.RandomCrop(height=300, width=300, p=0.5),
            # A.Rotate(limit= 90, p=0.5), 
            # A.RandomBrightnessContrast(p=0.5),
            # A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.RandomSizedCrop(min_max_height=(100, 400), height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Transpose(p=0.5),
            A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
            A.GridDistortion(p=0.5),
            A.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.5),
            ToTensorV2()
        ])

    val_transforms = A.Compose([
            A.ToFloat(max_value=65535.0),
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            ToTensorV2()
        ])

    model = UNET(in_channels=12, out_channels=12).to(DEVICE)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=30, threshold=0.001, verbose="True")
    scaler = torch.cuda.amp.GradScaler()

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY
        )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    logger = Logger(device=DEVICE, log_dir=LOG_DIR)
    

    for epoch in range(NUM_EPOCHS): 
        train_fn(epoch, train_loader, model, optimizer, loss_fn, scaler, logger)
        val_loss = logger.log_epoch(val_loader, model, optimizer)
        scheduler.step(val_loss)


if __name__ == "__main__":
    main()