In [300]:
#!pip install albumentations==0.4.6

In [301]:
#from google.colab import drive
#drive.mount('/content/drive')

In [302]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.functional as TF

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import numpy as np

#augmentation 
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.auto import tqdm
import torch.optim as optim

from numpy import float32



#Utils

In [303]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
  print("=> Saving checkpoint")
  torch.save(state, filename)

def load_checkpoint(checkpoint, model):
  print("=> Loading checkpoint")
  model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
  train_ds = CarvanaDataset(
      image_dir = train_dir,
      mask_dir  = train_maskdir,
      transform = train_transform, 
  )
  train_loader = DataLoader(
      train_ds,
      batch_size=batch_size,
      num_workers=num_workers,
      pin_memory=pin_memory,
      shuffle=True,
  )

  val_ds = CarvanaDataset(
      image_dir = val_dir,
      mask_dir  = val_maskdir,
      transform = val_transform, 
  )
  val_loader = DataLoader(
      val_ds,
      batch_size=batch_size,
      num_workers=num_workers,
      pin_memory=pin_memory,
      shuffle=False,
  )

  return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
  num_correct = 0
  num_pixels = 0
  #Better metric than accuracy
  dice_score = 0
  model.eval()

  with torch.no_grad():
    for x, y in loader:
      x = x.to(device)
      #The label doesnt have a channel because is a gray scale
    
      #The model returns the logits, so we need to use an activation function
      if NUM_CLASSES > 1:
         y = y.to(device)
         preds = model(x)
      else:
        y = y.to(device).unsqueeze(1)
        preds = torch.sigmoid(model(x))
        preds = (preds > 0.5).float()
        num_correct += (preds == y).sum()
        num_pixels = torch.numel(preds)

      dice_score += (2 * (preds * y).sum() / ((preds + y).sum() + 1e-8) )
  
  print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.3f}")
  print(f"Dice score: {dice_score/len(loader)}")
  model.train()

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
  model.eval()
  for idx, (x, y) in enumerate(loader):
    x = x.to(device=device)
    with torch.no_grad():
      preds = model(x)
      torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
      if NUM_CLASSES > 1:
        torchvision.utils.save_image(y, f"{folder}/y_{idx}.png")
      else:
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/y_{idx}.png")
      
      
  model.train()

# Model 

