In [1]:
import os
import torch 
import numpy as np
from PIL import Image
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

In [2]:
BASE_DATA_PATH = "/scratch/ssd002/datasets/cv_project/spacenet"
EXAMPLE_DATA_PATH = os.path.join(BASE_DATA_PATH, "AOI_4_Shanghai_Train_processed")
IMG_PATH = os.path.join(BASE_DATA_PATH, "AOI_4_Shanghai_Train_processed", "RGB-PanSharpen")
MASK_PATH = os.path.join(BASE_DATA_PATH, "AOI_4_Shanghai_Train_processed", "masks")

In [3]:
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 2e-4
IMG_DIM = (256, 256)
DEVICE = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Data Exploration and Visualization

In [4]:
print(os.listdir(BASE_DATA_PATH))
print(os.listdir(EXAMPLE_DATA_PATH))

['AOI_4_Shanghai_Train', 'AOI_3_Paris_Test_public', 'AOI_5_Khartoum_Test_public', 'AOI_4_Shanghai_Test_public', 'AOI_5_Khartoum_Test_public_processed', 'AOI_3_Paris_Train', 'AOI_4_Shanghai_Train_processed', 'AOI_3_Paris_Train_processed', 'AOI_2_Vegas_Train_processed', 'AOI_4_Shanghai_Test_public_processed', 'AOI_2_Vegas_Test_public_processed', 'AOI_2_Vegas_Train', 'AOI_5_Khartoum_Train_processed', 'AOI_5_Khartoum_Train', 'AOI_3_Paris_Test_public_processed', 'AOI_2_Vegas_Test_public']
['masks', 'extras', 'RGB-PanSharpen']


In [5]:
print(len(os.listdir(EXAMPLE_DATA_PATH + "/RGB-PanSharpen")))
print(len(os.listdir(EXAMPLE_DATA_PATH + "/masks")))

4582
4582


In [6]:
print(sorted(os.listdir(EXAMPLE_DATA_PATH + "/RGB-PanSharpen"))[:15])
print(sorted(os.listdir(EXAMPLE_DATA_PATH + "/masks"))[:15])

['RGB-PanSharpen_AOI_4_Shanghai_img1001.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1002.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1003.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1005.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1007.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1008.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1009.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1010.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1012.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1013.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1014.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1015.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1016.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1017.png', 'RGB-PanSharpen_AOI_4_Shanghai_img1018.png']
['AOI_4_Shanghai_img1001_mask.png', 'AOI_4_Shanghai_img1002_mask.png', 'AOI_4_Shanghai_img1003_mask.png', 'AOI_4_Shanghai_img1005_mask.png', 'AOI_4_Shanghai_img1007_mask.png', 'AOI_4_Shanghai_img1008_mask.png', 'AOI_4_Shanghai_img1009_mask.png', 'AOI_4_Shanghai_img1010_mask.png', 'AOI_4_Shanghai_img1012_mask.png', 'AOI_4_S

In [7]:
ex_img_path = os.path.join(EXAMPLE_DATA_PATH, "RGB-PanSharpen", "RGB-PanSharpen_AOI_4_Shanghai_img2070.png")
ex_img = np.array(Image.open(ex_img_path))
print(ex_img.shape)

(650, 650, 3)


In [8]:
ex_mask_path = os.path.join(EXAMPLE_DATA_PATH, "masks", "AOI_4_Shanghai_img2070_mask.png")
ex_mask = np.array(Image.open(ex_mask_path))
print(ex_mask.shape)

(650, 650)


In [9]:
fig, axarr = plt.subplots(1, 2, figsize=(20, 20))
axarr[0].imshow(ex_img)
axarr[1].imshow(ex_mask, cmap="gray")

<matplotlib.image.AxesImage at 0x7f53b35d1898>

## Custom Dataset Defintion

In [10]:
class SpaceNet_Dataset(Dataset): 
    def __init__(self, img_dir, mask_dir, transform = None):
        self.img_dir = img_dir 
        self.mask_dir = mask_dir
        self.images = os.listdir(img_dir)
        self.masks = os.listdir(mask_dir)
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index])
        
        img = Image.open(img_path)
        mask = Image.open(mask_path)
        
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
        
        return img, mask
        

## Model Defintion

