# download datasets and libraries

* download p3m_10k dataset from here: https://drive.google.com/uc?export=download&id=1LqUU7BZeiq8I3i5KxApdOJ2haXm-cEv1 

In [2]:
!gdown "https://drive.google.com/uc?export=download&id=1LqUU7BZeiq8I3i5KxApdOJ2haXm-cEv1" -O "/content/sample_data/pm_10k.zip"

Access denied with the following error:

 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?export=download&id=1LqUU7BZeiq8I3i5KxApdOJ2haXm-cEv1 



In [None]:
# ! cp "/content/sample_data/pm_10k.zip" "/content/drive/MyDrive/datasets/"

In [None]:
# !unzip "/content/sample_data/pm_10k.zip" -d "/content/sample_data/"

In [4]:
!unzip "/content/drive/MyDrive/datasets/pm_10k.zip" -d "/content/sample_data/"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/sample_data/P3M-10k/train/mask/p_f0724e0e.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_b3887f38.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_d6794eb2.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_cdc1f351.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_c3762085.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_683e2df2.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_48151e4a.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_c2692f3b.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_3391c451.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_a239672c.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_643eedec.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_bb107b56.png  
  inflating: /content/sample_data/P3M-10k/train/mask/p_a46ef225.png  
  inflating: /content/sam

# import libraries

In [1]:
import os
from glob import glob
import sys
import numpy as np
import pandas as pd
import cv2
from PIL import Image
from PIL.Image import Image as PILImage
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
import torch.optim.lr_scheduler as lr_scheduler

import u2net # upload from local drive [NOT FROM PIP]

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [2]:
HOME_PATH = "/content/sample_data/P3M-10k/"
TRAIN_PATH = os.path.join(HOME_PATH, "train")
TEST_PATH = os.path.join(HOME_PATH, "validation", "P3M-500-NP")

model_dir = "/content/sample_data/model_checkpoint"
if not os.path.exists(model_dir):
  os.mkdir(model_dir)
  os.mkdir(os.path.join(model_dir, "images"))

# hyperparameters
EPOCHS = 20
TRAIN_BATCH_SIZE = 10
VAL_BATCH_SIZE = 3
RESIZE = (320, 320)
lr = 2e-04

In [3]:
train_image_names = pd.read_csv("/content/sample_data/P3M-10k/train_list.txt", header=None)
train_image_names = train_image_names.values.flatten().tolist()

val_image_names = pd.read_csv("/content/sample_data/P3M-10k/P3M-500-NP_list.txt", header=None)
val_image_names = val_image_names.values.flatten().tolist()

In [4]:
# Define a custom transformation to normalize the mask
class NormalizeImage:
    def __call__(self, mask):
        # Ensure the mask values are between 0 and 1
        mask = torch.clamp(mask, 0, 1)
        return mask

class P3MDataset(Dataset):
    def __init__(self, root_dir, image_dir, mask_dir, image_names=[], resize=(224, 224)) -> None:
        super(P3MDataset, self).__init__()
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_names = image_names

        # Add a transform to resize the images to 224x224.
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(resize, antialias=True),
            NormalizeImage()
        ])

    def __len__(self):
      return len(os.listdir(os.path.join(self.root_dir, self.image_dir)))

    def __getitem__(self, index):
        name = self.image_names[index]
        # Load the image and mask at the given index.
        image = Image.open(os.path.join(self.root_dir, self.image_dir, f'{name}.jpg'))
        mask = Image.open(os.path.join(self.root_dir, self.mask_dir, f'{name}.png')).convert("L")
        # Apply the transform to the image.
        image = self.transform(image)
        mask = self.transform(mask)

        # Return the image and mask tensors.
        return image, mask


In [6]:
# Create the training dataset
train_dataset = P3MDataset(TRAIN_PATH, 'blurred_image', 'mask', train_image_names, resize=RESIZE)

# Create the validation dataset
val_dataset = P3MDataset(TEST_PATH, 'original_image', 'mask', val_image_names, resize=RESIZE)

# Create the training dataloader
train_dataloader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)

# Create the validation dataloader
val_dataloader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)

In [7]:
# taking a batch for testing
data_iterator = iter(val_dataloader)  # Create an iterator from the DataLoader
val_batch = next(data_iterator)  # Load the next batch
len(val_batch), val_batch[0].shape

(2, torch.Size([3, 3, 320, 320]))

In [8]:
for imgs, masks in train_dataloader:
  imgs = imgs
  masks = masks
  print(imgs.shape, masks.shape)
  break

