In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )
  def forward(self, x):
    return self.conv(x)

In [None]:
class UNET(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
    super(UNET, self).__init__()
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # Down part of UNET
    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    # Up part of UNET
    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
      self.ups.append(DoubleConv(feature*2, feature))

    self.bottleneck = DoubleConv(features[-1], features[-1]*2)

    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    skip_connections = skip_connections[::-1]

    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx](x)
      skip_connection = skip_connections[idx//2]

      if (x.shape != skip_connection.shape):
        x = torch.nn.functional.interpolate(x, size=skip_connection.shape[2:]) # [2: ] because we want to exclude number of batches and number of channels

      concat_skip = torch.cat((skip_connection, x), dim=1)
      x = self.ups[idx+1](concat_skip)

    return self.final_conv(x)


In [None]:
def test():
  x = torch.randn((3, 1, 160, 160))
  model = UNET(in_channels=1, out_channels=1)
  preds = model(x)
  print(preds.shape)
  print(x.shape)
  assert preds.shape == x.shape

In [None]:
test()

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class CarvanaDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg","_mask.gif"))
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
    mask[mask == 255.0] = 1.0

    if self.transform is not None:
      augmentations = self.transform(image=image, mask=mask)
      image = augmentations['image']
      mask = augmentations['mask']

    return image, mask

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 160 # 1280 originally
IMAGE_WIDTH = 240 # 1918 originally
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "/content/drive/MyDrive/UNET images/train_images/"
TRAIN_MASK_DIR = "/content/drive/MyDrive/UNET images/train_masks/"
VAL_IMG_DIR = "/content/drive/MyDrive/UNET images/val_images/"
VAL_MASK_DIR = "/content/drive/MyDrive/UNET images/val_masks/"

def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)
  for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(DEVICE)
    targets = targets.float().unsqueeze(1).to(device=DEVICE)

    predictions = model(data)
    loss = loss_fn(predictions, targets)

    # backward
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # update tqdm loop
    loop.set_postfix(loss=loss.item())


def main():
  train_transform = A.Compose(
      [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=1.0),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2()
      ]
  )
  val_transforms = A.Compose(
      [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2()
      ]
  )

  model = UNET(in_channels=3, out_channels=1).to(DEVICE)
  loss_fn = nn.BCEWithLogitsLoss() # Cross Entropy Loss for multiple classes
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
  train_loader, val_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      VAL_IMG_DIR,
      VAL_MASK_DIR,
      BATCH_SIZE,
      train_transform,
      val_transforms,
      NUM_WORKERS,
      PIN_MEMORY
  )

  scaler = torch.cuda.amp.GradScaler()

  for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, device = DEVICE)

    # print some examples to a folder
    save_predictions_as_imgs(
        val_loader, model, folder="saved_images/", device=DEVICE
    )

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
  print("=> Saving Checkpoint")
  torch.save(state, filename)

def load_checkpoint(checkpoint, model):
  print("=> Loading checkpoint")
  model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
      train_dir,
      train_maskdir,
      val_dir,
      val_maskdir,
      batch_size,
      train_transform,
      val_transform,
      num_workers=4,
      pin_memory = True
  ):
    train_ds = CarvanaDataset(
        image_dir = train_dir,
        mask_dir = train_maskdir,
        transform = train_transform
    )

    train_loader = DataLoader(
        train_ds,
        batch_size = batch_size,
        num_workers = num_workers,
        pin_memory = pin_memory,
        shuffle = True
    )

    val_ds = CarvanaDataset(
        image_dir = val_dir,
        mask_dir = val_maskdir,
        transform = val_transform
    )

    val_loader = DataLoader(
        val_ds,
        batch_size = batch_size,
        num_workers = num_workers,
        pin_memory = pin_memory,
        shuffle = False
    )
    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
  num_correct = 0
  num_pixels = 0
  dice_score = 0
  model.eval()
  with torch.no_grad():
    for x,y in loader:
      x = x.to(device)
      y = y.to(device)
      preds = torch.sigmoid(model(x))
      preds = (preds>0.5).float()
      num_correct += (preds == y).sum()
      num_pixels += torch.numel(preds)
      dice_score += (2*(preds*y).sum()) / ((preds+y).sum() + 1e-8)

  print(f"Got {num_correct}/{num_pixels} with acc {(num_correct/num_pixels)*100:.2f}")
  print(f"Dice Score : {dice_score/len(loader)}")
  model.train()

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
  model.eval()
  for idx, (x, y) in enumerate(loader):
    x = x.to(device = device)
    with torch.no_grad():
      preds = torch.sigmoid(model(x))
      preds = (preds > 0.5).float()
    torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
    torchvision.utils.save_image(y.unsqueeze(1), f"{folder}/pred_{idx}.png")

  model.train()

