In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

import os
from PIL import Image
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision.transforms as transforms

from tqdm.auto import tqdm

In [None]:
class UNET(nn.Module):

  class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
      super().__init__()
      self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = (3,3), stride = 1, padding = 1, bias = False),
                                nn.BatchNorm2d(out_channels),
                                nn.ReLU(incplace = True),
                                nn.Conv2d(out_channels, out_channels, kernel_size = (3,3), stride = 1, padding = 1, bias = False),
                                nn.BatchNorm2d(out_channels),
                                nn.ReLU(incplace = True))

    def forward(self, x):
      return self.conv(x)

  def __init__(self, in_channels = 3, out_channels = 1, features =[64, 128, 256, 512]):
    super().__init__()
    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature * 2, feature), kernel_size = 2, stride = 2)
      self.ups.append(DoubleConv(feature * 2, feature))

    self.bottleneck = DoubleConv(features[-1], feature[-1] * 2)
    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size = 1)

  def forward(self, x):
    skip_connections = []
    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx](x)
      skip_connection = skip_connections[idx // 2]

      if x.shape != skip_connections.shape:
        x = TF.resize(x, size = skip_connections.shape[2:])
      concat_skip = torch.cat((skip_connection, x), dim = 1)
      x = self.ups[idx + 1](concat_skip)

    x = self.final_conv(x)

    return x







In [None]:
class CarvanaDataset(Dataset):
  def __init__(self, image_dir, mask_dir, transform = None, mode = 'train'):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, index):
    img_path = os.path.join([self.image_dir, self.images[index]])
    mask_path = os.path.join([self.mask_dir, self.images[index].replace('.jpg', '_mask.gif')])
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"), dtype = np.float32)
    mask[mask == 255.0] = 1.0

    if self.transform is not None:
      image, mask = self.transform(image = image, mask = mask)

    return image, mask



In [None]:
# Hyperparameters e
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
NUM_EPOCHS = 10
NUM_WORKERS = -1
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "/content/train"
TRAIN_MASK_DIR = "/content/train_masks"

In [None]:
!unzip /content/train.zip
!unzip /content/train_masks.zip

Archive:  /content/train.zip
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of /content/train.zip or
        /content/train.zip.zip, and cannot find /content/train.zip.ZIP, period.


In [None]:
os.listdir(TRAIN_IMG_DIR)

FileNotFoundError: [Errno 2] No such file or directory: '/content/train'

In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)

  for bacth_idx, (data, target) in enumerate(loop):
    data = data.to(device = DEVICE)
    targets = targets.float().unsqueeze(1).to(device = DEVICE)

    #forward
    with torch.cuda.amp.autocast():
      predicitions = model(data)
      loss = loss_fn(predicitions, targets)

    #backward
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    loop.set_postfix(loss = loss.item())


In [None]:
train_transform = transforms.Compose(
    [transforms.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH),
     transforms.Rotate(limit = 35, p = 1.0),
     transforms.HorizontalFlip(p = 0.5),
     transforms.VerticalFlip(p = 0.1),
     transforms.Normalize(mean = [0.0, 0.0, 0.0], std = [1.0, 1.0, 1.0],
                 max_pixel_value = 255.0),
     transforms.ToTensorV2(),],
)

val_transform = transforms.Compose(
    [transforms.Resize(height = IMAGE_HEIGHT, width = IMAGE_WIDTH),
     transforms.Normalize(mean = [0.0, 0.0, 0.0], std = [1.0, 1.0, 1.0],
                 max_pixel_value = 255.0),
     transforms.ToTensorV2(),],
)

TypeError: Resize.__init__() got an unexpected keyword argument 'height'

In [None]:
model = UNET(in_channels = 3, out_channels = 1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [None]:
train_loader, val_loader = get_loaders(TRAIN_IMG_DIR, TRAIN_MASK_DIR, train_transforms, val_transforms)


In [None]:
scaler = torch.cuda.amp.GradScaler()


In [None]:
for epoch in range(NUM_EPOCHS):
  train_fn = (train_loader, model, optimizer, loss_fn, scaler)



In [None]:
def save_checkpoint(state, filename = "checkpoint.pth.tar"):
  print("=> Saving checkpoint")
  torch.save(state, filename)

def load_checkpoint(checkpoint, model):
  print("=> Loading checkpoint")
  model.load_state_dict(checkpoint['state_dict'])

def get_loaders(train_dir, train_maskdir, train_transform, val_transforms, num_workers = -1, pin_memory = True):
  train_dataset = CarvanaDataset(image_dir = train_dir, mask_dir = train_maskdir, transforms = train_transform)
  train_loader = DataLoader(train_dataset, batch_size = batch_size, num_workers= num_workers, pin_memory = pin_memory, shuffle = True)
