In [1]:
import os
import torch
import torch.nn as nn
from skimage import io
from torch.utils.data import DataLoader
from utilities import preprocess_image, postprocess_image
from torchvision import transforms
import torch.nn.functional as F

from PIL import Image
import numpy as np

from briarmbg import BriaRMBG

In [8]:
# Masking image. It's just an exapmle for 1 train pic. It should be integrated in CustomDataset class

image_path = 'train_images_output/kinder_original.png'
image = Image.open(image_path)

# Get alpha channel
mask = image.split()[-1]

# Saving mask(not binary yet)
mask.save('train_images_output/kinder.png')

In [9]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, mask_dir, device):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.device = device
        self.images = [img for img in os.listdir(image_dir) if img.endswith('.png')]  # U can remove if u shure that all pics will be in .png
        self.model_input_size = [1024, 1024]

    def preprocess_image(self, im: np.ndarray, model_input_size: list) -> torch.Tensor:
        if len(im.shape) < 3:
            im = im[:, :, np.newaxis]
        im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)  # CxHxW
        im_tensor = F.interpolate(im_tensor.unsqueeze(0), size=model_input_size, mode='bilinear', align_corners=False)
        im_tensor = im_tensor.squeeze(0).type(torch.float32)
        image = im_tensor / 255.0
        image = transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])(image)
        return image

    def preprocess_mask(self, mask: np.ndarray, model_input_size: list) -> torch.Tensor:
        if len(mask.shape) == 2:
            mask = mask[np.newaxis, :, :]  # Adding channel dimension
        mask_tensor = torch.tensor(mask, dtype=torch.float32)   # CxHxW
        mask_tensor = F.interpolate(mask_tensor.unsqueeze(0), size=model_input_size, mode='nearest').squeeze(0)  # remove align_corners
        mask_tensor = (mask_tensor > 0.5).float()  # binarization
        return mask_tensor


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        orig_im = io.imread(img_path)
        orig_mask = io.imread(mask_path, as_gray=True)

        #print(orig_im.shape)
        #print(orig_mask.shape)

        image = self.preprocess_image(orig_im, self.model_input_size).to(self.device)
        mask = self.preprocess_mask(orig_mask, self.model_input_size).to(self.device)

        return image, mask

# FineTuning function
def finetune_model(image_dir, mask_dir, model_path, epochs, batch_size, learning_rate):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Dataset and dataloader
    dataset = CustomDataset(image_dir=image_dir, mask_dir=mask_dir, device=device)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Load local weights(.pth)
    model = BriaRMBG()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model = model.to(device)
    
    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCEWithLogitsLoss()

    # train loop
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)[0][0]
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader)}')

    # Saving model
    torch.save(model.state_dict(), 'finetuned_model.pth')
    del model

In [10]:
# Launch FineTuning
finetune_model(image_dir='train_images_input/',
               mask_dir='train_images_output/',
               model_path='model.pth',
               epochs=9,
               batch_size=4,
               learning_rate=1e-4)

Epoch 1/9, Loss: 0.5617952346801758
Epoch 2/9, Loss: 0.5557054281234741
Epoch 3/9, Loss: 0.5548108816146851
Epoch 4/9, Loss: 0.5534306764602661
Epoch 5/9, Loss: 0.5525507926940918
Epoch 6/9, Loss: 0.5517305135726929
Epoch 7/9, Loss: 0.5513277649879456
Epoch 8/9, Loss: 0.5511763095855713
Epoch 9/9, Loss: 0.5510510206222534
