PyTorch implementation of U-Net: https://arxiv.org/pdf/1505.04597.pdf

In [None]:
import wandb
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import requests
from sklearn.utils import shuffle

Fetching urls of training images and labels from github. https://github.com/anuraglamsal/Nepali-Currency-Images-and-Segmentation-Masks

In [None]:
def get_github_folder_contents(owner, repo, path):
    url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
    response = requests.get(url)
    if response.status_code == 200:
        return response.json()
    else:
        print("Failed to fetch folder contents.")
        return None

def get_images(owner, repo, path):
    contents = get_github_folder_contents(owner, repo, path)
    images = [item['download_url'] for item in contents]
    return images

owner = "anuraglamsal"
repo = "training_images"
path_1 = "images"
path_2 = "labels"

images, labels = shuffle(get_images(owner, repo, path_1), get_images(owner, repo, path_2), random_state=0)

Writing the dataset class and doing train, validation and test split.

In [None]:
class CurrencyDataset(Dataset):
    def __init__(self, label_urls, image_urls, transform=None, target_transform=None):
        self.label_urls = label_urls
        self.image_urls = image_urls
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.label_urls)

    def __getitem__(self, idx):
        image = Image.open(requests.get(self.image_urls[idx], stream=True).raw)
        label = Image.open(requests.get(self.label_urls[idx], stream=True).raw)

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label


split_idx_train = int(0.8 * len(labels))
split_idx_test = split_idx_train + int(0.1 * len(labels))

# transformations to do to the images
transform = transforms.Compose([transforms.Resize((512, 512)), transforms.Grayscale(), transforms.ToTensor()])

dataset_train = CurrencyDataset(labels[0:split_idx_train], images[0:split_idx_train], transform, transform)
dataset_test = CurrencyDataset(labels[split_idx_train:split_idx_test], images[split_idx_train:split_idx_test], transform, transform)
dataset_val = CurrencyDataset(labels[split_idx_test:], images[split_idx_train:], transform, transform)

* init_channel = number of channels in your input image.
* next_channel = the number of channels that the result of the first convolution should have.

