<a href="https://colab.research.google.com/github/Adeeshdiwan/Coderguy/blob/master/Buildings_from_Satellite_Images.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))


In [None]:
import numpy as np
import numba as nb
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import gdal
from tqdm import tqdm
from PIL import Image
import cv2
import random

import os
import gc

DATAPATH = "/kaggle/input/sysu-rs-contest-2021/"

In [None]:
class myDataset(data.Dataset):
    def __init__(self, data_path, set_name, val_rate=0.1):
        self.set_name = set_name
        self.imageList, self.labelList = [], []
        folderList = [(data_path + "/train/" + i) for i in ["rural_CN", "urban_CN", "urban_US"]]
        for fd in folderList:
            self.imageList.extend([i.path for i in os.scandir(fd + "/image")])
        self.labelList = [i.replace("image", "label") for i in self.imageList]

        total_num = len(self.imageList)
        val_num = int(total_num * val_rate)

        np.random.seed(0)
        val_index = np.random.choice(np.arange(total_num), size=val_num, replace=False).astype(np.int)
        train_index = np.array(list(set(np.arange(total_num, dtype=np.int)) - set(val_index)))

        if set_name == "train":
            idx_used = train_index
        elif set_name == "val":
            idx_used = val_index
        else:
            raise ValueError("set_name can only be 'train' or 'val'")

        self.imgList_used = np.array(self.imageList)[idx_used.astype(np.int)]
        self.lbList_used = np.array(self.labelList)[idx_used.astype(np.int)]

    def __len__(self):
        return len(self.imgList_used)

    def __getitem__(self, index):
        img_path, lb_path = self.imgList_used[index], self.lbList_used[index]

        if img_path.split('.')[-1] == "tif":
            img_ds = gdal.Open(img_path)
            img = img_ds.ReadAsArray()
            del img_ds
        else:
            img = np.array(Image.open(img_path)).transpose(2, 0, 1)
        lb = np.array(Image.open(lb_path))

        img = img/img.max()

        unified_size = 256
        if self.set_name == "train":
            if img.shape[1]>unified_size:
                uly, ulx = random.randint(0, img.shape[1]-unified_size), random.randint(0, img.shape[2]-unified_size)
                img = img[:, uly:uly+unified_size, ulx:ulx+unified_size]
                lb = lb[uly:uly+unified_size, ulx:ulx+unified_size]
            elif img.shape[1]<unified_size:
                size_y, size_x = img.shape[1], img.shape[2]
                img = cv2.copyMakeBorder(img.transpose(1, 2, 0), 0, unified_size-size_y, 0, unified_size-size_x, cv2.BORDER_CONSTANT, value=(0, 0, 0)).transpose(2, 0, 1)
                lb = cv2.copyMakeBorder(lb, 0, unified_size-size_y, 0, unified_size-size_x, cv2.BORDER_CONSTANT, value=0)

        return torch.as_tensor(img, dtype=torch.float), torch.as_tensor(lb, dtype=torch.int64)


class myDataset_test(data.Dataset):
    def __init__(self, data_path):
        self.imgList = [i.path for i in os.scandir(data_path) if i.name.split('.')[-1]]
        self.imgList.sort()
        self.imgNameList = [i.split('/')[-1] for i in self.imgList]

    def __len__(self):
        return len(self.imgList)

    def __getitem__(self, index):
        img_path = self.imgList[index]
        if img_path.split('.') == "tif":
            img_ds = gdal.Open(img_path)
            img = img_ds.ReadAsArray()
            del img_ds
        else:
            img = np.array(Image.open(img_path)).transpose(2, 0, 1)

        img = img/img.max()
        return torch.as_tensor(img, dtype=torch.float)


def getDataLoader(dataPath, setName, shuffle=True, BSize=4, nWorkers=4, pinMem=True):
    if setName in ["train", "val"]:
        data_set = myDataset(dataPath, setName)
        return data.DataLoader(data_set, batch_size=BSize, shuffle=shuffle, num_workers=nWorkers, pin_memory=pinMem)

    elif setName == "test":
        data_set = myDataset_test(dataPath)
        return data.DataLoader(data_set, batch_size=BSize, shuffle=False, num_workers=nWorkers, pin_memory=pinMem)

    else:
        raise ValueError("setName can only be 'train', 'val' or 'test'")

In [None]:
class Resblock(nn.Module):
    """
    """
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.ch_asc = (in_ch != out_ch)
        mid_ch = out_ch//4
        self.conv = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1, 1, padding=0), nn.BatchNorm2d(mid_ch), nn.ReLU(),
                                  nn.Conv2d(mid_ch, mid_ch, 3, stride, padding=1), nn.BatchNorm2d(mid_ch), nn.ReLU(),
                                  nn.Conv2d(mid_ch, out_ch, 1, 1, padding=0), nn.BatchNorm2d(out_ch))
        if self.ch_asc:
            self.shortcut = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, 1, 0), nn.BatchNorm2d(out_ch))

    def forward(self, x_in):
        x = self.conv(x_in)
        if self.ch_asc:
            return F.relu(x + self.shortcut(x_in))
        else:
            return F.relu(x + x_in)

