In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [51]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [52]:
#The double conv Class: Double conv2d each followed by a batchnorm2d and Relu activation.
class DoubleConv(nn.Module):
  def __init__(self, in_channels,out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size = 3, stride = 1, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_channels,out_channels,kernel_size = 3, stride = 1, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True)
    )
  
  def forward(self,x):
    return self.conv(x)

In [53]:
#Architecture Class
class UNET(nn.Module):
  def __init__(self, in_channels = 3, out_channels = 2, features = [64, 128, 256, 512]):
    super(UNET, self).__init__()
    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size = 2, stride = 2))
      self.ups.append(DoubleConv(feature*2, feature))
    
    self.bottleneck = DoubleConv(features[-1], features[-1]*2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size = 1)
  
  def forward(self,x):
    skip_connections = []
    
    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)
    
    x = self.bottleneck(x) # bottom layer
    skip_connections = skip_connections[::-1]

    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx](x)
      skip_connection = skip_connections[idx//2]
      concat_skip = torch.cat((skip_connection,x), dim = 1) #concatenating Skip connection layer
      x = self.ups[idx + 1](concat_skip)

    return self.final_conv(x)


In [54]:
#Testing the architecture
def test():
  x = torch.randn((3,1,160,160))
  model = UNET(in_channels = 1, out_channels = 2,)
  preds = model(x)
  print(preds.shape)

test()

torch.Size([3, 2, 160, 160])


In [55]:
#Dataset
class ForestAreaDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform = None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    image_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index].replace("sat", "mask"))
    image = np.array(Image.open(image_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"), dtype = np.float32) # L is for single channel
    image = image/255.0
    mask = mask/255.0


    if self.transform is not None:
      augmentations = self.transform(image = image,mask = mask)
      image = augmentations["image"]
      mask = augmentations["mask"]
    
    return image, mask

In [56]:
# HYPERPARAMETERS
LEARNING_RATE = 1e-5
BATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_EPOCHS = 10
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "/content/drive/MyDrive/Forest area dataset/train_image/"
TRAIN_MASK_DIR = "/content/drive/MyDrive/Forest area dataset/train_mask/"
TEST_IMG_DIR = "/content/drive/MyDrive/Forest area dataset/test_image/"
TEST_MASK_DIR = "/content/drive/MyDrive/Forest area dataset/test_mask/"

In [57]:
# To check accuracy, using every pixel
def check_accuracy(loader, model, device = DEVICE):
  num_correct = 0
  num_pixels = 0
  model.eval()

  with torch.no_grad():
    for x, y in loader:
      x = x.to(device).float() #dtype error without float
      y = y.to(device).unsqueeze(1).float() #unsqueeze to get the correct dims
      preds = torch.sigmoid(model(x))
      preds = (preds > 0.5).float()
      num_correct += (preds == y).sum()
      num_pixels += torch.numel(preds) 

  print(
      f"Got {num_correct}/{num_pixels} with % accuracy {(num_correct/num_pixels)*100:.2f} "
  )
  model.train()

In [58]:
#Train and test loader using DataLoader
def get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    TEST_IMG_DIR,
    TEST_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    test_transform,
    NUM_WORKERS = 2,
    PIN_MEMORY = True
):
  train_ds = ForestAreaDataset(
      image_dir =TRAIN_IMG_DIR,
      mask_dir = TRAIN_MASK_DIR,
      transform = train_transform
  )

  train_loader = DataLoader(
      train_ds,
      batch_size = BATCH_SIZE,
      num_workers = NUM_WORKERS,
      pin_memory = PIN_MEMORY,
      shuffle = True
  )
  test_ds = ForestAreaDataset(
      image_dir =TEST_IMG_DIR,
      mask_dir = TEST_MASK_DIR,
      transform = test_transform
  )

  test_loader = DataLoader(
      test_ds,
      batch_size = BATCH_SIZE,
      num_workers = NUM_WORKERS,
      pin_memory = PIN_MEMORY,
      shuffle = False
  )
  
  return train_loader, test_loader

In [61]:
def save_predictions_as_imgs(
    loader, model, folder = '/content/drive/MyDrive/Check',device="cuda"
):
  model.eval()
  for idx, (x,y) in enumerate(loader):
    x = x.to(device=DEVICE).float()
    with torch.no_grad():
      preds = torch.sigmoid(model(x))
      preds = (preds >0.5).float()
    torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")

In [62]:
#train loop. tqdm for progress bar
def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)

  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(device = DEVICE).float() #dtype error without .float
    targets = targets.to(device = DEVICE).float().unsqueeze(1)

    with torch.cuda.amp.autocast():
      predictions = model(data)
      loss = loss_fn(predictions, targets)

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    loop.set_postfix(loss = loss.item())


train_transform = A.Compose(
      [
        A.Resize(height = IMAGE_HEIGHT,width = IMAGE_WIDTH),
        A.Rotate(limit = 30, p = 0.4),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.4),  
        ToTensorV2(),
      ],
    )

test_transform = A.Compose(
      [
          A.Resize(height = IMAGE_HEIGHT,width = IMAGE_WIDTH),
          ToTensorV2(),
      ],
    )

model = UNET(in_channels = 3, out_channels = 1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss() 
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

train_loader, test_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      TEST_IMG_DIR,
      TEST_MASK_DIR,
      BATCH_SIZE,
      train_transform,
      test_transform,
      NUM_WORKERS,
      PIN_MEMORY,
)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
  train_fn(train_loader, model, optimizer, loss_fn, scaler)

  check_accuracy(test_loader, model, device = DEVICE)
  model.eval()

save_predictions_as_imgs(test_loader, model, folder = '/content/drive/MyDrive/Check',device="cuda")

100%|██████████| 246/246 [01:36<00:00,  2.55it/s, loss=0.44]


Got 60853000/76873728 with % accuracy 79.16 


100%|██████████| 246/246 [01:32<00:00,  2.65it/s, loss=0.517]


Got 61438972/76873728 with % accuracy 79.92 


100%|██████████| 246/246 [01:32<00:00,  2.65it/s, loss=0.418]


Got 61814295/76873728 with % accuracy 80.41 


100%|██████████| 246/246 [01:32<00:00,  2.65it/s, loss=0.429]


Got 61887566/76873728 with % accuracy 80.51 


100%|██████████| 246/246 [01:32<00:00,  2.65it/s, loss=0.299]


Got 61288584/76873728 with % accuracy 79.73 


100%|██████████| 246/246 [01:32<00:00,  2.66it/s, loss=0.343]


Got 62364893/76873728 with % accuracy 81.13 


100%|██████████| 246/246 [01:32<00:00,  2.67it/s, loss=0.419]


Got 62347674/76873728 with % accuracy 81.10 


100%|██████████| 246/246 [01:32<00:00,  2.65it/s, loss=0.33]


Got 62651410/76873728 with % accuracy 81.50 


100%|██████████| 246/246 [01:32<00:00,  2.66it/s, loss=0.464]


Got 62332782/76873728 with % accuracy 81.08 


100%|██████████| 246/246 [01:32<00:00,  2.66it/s, loss=0.376]


Got 62437900/76873728 with % accuracy 81.22 