In [None]:
main()

100%|██████████| 145/145 [04:21<00:00,  1.80s/it, loss=0.144]


=> Saving Checkpoint
Got 452584382/17203200 with acc 2630.82
Dice Score : 0.6564556360244751


100%|██████████| 145/145 [04:17<00:00,  1.78s/it, loss=0.109]


=> Saving Checkpoint
Got 428784896/17203200 with acc 2492.47
Dice Score : 0.6018999814987183


100%|██████████| 145/145 [04:16<00:00,  1.77s/it, loss=0.0815]


=> Saving Checkpoint
Got 455013808/17203200 with acc 2644.94
Dice Score : 0.6535269021987915


100%|██████████| 145/145 [04:14<00:00,  1.75s/it, loss=0.0683]


=> Saving Checkpoint
Got 451048000/17203200 with acc 2621.88
Dice Score : 0.6451285481452942


100%|██████████| 145/145 [04:14<00:00,  1.76s/it, loss=0.053]


=> Saving Checkpoint
Got 456704930/17203200 with acc 2654.77
Dice Score : 0.6556177139282227


100%|██████████| 145/145 [04:13<00:00,  1.75s/it, loss=0.0447]


=> Saving Checkpoint
Got 454977400/17203200 with acc 2644.73
Dice Score : 0.654357373714447


100%|██████████| 145/145 [04:15<00:00,  1.76s/it, loss=0.0374]


=> Saving Checkpoint
Got 448555000/17203200 with acc 2607.39
Dice Score : 0.6385640501976013


100%|██████████| 145/145 [04:20<00:00,  1.80s/it, loss=0.0339]


=> Saving Checkpoint
Got 458077750/17203200 with acc 2662.75
Dice Score : 0.6574647426605225


100%|██████████| 145/145 [04:16<00:00,  1.77s/it, loss=0.0285]


=> Saving Checkpoint
Got 442785984/17203200 with acc 2573.86
Dice Score : 0.6268711090087891


100%|██████████| 145/145 [04:16<00:00,  1.77s/it, loss=0.0258]


=> Saving Checkpoint
Got 446676920/17203200 with acc 2596.48
Dice Score : 0.6357702612876892


  6%|▌         | 8/145 [00:16<04:49,  2.11s/it, loss=0.027]


KeyboardInterrupt: ignored

In [None]:
try:
    from going_modular.going_modular import data_setup, engine
    from helper_functions import download_data, set_seeds, plot_loss_curves
except:
    # Get the going_modular scripts
    print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
    !git clone https://github.com/mrdbourke/pytorch-deep-learning
    !mv pytorch-deep-learning/going_modular .
    !mv pytorch-deep-learning/helper_functions.py . # get the helper_functions.py script
    !rm -rf pytorch-deep-learning

[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.
Cloning into 'pytorch-deep-learning'...
remote: Enumerating objects: 3435, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (87/87), done.[K
remote: Total 3435 (delta 55), reused 97 (delta 41), pack-reused 3302[K
Receiving objects: 100% (3435/3435), 643.58 MiB | 16.62 MiB/s, done.
Resolving deltas: 100% (1962/1962), done.
Updating files: 100% (222/222), done.


In [None]:
from google.colab import files
files.download('UNET.zip')

In [None]:
!cd models/ && zip -r ../UNET.zip * -x "*.pyc" "*.ipynb" "*__pycache__*" "*ipynb_checkpoints*"

  adding: UNET.pth (deflated 8%)


In [None]:
from google.colab import files
files.download('UNET.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
from going_modular.going_modular import utils

utils.save_model(
    model = model,
    target_dir = "models",
    model_name = "UNET.pth")

[INFO] Saving model to: models/UNET.pth


In [None]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)

In [None]:
model.load_state_dict(torch.load('my_checkpoint.pth.tar'))

RuntimeError: ignored

In [None]:
train_transform = A.Compose(
      [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=1.0),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2()
      ]
  )
val_transforms = A.Compose(
      [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean = [0.0, 0.0, 0.0],
            std = [1.0, 1.0, 1.0],
            max_pixel_value = 255.0
        ),
        ToTensorV2()
      ]
  )

model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss() # Cross Entropy Loss for multiple classes
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
train_loader, val_loader = get_loaders(
      TRAIN_IMG_DIR,
      TRAIN_MASK_DIR,
      VAL_IMG_DIR,
      VAL_MASK_DIR,
      BATCH_SIZE,
      train_transform,
      val_transforms,
      NUM_WORKERS,
      PIN_MEMORY
  )

scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, device = DEVICE)

    # print some examples to a folder
    save_predictions_as_imgs(
        val_loader, model, folder="saved_images/", device=DEVICE
    )

  0%|          | 0/145 [00:00<?, ?it/s]