In [304]:
class DoubleConv(nn.Module):
  def __init__ (self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        #Here we add Batch Normalization to improve the paper's model
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

  def forward(self, x):
    return self.conv(x)
      

In [305]:
class UNET(nn.Module):
  #features represent the original paper dimensions.
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
    super(UNET, self).__init__()
    #Encoder part
    self.downs = nn.ModuleList()
    #Decoder part
    self.ups = nn.ModuleList()
    #Pool
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

   #Create a list of contracting path
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

   #Create a list of expansive path
    for feature in reversed(features):
      self.ups.append(
         #Featur2 * 2 is to create the 1024 dim
         nn.ConvTranspose2d(feature *2, feature, kernel_size=2, stride=2,)
     )
      self.ups.append(DoubleConv(feature * 2, feature))
  
    self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
   
    #For save the connections with the up part
    skip_connections = []
    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)
    #Here is the bottom part of the net
    x = self.bottleneck(x)
    
    #Start the up part
    #Reverse list
    skip_connections = skip_connections[:: -1]

    #Step of two because we use up and doubleconv
    #0 is the up
    #1 is the double conv
    for idx in range (0, len(self.ups), 2):
      x = self.ups[idx](x)
      #Divide idx by 2 for going liner with the skip connections
      skip_connection = skip_connections[idx//2]

      #General solutions for image tha not are divisibles
      if x.shape != skip_connection.shape:
        #Take the H and W, skip the Batch Size and Channels
        x = TF.resize(x, size=skip_connection.shape[2:])
      #Add the skip connection
      concat_skip = torch.cat((skip_connection, x), dim=1)
      # (e.g 0 +1 for the doubleconv)
      x = self.ups[idx+1](concat_skip)
    
    return self.final_conv(x)

In [306]:
def test():
  #Batch, Channel, H, W
  x = torch.randn((3,3,160,160))
  model = UNET(in_channels=3,out_channels=3)
  preds = model(x)
  print(preds.shape)
  print(x.shape)

  assert preds.shape == x.shape

test()

torch.Size([3, 3, 160, 160])
torch.Size([3, 3, 160, 160])


# Train

In [307]:
class CarvanaDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(self.image_dir)

  def __len__(self):
    return len(self.images)
  
  def __getitem__(self, index):
     img_path = os.path.join(self.image_dir, self.images[index])
     mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
     
     #The image input is an RBG but the image mask is in a grayscale
     image = np.array(Image.open(img_path).convert("RGB"))
     if NUM_CLASSES > 1:
       mask = np.array(Image.open(mask_path).convert("RGB"))
     else:
       mask = np.array(Image.open(mask_path).convert("L"), dtype=float32)
     
     
     #mask[mask == 255.0] = 1.0
     if self.transform is not None:
       augmentations = self.transform(image=image, mask=mask)
       image = augmentations["image"]
       mask = augmentations["mask"]
     return image, mask

# Hyperparameters

In [308]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
NUM_EPOCHS = 5
NUM_WORKERS = 0
NUM_CHANNELS = 3
NUM_CLASSES = 3
IMAGE_HEIGHT = 160 #Originally 1280
IMAGE_WIDTH = 240 #Originally 1918
PIN_MEMORY = True
LOAD_MODEL = False
#TRAIN_IMG_DIR = "/content/drive/MyDrive/plants/train"
#TRAIN_MASK_DIR = "/content/drive/MyDrive/plants/train_masks"
#VAL_IMG_DIR = "/content/drive/MyDrive/plants/val"
#VAL_MASK_DIR = "/content/drive/MyDrive/plants/val_masks"

TRAIN_IMG_DIR = "G:/Mi unidad/plants/train"
TRAIN_MASK_DIR = "G:/Mi unidad/plants/train_masks"
VAL_IMG_DIR = "G:/Mi unidad/plants/val"
VAL_MASK_DIR = "G:/Mi unidad/plants/val_masks"

#TRAIN_IMG_DIR = "G:/Mi unidad/cavana/train"
#TRAIN_MASK_DIR = "G:/Mi unidad/cavana/train_masks"
#VAL_IMG_DIR = "G:/Mi unidad/cavana/val"
#VAL_MASK_DIR = "G:/Mi unidad/cavana/val_masks"

print(DEVICE)

cpu


# Training 

In [309]:
#Will going to do 1 epoch training
def train_fn(loader, model, optimizer, loss_fn, scaler):
  #Progress bar
  loop = tqdm(loader)

  for batch_idx, (data, targets) in enumerate (loop):
    data = data.to(device=DEVICE)
    #For the Binary Cross Entropy using must be a float
    #Unsqueese is for adding a channel dimension.
    if NUM_CLASSES > 1:
      targets = targets.float().to(device=DEVICE)
    else:
      targets = targets.float().unsqueeze(1).to(device=DEVICE)
    

    #Forward
    #Run with mixture precision float 32 ops with float16 and stuffs like that.
    with torch.cuda.amp.autocast():
      if NUM_CLASSES > 1:
        predictions = model(data)
        predictions = predictions.float().to(device=DEVICE)
        loss = loss_fn(predictions, targets)
      else:
        predictions = model(data)
        predictions = predictions.float().to(device=DEVICE)
        loss = loss_fn(predictions, targets)
    
    #Backwards
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer=optimizer)
    scaler.update()

    #update tqdm loop

    loop.set_postfix(loss=loss.item())

# Main

## Transforms

In [310]:
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.1),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
    #TODO change if is a gray scale image
],additional_targets={'image': 'image', 'mask': 'image'})

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
    #TODO
],additional_targets={'image': 'image', 'mask': 'image'})

## Create model

In [311]:
model = UNET(in_channels=NUM_CHANNELS, out_channels=NUM_CLASSES).to(DEVICE)
#With Logits because the model has not activation function
if NUM_CLASSES > 1:
  loss = nn.MSELoss()
else:
  loss = nn.BCEWithLogitsLoss()

#Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transform,
    NUM_WORKERS,
    PIN_MEMORY,
)

if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

#check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
  train_fn(train_loader, model, optimizer, loss, scaler)

  checkpoint = {
      "state_dict"  : model.state_dict(),
      "optimizer"   : optimizer.state_dict()
  }

  save_checkpoint(checkpoint)

  #Check acc

  #check_accuracy(val_loader, model, device=DEVICE)

  save_predictions_as_imgs(val_loader, model, folder="saved_images", device=DEVICE)

#Loader



100%|██████████| 71/71 [01:13<00:00,  1.04s/it, loss=0.0159]
=> Saving checkpoint
100%|██████████| 71/71 [01:13<00:00,  1.03s/it, loss=0.00426]
=> Saving checkpoint
100%|██████████| 71/71 [01:14<00:00,  1.05s/it, loss=0.00308]
=> Saving checkpoint
100%|██████████| 71/71 [01:13<00:00,  1.03s/it, loss=0.00267]
=> Saving checkpoint
100%|██████████| 71/71 [01:12<00:00,  1.02s/it, loss=0.00227]
=> Saving checkpoint


In [312]:
x = torch.randn((1,160,180,3))
print(x.shape)

print(x.permute(0,3,1,2).shape)

torch.Size([1, 160, 180, 3])
torch.Size([1, 3, 160, 180])