In [11]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()

        self.maxpool_2x2 = nn.MaxPool2d(kernel_size= 2, stride= 2)
        self.dconv1= double_conv(in_channels,64)
        self.dconv2= double_conv(64,128)
        self.dconv3= double_conv(128,256)
        self.dconv4= double_conv(256,512)
        self.dconv5= double_conv(512,1024)
        self.num_of_classes = out_channels
        
        #Now, the first up convolution is performed followed by a double convolution to alter the number of channels of feature map.
        self.uptrans1= nn.ConvTranspose2d(
            in_channels= 1024,
            out_channels= 512,
            kernel_size= 2,
            stride= 2
        )
        
        self.upconv1= double_conv(1024,512)
        
        self.uptrans2= nn.ConvTranspose2d(
            in_channels= 512,
            out_channels= 256,
            kernel_size= 2,
            stride= 2
        )
        
        self.upconv2= double_conv(512, 256)
        
        self.uptrans3= nn.ConvTranspose2d(
            in_channels= 256,
            out_channels= 128,
            kernel_size= 2,
            stride= 2
        )
        
        self.upconv3= double_conv(256,128)
        
        self.uptrans4= nn.ConvTranspose2d(
            in_channels= 128,
            out_channels= 64,
            kernel_size= 2,
            stride= 2
        )
        
        self.upconv4= double_conv(128,64)
        
        self.out= nn.Conv2d(
            in_channels= 64,
            out_channels= self.num_of_classes,
            kernel_size= 1
        )
    
    def forward(self, image):
        
        #encoder
        enc_x_1= self.dconv1(image)
        enc_x_2= self.maxpool_2x2(enc_x_1)
        enc_x_3= self.dconv2(enc_x_2)
        enc_x_4= self.maxpool_2x2(enc_x_3)
        enc_x_5= self.dconv3(enc_x_4)
        enc_x_6= self.maxpool_2x2(enc_x_5)
        enc_x_7= self.dconv4(enc_x_6)
        enc_x_8= self.maxpool_2x2(enc_x_7)
        enc_x_9= self.dconv5(enc_x_8)
        
        #decoder
        dec_x_1= self.uptrans1(enc_x_9)
        dec_x_2 = self.upconv1(torch.cat([dec_x_1, enc_x_7],1))
        
        dec_x_3= self.uptrans2(dec_x_2)
        dec_x_4= self.upconv2(torch.cat([dec_x_3, enc_x_5],1))
        
        dec_x_5= self.uptrans3(dec_x_4)
        dec_x_6= self.upconv3(torch.cat([dec_x_5, enc_x_3],1))
        
        dec_x_7= self.uptrans4(dec_x_6)
        dec_x_8= self.upconv4(torch.cat([dec_x_7, enc_x_1],1))
        
        dec_x_9= self.out(dec_x_8)
        print(dec_x_9.size())
        return dec_x_9

In [12]:
def double_conv(in_c, out_c):
    
    conv= nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size= 3, padding=(1, 1)),
        nn.ReLU(inplace= False),
        nn.Conv2d(out_c, out_c, kernel_size= 3, padding=(1, 1)),
        nn.ReLU(inplace= False)
    )
    return conv

## Train and Validaiton

In [13]:
def train_fn(loader, model, optimizer, loss_fn):
    loop = tqdm(loader)
    
    for batch_id, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().to(device=DEVICE)
            
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loop.set_postfix(loss=loss.item())   

In [14]:
def main():
    normalize = transforms.Normalize(mean=[0], std=[1])
        
    data_transform = transforms.Compose([
        transforms.Resize(IMG_DIM),
            transforms.ToTensor(),
            normalize
    ])
    
    train_ds = SpaceNet_Dataset(img_dir=IMG_PATH, 
                                mask_dir=MASK_PATH, 
                                transform=data_transform)
    
    train_loader = DataLoader(train_ds,
                              batch_size=BATCH_SIZE,
                              num_workers=2,
                              pin_memory=True,
                              shuffle=True
                              )
    
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    for epoch in range(EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn)
    

In [15]:
main()

  0%|          | 0/144 [00:00<?, ?it/s]

torch.Size([32, 1, 256, 256])


  1%|          | 1/144 [00:03<07:34,  3.18s/it, loss=0.668]

torch.Size([32, 1, 256, 256])


  1%|▏         | 2/144 [00:04<06:18,  2.67s/it, loss=0.661]

torch.Size([32, 1, 256, 256])


  2%|▏         | 3/144 [00:06<05:22,  2.29s/it, loss=0.66] 

torch.Size([32, 1, 256, 256])


  3%|▎         | 4/144 [00:07<04:43,  2.02s/it, loss=0.659]

torch.Size([32, 1, 256, 256])


  3%|▎         | 5/144 [00:08<04:16,  1.84s/it, loss=0.651]

torch.Size([32, 1, 256, 256])


  4%|▍         | 6/144 [00:10<03:57,  1.72s/it, loss=0.644]

torch.Size([32, 1, 256, 256])


  5%|▍         | 7/144 [00:11<03:43,  1.63s/it, loss=0.64] 

torch.Size([32, 1, 256, 256])


  6%|▌         | 8/144 [00:13<03:32,  1.57s/it, loss=0.629]

torch.Size([32, 1, 256, 256])


  6%|▋         | 9/144 [00:14<03:25,  1.52s/it, loss=0.636]

torch.Size([32, 1, 256, 256])


  7%|▋         | 10/144 [00:16<03:20,  1.49s/it, loss=0.612]

torch.Size([32, 1, 256, 256])


KeyboardInterrupt: 