In [1]:
import torch
from torch import nn
import torchvision.transforms.functional


class DoubleConvolution(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    # first convolution with kerel size = 3 and stride = 1 (padding 1 added so output will have same size as input)
    self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    self.act1 = nn.ReLU()
    self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
    self.act2 = nn.ReLU()
  
  def forward(self, x):
    x = self.first(x);
    x = self.act1(x);
    x = self.second(x);
    return self.act2(x)


class DownSample(nn.Module):
  def __init__(self):
    super().__init__()

    self.pool = nn.MaxPool2d(2)

  def forward(self, x):
    return self.pool(x)


class UpSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    
    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
  
  def forward(self, x):
    return self.up(x)


class CropAndConcat(nn.Module):
  def forward(self, x, contracting_x):
    contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])
    x = torch.cat([x, contracting_x], dim=1)
    return x


class UNet(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.down_conv = nn.ModuleList([DoubleConvolution(i,o) for i, o in [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])
    self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])
    self.middle_conv = DoubleConvolution(512, 1024)
    self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]])
    self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]])
    self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])
    self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
  
  def forward(self, x):
    pass_through = []
    for i in range(len(self.down_conv)):
      x = self.down_conv[i](x)
      pass_through.append(x)
      x = self.down_sample[i](x)
    x = self.middle_conv(x)
    for i in range(len(self.up_conv)):
      x = self.up_sample[i](x)
      x = self.concat[i](x, pass_through.pop())
      x = self.up_conv[i](x)
    x = self.final_conv(x)
    return x

In [6]:
import numpy as np
import torch
import torch.utils.data
import torchvision.transforms.functional
from torch import nn

from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml_helpers.device import DeviceConfigs
from labml_nn.unet.carvana import CarvanaDataset

In [25]:
class Configs(BaseConfigs):
  device: torch.device = DeviceConfigs()
  model: UNet
  image_channels = 3
  mask_channels = 1
  batch_size = 1
  learning_rate = 2.5e-4
  epochs = 4
  dataset: CarvanaDataset
  data_loader: torch.utils.data.DataLoader
  loss_func = nn.BCELoss()
  sigmoid = nn.Sigmoid()
  optimizer: torch.optim.Adam

  def init(self):
    self.dataset = CarvanaDataset(lab.get_data_path()/'carvana'/'train',
                                  lab.get_data_path()/'carvana'/'train_masks')
    self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
    self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
    self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
    tracker.set_image("sample", True)

  @torch.no_grad()
  def sample(self, idx=-1):
    x, _ = self.dataset[np.random.randint(len(self.dataset))]
    x = x.to(self.device)
    mask = self.sigmoid(self.model(x[None,:]))
    x = torchvision.tranforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
    tracker.save('sample', x*mask)

  def train(self):
    for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
      tracker.add_global_step()
      image, mask = image.to(self.device), mask.to(self.device)
      self.optimizer.zero_grad()
      logits = self.model(image)

      mask=torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
      loss = self.loss_func(self.sigmoid(logits), mask)
      loss.backward()
      self.optimizer.step()
      tracker.save('loss', loss)

  def run(self):
    for _ in monit.loop(self.epochs):
      self.train()
      tracker.new_line()
      experiment.save_checkpoint()


In [26]:
def main():
  experiment.create(name='unet')
  configs = Configs()
  experiment.configs(configs, {})
  configs.init()
  experiment.add_pytorch_models({'model': configs.model})
  with experiment.start():
    configs.run()

In [32]:
main()

OutOfMemoryError: CUDA out of memory. Tried to allocate 232.00 MiB (GPU 0; 3.81 GiB total capacity; 2.57 GiB already allocated; 70.38 MiB free; 2.85 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [31]:
import os
os.environ['CUDA_VISIBLE_DEVICES']=''