class myModel(nn.Module):

    def __init__(self, in_ch=3, n_classes=2):
        self.n_classes = n_classes
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
                                   nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU())
        self.conv2 = nn.Sequential(nn.MaxPool2d(3, stride=2, padding=1),
                                   Resblock(in_ch=64, out_ch=128),
                                   Resblock(in_ch=128, out_ch=128))
        self.conv3 = nn.Sequential(nn.MaxPool2d(2, stride=2, padding=0),
                                   Resblock(in_ch=128, out_ch=128),
                                   Resblock(in_ch=128, out_ch=256))
        self.conv4 = nn.Sequential(Resblock(in_ch=256, out_ch=256),
                                   Resblock(in_ch=256, out_ch=512))
        self.conv5 = nn.Sequential(Resblock(in_ch=512, out_ch=512),
                                   Resblock(in_ch=512, out_ch=512))
        self.up = nn.Sequential(nn.ConvTranspose2d(512, 256, 2, stride=2, padding=0), nn.BatchNorm2d(256), nn.ReLU(),
                                nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0), nn.BatchNorm2d(128), nn.ReLU(),
                                nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0), nn.BatchNorm2d(64), nn.ReLU(),
                                nn.Conv2d(64, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
                                nn.Conv2d(32, n_classes, 1, stride=1, padding=0))


    def forward(self, inputs):
        inputs = self.conv1(inputs)
        inputs = self.conv2(inputs)
        inputs = self.conv3(inputs)
        inputs = self.conv4(inputs)
        inputs = self.conv5(inputs)

        outputs = self.up(inputs)
        return outputs

In [None]:
def dice_coef(pred, label):
    """
    """
    assert(pred.shape == label.shape)
    batchSize = pred.shape[0]
    pred, label = pred.view(batchSize, -1), label.view(batchSize, -1)
    TP = torch.sum(pred * label, dim=1).float()
    return 2 * TP / (torch.sum(pred, dim=1).float() + torch.sum(label, dim=1).float() + 0.00000001)

@nb.njit
def encodePixel(binaryMap):
    """
    """
    assert len(binaryMap.shape) == 2
    binaryMap = binaryMap.reshape(-1)
    totalPixNum = binaryMap.shape[0]
    encodedStr = ""
    flag = 0
    count = 0
    for i in range(totalPixNum):
        if (binaryMap[i] == 1) and (flag == 0) and (i < totalPixNum-1):
            encodedStr += str(i+1)
            encodedStr += " "
            flag = 1
            count += 1
        elif (binaryMap[i] == 0) and (flag == 1):
            encodedStr += str(count)
            encodedStr += " "
            count = 0
            flag = 0
        elif (binaryMap[i] == 1) and (flag == 1) and (i < totalPixNum-1):
            count += 1
        elif (binaryMap[i] == 1) and (flag == 1) and (i == totalPixNum-1):
            encodedStr += str(count)
            encodedStr += " "
            count = 0
            flag = 0
        elif (binaryMap[i] == 1) and (flag == 0) and (i == totalPixNum-1):
            encodedStr += str(i+1)
            encodedStr += " 1 "


    return encodedStr[:-1]


