In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms.functional as ft
from torch.utils.data import Dataset
!pip install -U albumentations
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader

In [None]:
!nvidia-smi

# Create Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features= [64, 128, 256, 512]):
        super(UNET, self).__init__()

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down-sampling
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up-sampling
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottom = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottom(x)

        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]

            if x.shape != skip_connection.shape:  # Required if we have not chosen the input and output size
                # according to MaxPool2d
                x = ft.resize(x, size=(skip_connection.shape[2:]))
            concat_skip = torch.concat((skip_connection, x), dim=1)

            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

# Creating Dataset

In [None]:
def random_splits(prob, image_list):
    img_lis_1 = random.sample(image_list, int(prob*len(image_list)))
    img_lis_2 = []
    for img in image_list:
        if img not in img_lis_1:
            img_lis_2.append(img)
            
    return img_lis_1, img_lis_2

In [None]:
class TreeSegmentDataset(Dataset):
    def __init__(self, image_dir, mask_dir, images, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace('.jpg', '.png'))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) / 255
        
        if self.transform:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
class BeachDataset(Dataset):
    def __init__(self, image_dir, mask_dir, images, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transforms = transform
        self.images = images
        # self.images.remove(".DS_Store")
        # self.masks.remove(".DS_Store")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace("image", "mask"))

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) / 255

        if self.transforms:
            augmentations = self.transforms(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

# Hyperparameters

In [None]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 50
NUM_WORKERS = 4
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
PIN_MEMORY = True
LOAD_MODEL = True
TREE_IMAGE_DIR= "/kaggle/input/tree-binary-segmentation/images/images"
TREE_MASK_DIR = "/kaggle/input/tree-binary-segmentation/masks/masks"
COAST_IMAGE_DIR = "/kaggle/input/beach-segmentation/seg_data_new/images"
COAST_MASK_DIR = "/kaggle/input/beach-segmentation/seg_data_new/masks"
# PRE_TRAIN_IMAGES, PRE_VAL_IMAGES = random_splits(0.7, os.listdir(TREE_IMAGE_DIR))
TRAIN_IMAGES, VAL_IMAGES = random_splits(0.7, os.listdir(COAST_IMAGE_DIR))

# Utility Functions and Transforms

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # Forward
        with torch.amp.autocast('cuda'):
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        # Updating tqdm loop
        loop.set_postfix(loss=loss.item())

In [None]:
train_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=90, p=1.0),
        A.HorizontalFlip(p=0.8),
        A.VerticalFlip(p=0.8),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2()
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

In [None]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving Checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model: torch.nn.Module):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
        img_dir,
        mask_dir,
        train_images,
        val_images,
        batch_size,
        train_transform,
        val_transform,
        num_workers,
        pin_memory=True,
        beach_dataset:bool = False
):
    if beach_dataset:
        train_ds = BeachDataset(image_dir=img_dir,
                                mask_dir=mask_dir,
                                images = train_images,
                                transform=train_transform)
        
        val_ds = BeachDataset(image_dir=img_dir,
                              mask_dir=mask_dir,
                              images = val_images,
                              transform=val_transform)
    else:
        train_ds = TreeSegmentDataset(image_dir=img_dir,
                                      mask_dir=mask_dir,
                                      images = train_images,
                                      transform=train_transform)


        val_ds = TreeSegmentDataset(image_dir=img_dir,
                                    mask_dir=mask_dir,
                                    images = val_images,
                                    transform=val_transform)
    
    
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader


def check_accuracy(loader, model, device, acc_list, dice_list):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.inference_mode():
        for x, y in loader:
            x, y = x.to(device), y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2*(preds * y).sum()) / ((preds + y).sum() + 1e-8)
    
    
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.4f}"
    )

    print(f"Dice score: {dice_score/len(loader)}")
    acc = round(((num_correct*100)/num_pixels).item(), 4)
    acc_list.append(acc)
    dice_list.append((dice_score/ len(loader)).item())
    model.train()

In [None]:
pos_weight = torch.tensor(59.6704)

In [None]:
pos_weight.item()

# Setting-up model, dataloaders, loss-function, optimizer, learning rate scheduler

In [None]:
model = UNET(in_channels=3, out_channels=1)
# model = nn.DataParallel(model)
model.to(device=DEVICE)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [15, 30], gamma=0.1)
train_loader, val_loader = get_loaders(
    COAST_IMAGE_DIR,
    COAST_MASK_DIR,
    TRAIN_IMAGES,
    VAL_IMAGES,
    BATCH_SIZE,
    train_transforms,
    val_transforms,
    num_workers=NUM_WORKERS,
    beach_dataset = True
)

if LOAD_MODEL:
    load_checkpoint(torch.load("/kaggle/input/tree-segment-checkpoint/treesegmentpre-train.tar"), model)
scaler = torch.amp.GradScaler('cuda')

# Training the Model

In [None]:
acc_list, dice_list = [], []
check_accuracy(val_loader, model, device=DEVICE, acc_list=acc_list, dice_list=dice_list)
for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    scheduler.step()
    if epoch == 49:
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        save_checkpoint(checkpoint)
    check_accuracy(val_loader, model, device=DEVICE, acc_list=acc_list, dice_list=dice_list)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10), dpi=200)
ax.plot(range(len(dice_list)), dice_list, c='black')
ax.set_xlabel("Epochs")
ax.set_ylabel("Dice Score")
plt.savefig('dice.png')