PyTorch implementation of U-Net: https://arxiv.org/pdf/1505.04597.pdf <br><br>
Note that the <span style="color:red">training code</span> may have <span style="color:red">errors</span> (Like the IOU calculation). But the U-Net impelementation part, I believe, is correct.

In [None]:
import wandb
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import requests
from sklearn.utils import shuffle

Fetching urls of training images and labels from github. https://github.com/anuraglamsal/Nepali-Currency-Images-and-Segmentation-Masks

In [None]:
def get_github_folder_contents(owner, repo, path):
    url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
    response = requests.get(url)
    if response.status_code == 200:
        return response.json()
    else:
        print("Failed to fetch folder contents.")
        return None

def get_images(owner, repo, path):
    contents = get_github_folder_contents(owner, repo, path)
    images = [item['download_url'] for item in contents]
    return images

owner = "anuraglamsal"
repo = "training_images"
path_1 = "images"
path_2 = "labels"

images, labels = shuffle(get_images(owner, repo, path_1), get_images(owner, repo, path_2), random_state=0)

Writing the dataset class and doing train, validation and test split.

In [None]:
class CurrencyDataset(Dataset):
    def __init__(self, label_urls, image_urls, transform=None, target_transform=None):
        self.label_urls = label_urls
        self.image_urls = image_urls
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.label_urls)

    def __getitem__(self, idx):
        image = Image.open(requests.get(self.image_urls[idx], stream=True).raw)
        label = Image.open(requests.get(self.label_urls[idx], stream=True).raw)

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label


split_idx_train = int(0.8 * len(labels))
split_idx_test = split_idx_train + int(0.1 * len(labels))

# transformations to do to the images
transform = transforms.Compose([transforms.Resize((512, 512)), transforms.Grayscale(), transforms.ToTensor()])

dataset_train = CurrencyDataset(labels[0:split_idx_train], images[0:split_idx_train], transform, transform)
dataset_test = CurrencyDataset(labels[split_idx_train:split_idx_test], images[split_idx_train:split_idx_test], transform, transform)
dataset_val = CurrencyDataset(labels[split_idx_test:], images[split_idx_train:], transform, transform)

* init_channel = number of channels in your input image.
* next_channel = the number of channels that the result of the first convolution should have.

