## Importing Libraries

In [1]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models
import torch.nn as nn
from torch.utils.data import DataLoader

## Setting Device to GPU for faster training

In [2]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

## Custom Dataset For Training Images

In [3]:
class SatelliteImagesTrain(Dataset):
    def __init__(self, csv_file, split, transform=None):
        self.csv_file = pd.read_csv(csv_file)
        self.transform = transform
        self.split = split
        
    def __len__(self):
        return len(self.csv_file.where(self.csv_file["split"] == self.split).dropna(axis = 0, how = 'all'))

    def __getitem__(self, index):
        sat_img_path = self.csv_file["sat_image_path"].where(self.csv_file["split"] == self.split).dropna(axis = 0, how = 'all')[index]
        mask_img_path = self.csv_file["mask_path"].where(self.csv_file["split"] == self.split).dropna(axis = 0, how = 'all')[index]
        sat_img = Image.open(sat_img_path).convert("RGB")
        mask_img = Image.open(mask_img_path).convert("RGB")
        
        if self.transform is not None:
            sat_img = self.transform(sat_img)
            mask_img = self.transform(mask_img)

        return (sat_img, mask_img)

## Custom Dataset for validation and test set

In [4]:
class SatelliteImagesTestValid(Dataset):
    def __init__(self, csv_file, split, transform=None):
        self.csv_file = pd.read_csv(csv_file)
        self.transform = transform
        self.split = split
        
    def __len__(self):
        return len(self.csv_file.where(self.csv_file["split"] == self.split).dropna(axis = 0, how = 'all').reset_index(drop = True))

    def __getitem__(self, index):
        sat_img_path = self.csv_file["sat_image_path"].where(self.csv_file["split"] == self.split).dropna(axis = 0, how = 'all').reset_index(drop = True)[index]
        sat_img = Image.open(sat_img_path).convert("RGB")
        
        if self.transform is not None:
            sat_img = self.transform(sat_img)

        return sat_img

In [5]:
train_dataset = SatelliteImagesTrain(csv_file = "metadata.csv",split = "train", transform = transforms.ToTensor())

In [6]:
valid_dataset = SatelliteImagesTestValid(csv_file = "metadata.csv",split = "valid", transform = transforms.ToTensor())

In [7]:
test_dataset = SatelliteImagesTestValid(csv_file = "metadata.csv",split = "test", transform = transforms.ToTensor())

In [8]:
len(train_dataset)

803

In [9]:
len(valid_dataset)

171

In [10]:
len(test_dataset)

172

## Testing by printing images sizes and correspondiong tensors

In [11]:
for i,j in train_dataset:
    print(i.size()) #sat image sizse
    print(j.size()) #mask image size
    print(j) # mask image tensor
    print(i) # sat image tensor
    break

torch.Size([3, 2448, 2448])
torch.Size([3, 2448, 2448])
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
tensor([[[0.5059, 0.5059, 0.4980,  ..., 0.4078, 0.4078, 0.4078],
         [0.5020, 0.4980, 0.4902,  ..., 0.4039, 0.4118, 0.4275],
         [0.4902, 0.4941, 0.

In [12]:
for (i,j) in zip(valid_dataset,test_dataset):
    print(i.size()) # valid sat img size
    print(j.size()) # test img size
    print(j) # test sat img tensor
    print(i) # valid sat img tensor
    break

torch.Size([3, 2448, 2448])
torch.Size([3, 2448, 2448])
tensor([[[0.3725, 0.3569, 0.3647,  ..., 0.3059, 0.2471, 0.3216],
         [0.3647, 0.3529, 0.3490,  ..., 0.3373, 0.3216, 0.3059],
         [0.3490, 0.3490, 0.3451,  ..., 0.3686, 0.3765, 0.3412],
         ...,
         [0.2902, 0.2745, 0.2667,  ..., 0.1686, 0.1961, 0.1608],
         [0.3059, 0.3059, 0.2902,  ..., 0.1490, 0.1529, 0.1412],
         [0.2902, 0.3216, 0.3098,  ..., 0.1373, 0.1686, 0.1765]],

        [[0.3529, 0.3373, 0.3451,  ..., 0.3059, 0.2392, 0.3137],
         [0.3451, 0.3333, 0.3294,  ..., 0.3451, 0.3216, 0.3059],
         [0.3373, 0.3373, 0.3255,  ..., 0.3922, 0.4000, 0.3569],
         ...,
         [0.2824, 0.2667, 0.2588,  ..., 0.1569, 0.1882, 0.1529],
         [0.2980, 0.2980, 0.2902,  ..., 0.1373, 0.1412, 0.1333],
         [0.2824, 0.3137, 0.3098,  ..., 0.1137, 0.1569, 0.1647]],

        [[0.2353, 0.2196, 0.2196,  ..., 0.1569, 0.0902, 0.1608],
         [0.2275, 0.2157, 0.2118,  ..., 0.2000, 0.1804, 0.1569],
  

## Model

In [None]:
# class CNN(nn.Module)

## Hyperparameters

In [13]:
# Hyperparameters
num_epochs = 10
learning_rate = 0.001
batch_size = 32
shuffle = True
pin_memory = True
num_workers = 1

## Creating dataloaders for train,valid, test from custom datasets

In [14]:
train_loader = DataLoader(dataset=train_dataset, shuffle=shuffle, batch_size=batch_size,num_workers=num_workers,pin_memory=pin_memory)
validation_loader = DataLoader(dataset=valid_dataset, shuffle=shuffle, batch_size=batch_size,num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(dataset=test_dataset, shuffle=shuffle, batch_size=batch_size,num_workers=num_workers, pin_memory=pin_memory)

## Setting the loss function, optimizer and learning rate

In [16]:
# model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Helper Function to check accuracy

In [20]:
# pending

## Training Loop

In [17]:
# epoch_loss_list = []
# loop = tqdm(range(num_epochs),total = num_epochs)
# for epoch in loop:
#     epoch_loss = 0
#     for sat_img, mask_img in train_dataloader:
#         # Forward Propagation:
#         sat_img = sat_img.to(device)
#         mask_img = mask_img.to(device)
#         output = model(sat_img)
#         # Backward Propagation:
#         loss = criterion(output, mask_img)
#         epoch_loss+=loss.item()
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#     epoch_loss_list.append(loss.item())
#     loop.set_postfix({"Loss":epoch_loss})

## Epoch Loss Plot

In [18]:
# ## Plot Epoch Loss:
# plt.figure(dpi=150)
# plt.plot([epoch for epoch in range(num_epochs)], epoch_loss_list)
# plt.xlabel("Epoch Number")
# plt.ylabel("Epoch Train Loss")
# plt.title("CNN Train Loss plot Per Epoch")
# plt.show()