torch.Size([10, 3, 320, 320]) torch.Size([10, 1, 320, 320])


In [9]:
# load model
net = u2net.U2NET(3, 1)
if device == 'cuda':
  net.to(device)

logs = "/content/sample_data/logs"
writer = SummaryWriter(logs)

In [10]:
bce_loss = nn.BCELoss(reduction='mean')
net_optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lmbda = lambda epoch: 0.65 ** epoch
scheduler = lr_scheduler.MultiplicativeLR(net_optimizer, lr_lambda=lmbda)

# loss function
def multi_BCEloss(s0, s1, s2, s3, s4, s5, s6, actual_mask):
  loss0 = bce_loss(s0, actual_mask)
  loss1 = bce_loss(s1, actual_mask)
  loss2 = bce_loss(s2, actual_mask)
  loss3 = bce_loss(s3, actual_mask)
  loss4 = bce_loss(s4, actual_mask)
  loss5 = bce_loss(s5, actual_mask)
  loss6 = bce_loss(s6, actual_mask)

  loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
  return loss0, loss


In [None]:
# training
epochs = EPOCHS
checkpoint_interval = 10
batch_interval = 150

net.train()
for epoch in range(0, epochs):
  for batch_id, (img, mask) in enumerate(train_dataloader):
    if device == 'cuda':
      img = img.to(device)
      mask = mask.to(device)
    s0, s1, s2, s3, s4, s5, s6 = net(img)
    loss0, loss = multi_BCEloss(s0, s1, s2, s3, s4, s5, s6, mask)
    net_optimizer.zero_grad()
    loss.backward()
    net_optimizer.step()

    # Print losses occasionally and print to tensorboard
    if batch_id % batch_interval == 0 :
        print(f"Epoch [{epoch}/{epochs}] Batch {batch_id}/{len(train_dataloader)} \
        Loss : {loss:.4f}, loss 0: {loss0:.4f}")
        with torch.no_grad():
          test_img, test_mask = val_batch
          if device == 'cuda':
             test_img = test_img.to(device)
             test_mask = test_mask.to(device)
          s0, s1, s2, s3, s4, s5, s6 = net(test_img)

          # Concatenate along the batch dimension
          combined = torch.cat((test_img, test_mask.repeat(1,3,1,1), s0.repeat(1,3,1,1)), dim=0)
          # Make a grid
          grid = torchvision.utils.make_grid(combined, nrow=3)
          # Convert tensor to numpy for visualization
          grid = grid.cpu().numpy().transpose((1, 2, 0))
          # Update the plot
          plt.ion()  # Turn on interactive mode
          fig, ax = plt.subplots()
          ax.clear()
          ax.set(xticks=[], yticks=[])
          ax.imshow(grid)
          plt.draw()
          plt.pause(0.001)  # Small pause to allow the plot to update
          fig.savefig(os.path.join(model_dir, 'images', 'image_at_epoch_{:02d}{:04d}.png'.format(epoch, batch_id)))

    # del temporary outputs and loss
    del s0, s1, s2, s3, s4, s5, s6, loss0
    # break
  scheduler.step()
  # Check if it's time to create a checkpoint
  # if epoch % checkpoint_interval == 0:
  #     checkpoint_filename = os.path.join(model_dir, f"model_checkpoint_epoch_{epoch + 1}.pt")
  #     checkpoint = {
  #           'epoch': epoch + 1,
  #           'batch': batch_id,
  #           'model_state_dict': net.state_dict(),
  #           'optimizer_state_dict': net_optimizer.state_dict(),
  #           'loss': loss,  # Save the current loss if needed
  #       # You can include more information in the checkpoint if necessary
  #       }
      # torch.save(checkpoint, checkpoint_filename)
    # break
  # break


Output hidden; open in https://colab.research.google.com to view.

In [None]:
checkpoint_filename = os.path.join(model_dir, f"u2net_model.pt")
checkpoint = {
      'epoch': epoch + 1,
      'batch': batch_id,
      'model_state_dict': net.state_dict(),
      'optimizer_state_dict': net_optimizer.state_dict(),
      'loss': loss,  # Save the current loss if needed
  # You can include more information in the checkpoint if necessary
  }
torch.save(checkpoint, checkpoint_filename)

In [None]:
# change mode to evaluation mode
net.eval()
mask_pred = net(imgs)
mask_pred[0].shape

plt.imshow(mask_pred[0][0].detach().cpu().repeat(3,1,1).permute(1,2,0))
plt.show()

In [None]:
# END