In [None]:
class UNet(nn.Module):

  def __init__(self, init_channel, next_channel):
    super(UNet, self).__init__()

    self.conv1 = nn.Conv2d(init_channel, next_channel, 3, padding='same')
    self.ReLU = nn.ReLU()
    self.conv2 = nn.Conv2d(next_channel, next_channel, 3, padding='same')
    self.maxPool = nn.MaxPool2d(2, 2, ceil_mode=True)
    self.conv3 = nn.Conv2d(next_channel, 2 * next_channel, 3, padding='same')
    self.conv4 = nn.Conv2d(2 * next_channel, 2 * next_channel, 3, padding='same')
    self.conv5 = nn.Conv2d(2 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv6 = nn.Conv2d(4 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv7 = nn.Conv2d(4 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv8 = nn.Conv2d(8 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv9 = nn.Conv2d(8 * next_channel, 16 * next_channel, 3, padding='same')
    self.conv10 = nn.Conv2d(16 * next_channel, 16 * next_channel, 3, padding='same')
    self.upsample = nn.Upsample(scale_factor=2)
    self.conv11 = nn.Conv2d(16 * next_channel, 8 * next_channel, 2, padding='same')

    #Apparently convtranspose2d also more or less does the same as upsample+conv2d.
    #Would be interesting to explore this too perhaps. But the paper says
    #upsample+conv, so I'll stick to this for now. Also, apparently, using transposed
    #convolution could have "checkboard artifacts" as compared to upsample+conv:
    #https://distill.pub/2016/deconv-checkerboard/

    #self.conv11 = nn.ConvTranspose2d(16 * next_channel, 8 * next_channel, 2, 2)

    self.conv12 = nn.Conv2d(16 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv13 = nn.Conv2d(8 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv14 = nn.Conv2d(8 * next_channel, 4 * next_channel, 2, padding='same')
    self.conv15 = nn.Conv2d(8 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv16 = nn.Conv2d(4 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv17 = nn.Conv2d(4 * next_channel, 2 * next_channel, 2, padding='same')
    self.conv18 = nn.Conv2d(4 * next_channel, 2 * next_channel, 3, padding='same')
    self.conv19 = nn.Conv2d(2 * next_channel, 2 * next_channel, 3, padding='same')
    self.conv20 = nn.Conv2d(2 * next_channel, next_channel, 2, padding='same')
    self.conv21 = nn.Conv2d(2 * next_channel, next_channel, 3, padding='same')
    self.conv22 = nn.Conv2d(next_channel, next_channel, 3, padding='same')
    self.conv23 = nn.Conv2d(next_channel, 1, 1)

    self.activation = nn.Sigmoid()

    # self.initialize_weights()

  def forward(self, x):

    block_1 = self.ReLU(self.conv2(self.ReLU(self.conv1(x))))

    block_2 = self.ReLU(self.conv4(self.ReLU(self.conv3(self.maxPool(block_1)))))

    block_3 = self.ReLU(self.conv6(self.ReLU(self.conv5(self.maxPool(block_2)))))

    block_4 = self.ReLU(self.conv8(self.ReLU(self.conv7(self.maxPool(block_3)))))
    #print(block_4.shape[2])

    block_5 = self.ReLU(self.conv10(self.ReLU(self.conv9(self.maxPool(block_4))))) # bottom-most block
    #print(block_5.shape[2])

    up_conv_1 = self.conv11(self.upsample(block_5))
    crop_1 = transforms.CenterCrop(up_conv_1.shape[2])
    block_6 = self.ReLU(self.conv13(self.ReLU(self.conv12(torch.cat((crop_1(block_4), up_conv_1), 1)))))
    #print(block_6.shape[2])

    up_conv_2 = self.conv14(self.upsample(block_6))
    crop_2 = transforms.CenterCrop(up_conv_2.shape[2])
    block_7 = self.ReLU(self.conv16(self.ReLU(self.conv15(torch.cat((crop_2(block_3), up_conv_2), 1)))))
    #print(block_7.shape[2])

    up_conv_3 = self.conv17(self.upsample(block_7))
    crop_3 = transforms.CenterCrop(up_conv_3.shape[2])
    block_8 = self.ReLU(self.conv19(self.ReLU(self.conv18(torch.cat((crop_3(block_2), up_conv_3), 1)))))
    #print(block_8.shape[2])

    up_conv_4 = self.conv20(self.upsample(block_8))
    crop_4 = transforms.CenterCrop(up_conv_4.shape[2])
    block_9 = self.conv23(self.ReLU(self.conv22(self.ReLU(self.conv21(torch.cat((crop_4(block_1), up_conv_4), 1))))))
    #print(block_9.shape[2])

    output = self.activation(block_9)

    return output

Weight initialization.

In [None]:
def initialize_weights(model):
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                # Apply Kaiming normal initialization
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    # Initialize the biases to zeros
                    nn.init.constant_(m.bias, 0)

Wandb init

In [None]:
wandb.init( # for new runs
    project="Currency Segmentation"
)

Training staging area. Can play with things here. 

In [None]:
model = UNet(1, 64).to(torch.device('cuda'))
loss_fn = nn.BCELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.99)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, min_lr=1e-6)

initialize_weights(model)

batch_size = 1

train_loader = DataLoader(
    dataset=dataset_train,
    batch_size=batch_size,
    shuffle=True,
)

val_loader = DataLoader(
    dataset=dataset_val,
    batch_size=batch_size,
)

test_loader = DataLoader(
    dataset=dataset_test,
    batch_size=batch_size,
)

model_path = "/path/to/model"

Testing and validation code.

In [None]:
def test(loader, model):
  model.eval()

  total_ones_target = 0
  total_ones_intersection = 0

  with torch.no_grad():

    for data, target in loader:
      data, target = data.to(torch.device('cuda')), target.to(torch.device('cuda'))

      output = model(data)

      pred = (output > 0.5).float() # binarizing prediction
      target = (target > 0.5).float() # binarizing actual

      intersection = pred * target

      total_ones_target += target.sum().item()
      total_ones_intersection += intersection.sum().item()

  accuracy = 100.0 * total_ones_intersection / total_ones_target

  print(f"Average IOU: {accuracy}")

  return accuracy

Training code.

Apparently padded convolutions is fine. The reason to not do it basically because of heavy computation. Look here: https://stackoverflow.com/questions/44014534/why-could-u-net-mask-image-with-smaller-mask

In [None]:
_epoch = -1

try:
  checkpoint = torch.load(model_path)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  _epoch = checkpoint['epoch']
except Exception as e:
  print(e)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=0.99)

for epoch in range(_epoch+1, 100):
  model.train()  # Set the model to training mode

  #print(optimizer.param_groups[0]['lr'])

  total_epoch_loss = 0.0
  total_num_of_pixels = 0

  #print('\n')

  for batch_idx, (inputs, labels) in enumerate(train_loader):

     inputs, labels = inputs.to(torch.device('cuda')), labels.to(torch.device('cuda'))
     optimizer.zero_grad()

     # Forward pass
     outputs = model(inputs)
     total_num_of_pixels += outputs.numel()

     # Compute the loss
     loss = loss_fn(outputs, labels)
     total_epoch_loss += loss.item()

     # Backward pass
     loss.backward()

     nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

     # Update the weights
     optimizer.step()

     #scheduler.step()
     if batch_idx % 100 == 0:
         print(total_epoch_loss / total_num_of_pixels)

  #print('\n')

  avg_loss_per_pixel_train =  total_epoch_loss / total_num_of_pixels
  acc = test(val_loader, model)

  scheduler.step(int(acc)) # can play with this too i guess.

  wandb.log({"avg_loss_per_pixel_train": avg_loss_per_pixel_train,
               "avg_IOU": acc})

  torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
  }, model_path)

print('\n')
print('Test set metrics: ')
_ = test(test_loader, model)

wandb.finish()