def train(model, train_loader, val_loader, lr_base, epoch, device=torch.device('cpu:0'), comment="xxx", ckpt_path="./CKPT/", from_scrach=False):
    """
    """
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_base)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.8, patience=1,
                                                     min_lr=0.000001, threshold=0.0001)
    criteria = nn.CrossEntropyLoss()
    ckpt_path = ckpt_path + "/" + comment

    if os.path.exists(ckpt_path + "/ckpt.pth") and (not from_scrach):
        ckpt = torch.load(ckpt_path + "/ckpt.pth", map_location=torch.device('cpu'))
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])
        ep_start = ckpt["epoch"]
        iter_start = ckpt["iteration"]
    else:
        ep_start = 0
        iter_start = 0

    for ep in range(ep_start + 1, epoch + 1):
        with tqdm(enumerate(train_loader), desc="epoch %3d/%-3d" % (ep, epoch), total=len(train_loader)) as t:
            for batch_idx, (img_train, lb_train) in t:
                img_train, lb_train = img_train.to(device), lb_train.to(device)
                model.train()
                logits = model(img_train)
                loss_train = criteria(input=logits, target=lb_train)
                optimizer.zero_grad()
                loss_train.backward()
                optimizer.step()

                with torch.no_grad():
                    n_iter = (ep - 1) * len(train_loader) + batch_idx
                    pred = torch.argmax(F.softmax(logits, dim=1), dim=1)
                    dice_train = torch.mean(dice_coef(pred, lb_train))

                    if (batch_idx in [len(train_loader)-1, (len(train_loader)//2)]) or n_iter == iter_start:
                        model.eval()
                        dice_val_all = []
                        for index_v, (img_val, lb_val) in enumerate(val_loader):
                            img_val, lb_val = img_val.to(device), lb_val.to(device)
                            pred = torch.argmax(F.softmax(model(img_val), dim=1), dim=1)
                            dice_val_all.append(dice_coef(pred, lb_val))
                        dice_val_all = torch.cat(dice_val_all, dim=0)
                        dice_val = torch.mean(dice_val_all)

                t.set_postfix_str("loss(train): %.4f, dice(train): %.4f, dice(val): %.4f, lr: %f"
                                  % (loss_train, dice_train, dice_val,
                                     optimizer.state_dict()['param_groups'][0]['lr']))

        scheduler.step(dice_val)

        if not os.path.exists(ckpt_path):
            os.mkdir(ckpt_path)
        torch.save({'epoch': ep,
                    'iteration': ep * len(train_loader),
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    }, ckpt_path + "/ckpt.pth")


def inference(model, data_loader, ckpt_path, comment, save_path, device=torch.device('cpu:0')):
    """
    """
    ckpt_path = ckpt_path + "/" + comment

    if os.path.exists(ckpt_path + "/ckpt.pth"):
        ckpt = torch.load(ckpt_path + "/ckpt.pth", map_location=torch.device('cpu'))
        model.load_state_dict(ckpt["model"])
        model.to(device)

        fnList, encodedPixelList = ["ID"], ["Prediction"]
        with tqdm(enumerate(data_loader), desc="Inferencing" , total=len(data_loader)) as t:
            for batch_idx, img in t:
                img = img.to(device)
                model.eval()
                with torch.no_grad():
                    if (img.shape[2] % 8 + img.shape[3] % 8) == 0:
                        logits = model(img)
                        pred = torch.argmax(F.softmax(logits, dim=1), dim=1)
                    else:
                        pred = torch.zeros([1, img.shape[2], img.shape[3]], dtype=torch.int, device=device)
                        img_tiled = torch.cat([img[:, :, 0:128, 0:128], img[:, :, img.shape[2]-128:, 0:128], img[:, :, 0:128, 35:], img[:, :, img.shape[2]-128:, 35:]], dim=0)
                        pred_tiled = torch.argmax(F.softmax(model(img_tiled), dim=1), dim=1)
                        pred[0, 0:128, 0:128] = pred_tiled[0, ...]
                        pred[0, img.shape[2]-128:, 0:128] = pred_tiled[1, ...]
                        pred[0, 0:128, 35:] = pred_tiled[2, ...]
                        pred[0, img.shape[2]-128:, 35:] = pred_tiled[3, ...]

                pred = pred.cpu().numpy()
                fnList.append(data_loader.dataset.imgNameList[batch_idx])
                encodedPixelList.append(encodePixel(pred[0]))

            pred2submit = np.array(list(zip(fnList, encodedPixelList)))
            np.savetxt(save_path, pred2submit, delimiter=",", fmt="%s")
            print("prediction saved to %s" % save_path)
    else:
        print("Checkpoint not found")

In [None]:
model = myModel(n_classes=2)
trainLoader = getDataLoader(dataPath=DATAPATH + "/train_data", setName="train", shuffle=True, BSize=16, nWorkers=2, pinMem=True)
valLoader = getDataLoader(dataPath=DATAPATH+"/train_data", setName="val", shuffle=False, BSize=16, nWorkers=2, pinMem=True)
train(model, trainLoader, valLoader, lr_base=0.001, epoch=20, device=torch.device('cuda:0'), comment="foo", ckpt_path="/kaggle/working/", from_scrach=False)

In [None]:
testImageList = [i.path for i in os.scandir(DATAPATH+"/test_image")]
model = model.to(torch.device("cuda:0"))
model.eval()

with torch.no_grad():
    for i in range(0, len(testImageList), 50):
        img = cv2.imread(testImageList[i], -1).transpose(2, 0, 1)
        img = img/img.max()
        img = torch.as_tensor(img.reshape(1, 3, img.shape[1], img.shape[2]), dtype=torch.float).cuda()
        pred = torch.argmax(torch.softmax(model(img), dim=1), dim=1).cpu().numpy()[0, :, :]
        plt.figure(figsize=(10, 20))
        plt.subplot(1, 2, 1)
        plt.imshow(img.cpu().numpy().reshape(3, img.shape[2], img.shape[3]).transpose(1, 2, 0))
        plt.subplot(1, 2, 2)
        plt.imshow(pred)
        plt.show()
        if i>500:
            break

In [None]:
testLoader = getDataLoader(dataPath=DATAPATH + "/test_image", setName="test", shuffle=False, BSize=1, nWorkers=2, pinMem=True)
inference(model, data_loader=testLoader, ckpt_path="/kaggle/working/", comment="foo", save_path="/kaggle/working/prediction.csv", device=torch.device('cuda:0'))

In [None]:
df = pd.read_csv(r"/kaggle/working/prediction.csv")
print(""%len(df))
df.head(10)