In [None]:
class UNet(nn.Module):

  def __init__(self, init_channel, next_channel):
    super(UNet, self).__init__()

    self.conv1 = nn.Conv2d(init_channel, next_channel, 3, padding='same')
    self.ReLU = nn.ReLU()
    self.conv2 = nn.Conv2d(next_channel, next_channel, 3, padding='same')
    self.maxPool = nn.MaxPool2d(2, 2, ceil_mode=True)
    self.conv3 = nn.Conv2d(next_channel, 2 * next_channel, 3, padding='same')
    self.conv4 = nn.Conv2d(2 * next_channel, 2 * next_channel, 3, padding='same')
    self.conv5 = nn.Conv2d(2 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv6 = nn.Conv2d(4 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv7 = nn.Conv2d(4 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv8 = nn.Conv2d(8 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv9 = nn.Conv2d(8 * next_channel, 16 * next_channel, 3, padding='same')
    self.conv10 = nn.Conv2d(16 * next_channel, 16 * next_channel, 3, padding='same')
    self.upsample = nn.Upsample(scale_factor=2)
    self.conv11 = nn.Conv2d(16 * next_channel, 8 * next_channel, 2, padding='same')

    #Apparently convtranspose2d also more or less does the same as upsample+conv2d.
    #Would be interesting to explore this too perhaps. But the paper says
    #upsample+conv, so I'll stick to this for now. Also, apparently, using transposed
    #convolution could have "checkboard artifacts" as compared to upsample+conv:
    #https://distill.pub/2016/deconv-checkerboard/

    #self.conv11 = nn.ConvTranspose2d(16 * next_channel, 8 * next_channel, 2, 2)

    self.conv12 = nn.Conv2d(16 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv13 = nn.Conv2d(8 * next_channel, 8 * next_channel, 3, padding='same')
    self.conv14 = nn.Conv2d(8 * next_channel, 4 * next_channel, 2, padding='same')
    self.conv15 = nn.Conv2d(8 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv16 = nn.Conv2d(4 * next_channel, 4 * next_channel, 3, padding='same')
    self.conv17 = nn.Conv2d(4 * next_channel, 2 * next_channel, 2, padding='same')
    self.conv18 = nn.Conv2d(4 * next_channel, 2 * next_channel, 3, padding='same')
    self.conv19 = nn.Conv2d(2 * next_channel, 2 * next_channel, 3, padding='same')
    self.conv20 = nn.Conv2d(2 * next_channel, next_channel, 2, padding='same')
    self.conv21 = nn.Conv2d(2 * next_channel, next_channel, 3, padding='same')
    self.conv22 = nn.Conv2d(next_channel, next_channel, 3, padding='same')
    self.conv23 = nn.Conv2d(next_channel, 1, 1)

    self.activation = nn.Sigmoid()

    # self.initialize_weights()

  def forward(self, x):

    block_1 = self.ReLU(self.conv2(self.ReLU(self.conv1(x))))

    block_2 = self.ReLU(self.conv4(self.ReLU(self.conv3(self.maxPool(block_1)))))

    block_3 = self.ReLU(self.conv6(self.ReLU(self.conv5(self.maxPool(block_2)))))

    block_4 = self.ReLU(self.conv8(self.ReLU(self.conv7(self.maxPool(block_3)))))
    #print(block_4.shape[2])

    block_5 = self.ReLU(self.conv10(self.ReLU(self.conv9(self.maxPool(block_4))))) # bottom-most block
    #print(block_5.shape[2])

    up_conv_1 = self.conv11(self.upsample(block_5))
    crop_1 = transforms.CenterCrop(up_conv_1.shape[2])
    block_6 = self.ReLU(self.conv13(self.ReLU(self.conv12(torch.cat((crop_1(block_4), up_conv_1), 1)))))
    #print(block_6.shape[2])

    up_conv_2 = self.conv14(self.upsample(block_6))
    crop_2 = transforms.CenterCrop(up_conv_2.shape[2])
    block_7 = self.ReLU(self.conv16(self.ReLU(self.conv15(torch.cat((crop_2(block_3), up_conv_2), 1)))))
    #print(block_7.shape[2])

    up_conv_3 = self.conv17(self.upsample(block_7))
    crop_3 = transforms.CenterCrop(up_conv_3.shape[2])
    block_8 = self.ReLU(self.conv19(self.ReLU(self.conv18(torch.cat((crop_3(block_2), up_conv_3), 1)))))
    #print(block_8.shape[2])

    up_conv_4 = self.conv20(self.upsample(block_8))
    crop_4 = transforms.CenterCrop(up_conv_4.shape[2])
    block_9 = self.conv23(self.ReLU(self.conv22(self.ReLU(self.conv21(torch.cat((crop_4(block_1), up_conv_4), 1))))))
    #print(block_9.shape[2])

    output = self.activation(block_9)

    return output

Weight initialization.

In [None]:
def initialize_weights(model):
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                # Apply Kaiming normal initialization
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    # Initialize the biases to zeros
                    nn.init.constant_(m.bias, 0)

Wandb init

In [None]:
wandb.init( # for new runs
    project="Currency Segmentation"
)

Training staging area. Can play with things here. 

In [None]:
model = UNet(1, 64).to(torch.device('cuda'))
loss_fn = nn.BCELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.99)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, min_lr=1e-6)

initialize_weights(model)

batch_size = 1

train_loader = DataLoader(
    dataset=dataset_train,
    batch_size=batch_size,
    shuffle=True,
)

val_loader = DataLoader(
    dataset=dataset_val,
    batch_size=batch_size,
)

test_loader = DataLoader(
    dataset=dataset_test,
    batch_size=batch_size,
)

model_path = "/path/to/model"

Testing and validation code.

In [None]:
def test(loader, model):
  model.eval()

  total_ones_target = 0
  total_ones_intersection = 0

  with torch.no_grad():

    for data, target in loader:
      data, target = data.to(torch.device('cuda')), target.to(torch.device('cuda'))

      output = model(data)

      pred = (output > 0.5).float() # binarizing prediction
      target = (target > 0.5).float() # binarizing actual

      intersection = pred * target

      total_ones_target += target.sum().item()
      total_ones_intersection += intersection.sum().item()

  accuracy = 100.0 * total_ones_intersection / total_ones_target

  print(f"Average IOU basically: {accuracy}")

  return accuracy

Training code.

Apparently padded convolutions is fine. The reason to not do it basically because of heavy computation. Look here: https://stackoverflow.com/questions/44014534/why-could-u-net-mask-image-with-smaller-mask

In [None]:
_epoch = -1

try:
  checkpoint = torch.load(model_path)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  _epoch = checkpoint['epoch']
except Exception as e:
  print(e)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=0.99)

for epoch in range(_epoch+1, 100):
  model.train()  # Set the model to training mode

  #print(optimizer.param_groups[0]['lr'])

  total_epoch_loss = 0.0
  total_num_of_pixels = 0

  #print('\n')

  for batch_idx, (inputs, labels) in enumerate(train_loader):

     inputs, labels = inputs.to(torch.device('cuda')), labels.to(torch.device('cuda'))
     optimizer.zero_grad()

     # Forward pass
     outputs = model(inputs)
     total_num_of_pixels += outputs.numel()

     # Compute the loss
     loss = loss_fn(outputs, labels)
     total_epoch_loss += loss.item()

     # Backward pass
     loss.backward()

     nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

     # Update the weights
     optimizer.step()

     #scheduler.step()
     if batch_idx % 100 == 0:
         print(total_epoch_loss / total_num_of_pixels)

  #print('\n')

  avg_loss_per_pixel_train =  total_epoch_loss / total_num_of_pixels
  acc = test(val_loader, model)

  scheduler.step(int(acc)) # can play with this too i guess.

  wandb.log({"avg_loss_per_pixel_train": avg_loss_per_pixel_train,
               "avg_IOU": acc})

  torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
  }, model_path)

print('\n')
print('Test set metrics: ')
_ = test(test_loader, model)

wandb.finish()

Rough code to understand cross-entropy loss in torch. Btw, cross-entropy in torch does log-softmax by itself, so you can send the "raw" output directly. The log used is natural log. Log of base 2 could also be used, but shouldn't really matter: https://stats.stackexchange.com/questions/295174/difference-in-log-base-for-cross-entropy-calcuation

In [None]:
#predicted output
x = torch.tensor([[[0.41, 0.56], [0.69, 0.84]], [[0.57, 0.37], [0.29, 0.90]]])
x = x.unsqueeze(0) #The generic way to store image tensors is with a batch size too.

#You can either give the actual probabilities.
y = torch.tensor([[[0., 1.], [1., 1.]],[[1., 0.], [0., 0.]]]) #target output
y = y.unsqueeze(0) #Same reason as above.

#Or you can specify which class the pixels belong to. The first class has
#index 0, the second class has index 1, and so on. This is better computationally.
#If you are 100% sure in your target output of the classes of the pixel i.e. you use
#0 or 1, then it is better to use this.
y_class_index = torch.tensor([[1, 0], [0, 0]]) #target output
y_class_index = y_class_index.unsqueeze(0) #interpreted as [batch, height, width]

#reduction 'none' gives you the cross entropy value of each pixel.
entropy_1 = nn.CrossEntropyLoss(reduction='none')

#reduction 'sum' gives you the cross entropy value of the whole image.
entropy_2 = nn.CrossEntropyLoss(reduction='sum')

#The outputs are the same, but the "class_index" approach is computationally
#better. Look at the "NOTE" portion here:
#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
#It makes sense that it is computationally better too. If the target corresponding
#to an output is 0, then that output doesn't contribute to the cross-entropy
#calculation (look at the formula for cross-entropy.) Therefore, you can
#ignore those. With the class index approach, you are essentially telling
#pytorch that your intention is to ignore the ignorable.
print(entropy_1(x, y))
print(entropy_1(x, y_class_index))

print(entropy_2(x, y))
print(entropy_2(x, y_class_index))

#Interesting that stuff in pytorch is so nicely structured to be straightforward
#for images. But maybe that's a bias because I'm only doing stuff for images.
#Maybe stuff is straightforward for other stuff too.

torch.Size([1, 2, 2, 2]) torch.Size([1, 2, 2, 2])
tensor([[[0.6163, 0.6027],
         [0.5130, 0.7236]]])
tensor([[[0.6163, 0.6027],
         [0.5130, 0.7236]]])
tensor(2.4556)
tensor(2.4556)


Rough code to understand how softmax works in torch.

In [None]:
x = torch.rand(2, 2, 2, 2)

#Softmax occurs between batches. For instance, the
#softmax for the value at the 1st pixel position in the 1st channel
#of the 1st batch would be between the values at the 1st pixel position
#in the 1st channel of all the 'N' batches.
soft_1 = nn.Softmax(dim=0)

#Softmax occurs between channels. For instance, the
#softmax for the value at the 1st pixel position in the 1st channel
#of some batch would be between the values at the 1st pixel position
#in all the channels of this batch.
soft_2 = nn.Softmax(dim=1)

#Softmax occurs amongst the values along the height of each channel.
soft_3 = nn.Softmax(dim=2)

#Softmax occurs amongst the values along the width of each channel.
soft_4 = nn.Softmax(dim=3)

print(soft_1(x))
print(soft_2(x))
print(soft_3(x))
print(soft_4(x))

#Btw, the values that you use to softmax sum up to 1 after softmax.

tensor([[[[0.4452, 0.3417],
          [0.5520, 0.3510]],

         [[0.4771, 0.5547],
          [0.5214, 0.5568]]],


        [[[0.5548, 0.6583],
          [0.4480, 0.6490]],

         [[0.5229, 0.4453],
          [0.4786, 0.4432]]]])
tensor([[[[0.4998, 0.3139],
          [0.5768, 0.4049]],

         [[0.5002, 0.6861],
          [0.4232, 0.5951]]],


        [[[0.5319, 0.5233],
          [0.5465, 0.6124]],

         [[0.4681, 0.4767],
          [0.4535, 0.3876]]]])
tensor([[[[0.4083, 0.4500],
          [0.5917, 0.5500]],

         [[0.4848, 0.5488],
          [0.5152, 0.4512]]],


        [[[0.5144, 0.4601],
          [0.4856, 0.5399]],

         [[0.5291, 0.5509],
          [0.4709, 0.4491]]]])
tensor([[[[0.4966, 0.5034],
          [0.5391, 0.4609]],

         [[0.3112, 0.6888],
          [0.3687, 0.6313]]],


        [[[0.3896, 0.6104],
          [0.3393, 0.6607]],

         [[0.3814, 0.6186],
          [0.4023, 0.5977]]]])


Rough code to understand how the upsample class works in torch.

In [None]:
img = Image.open(requests.get("https://c8.alamy.com/zooms/9/7df650603bfe4193bab024ee29da5461/2btr9xg.jpg", stream=True).raw)
tensor_trans = transforms.ToTensor() #converting the image to a tensor
img = tensor_trans(img) #converting the image to a tensor continued
img = img.unsqueeze(0) #adding a batch dimension to the image.
#Doing this because of the interpertation scheme of the shape values of tensors
#by the Upsample class. If our image's shape was (x, y, z), Upsample would interpret
#'x' as batch size, 'y' as number of channels and 'z' as the width. But obviously
#'x' is number of channels, 'y' is the height and 'z' is the width. But if the shape
#is (w, x, y, z), Upsample interprets 'w' as batch size, and 'x', 'y' and 'z' as wanted
#above. Adding batch size is fine; it doesn't remove generality. If you only
#want to upsample a single image, the batch size is just going to be 1. Btw, in
#the (x, y, z) case, Upsample only scales the width -- both "theoretically" and
#pratically, having a 3d tensor doesn't work for our purpose. For more clarification
#and information: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html

print(img.shape)
#by default, nearest is used. scale factor defines by what factor we want to increase
#the height/width. new_height_width = old_height_width * scale_factor.
upscale = nn.Upsample(scale_factor=2, mode='nearest')
img = upscale(img) #upsampling
print(img.shape)

torch.Size([1, 3, 447, 640])
torch.Size([1, 3, 894, 1280])


Rough code to understand cropping in torch.

In [None]:
x = torch.rand(1, 512, 64, 64)
#Look at the dotted lines in the paper Fig.1. I think that implies Center Crop.
crop = transforms.CenterCrop(56)
x = crop(x)
print(x.shape)


torch.Size([1, 512, 56, 56])


Rough code to understand how to do concat as required by the u-net paper in torch.

In [None]:
x = torch.ones(1, 2, 4, 4)
y = torch.zeros(1, 2, 4, 4)

#the dimension in torch.cat refer to the shape values above. Remember, shape values are
#(batch_size, num_of_channels, height, width).

#This does the concat on 'batch level' i.e. the new tensor will contain two image tensors x and y.
print(torch.cat((x, y), 0))
print(torch.cat((x, y), 0).shape)

#This does the concat on 'channel level' i.e. the channels of y will be added to x
#to create a new image tensor with more channels, specifically, new_number_of_
#channels = no_of_channels_x + no_of_channels_y
print(torch.cat((x, y), 1))
print(torch.cat((x, y), 1).shape)

#Concat on 'height level' i.e. rows of the 1st channel of y will be added to the rows
#of the 1st channel of x vertically. This is done similarly for each channel.
print(torch.cat((x,y), 2))
print(torch.cat((x, y), 2).shape)

#Concat on 'width level' i.e. rows of the 1st channel of y will be added to the rows
#of the 1st channel of x horizontally. This is done similarly for each channel.
print(torch.cat((x, y), 3))
print(torch.cat((x, y), 3).shape)

#Look at the outputs to understand more. It also really helps if you understand
#how pytorch tensors are structured i.e. what each pair of [] represents -- specifically
#in the context of images I guess in our case.



tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]],


        [[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])
torch.Size([2, 2, 4, 4])
tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]]]])
torch.Size([1, 4, 4, 4])
tensor([[[[1., 1., 1., 1.],
      