# U²-Net Fine-tuning for Clothing Background Removal
Follow the steps below to fine-tune U²-Net on your dataset (images + masks).

In [None]:
!nvidia-smi || echo 'No GPU runtime detected'
!pip -q install torch torchvision opencv-python pillow matplotlib tqdm

In [None]:
!git clone https://github.com/xuebinqin/U-2-Net.git
%cd U-2-Net

In [None]:
import os, glob, random, numpy as np, torch, torchvision as tv
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from model import U2NET
import torch.nn.functional as F


In [None]:
# Mount Drive if using Google Drive
from google.colab import drive
drive.mount('/content/drive')
DATA_ROOT = '/content/drive/MyDrive/u2net_clothing_dataset'  # set to your path
IMG_DIR = os.path.join(DATA_ROOT, 'images')
MSK_DIR = os.path.join(DATA_ROOT, 'masks')
SAVE_PATH = '/content/drive/MyDrive/u2net_clothing.pth'


In [None]:
class PairDataset(Dataset):
  def __init__(self, img_dir, msk_dir, size=320):
    self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*')))
    self.msk_paths = sorted(glob.glob(os.path.join(msk_dir, '*')))
    assert len(self.img_paths) == len(self.msk_paths) and len(self.img_paths) > 0
    self.size = size

  def __len__(self):
    return len(self.img_paths)

  def __getitem__(self, idx):
    img = Image.open(self.img_paths[idx]).convert('RGB').resize((self.size, self.size))
    msk = Image.open(self.msk_paths[idx]).convert('L').resize((self.size, self.size))
    img = np.array(img).astype(np.float32) / 255.0
    img = np.transpose(img, (2,0,1))
    msk = np.array(msk).astype(np.float32) / 255.0
    msk = np.expand_dims(msk, 0)
    return torch.from_numpy(img), torch.from_numpy(msk)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_ds = PairDataset(IMG_DIR, MSK_DIR, size=320)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)
model = U2NET(3,1).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
def bce_loss(pred, target):
  return F.binary_cross_entropy_with_logits(pred, target)

EPOCHS = 10
for epoch in range(EPOCHS):
  model.train()
  pbar = tqdm(train_dl, desc=f'Epoch {epoch+1}/{EPOCHS}')
  total = 0.0
  for x, y in pbar:
    x, y = x.to(device), y.to(device)
    opt.zero_grad()
    outs = model(x)
    if isinstance(outs, (list, tuple)):
      logits = outs[0]
    else:
      logits = outs
    loss = bce_loss(logits, y)
    loss.backward()
    opt.step()
    total += loss.item() * x.size(0)
    pbar.set_postfix(loss=loss.item())
  print('epoch_loss', total/len(train_ds))
torch.save(model.state_dict(), SAVE_PATH)
print('Saved to', SAVE_PATH)
