In [1]:
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import glob

from clearbg.utils.common import read_yaml, create_directories
from clearbg.constants import PROJECT_ROOT
from clearbg.config.configuration import ConfigurationManager
from clearbg.utils.utils import RescaleT, RandomCrop, ToTensorLab, SalObjDataset
from clearbg.model.u2net import U2NET

In [15]:
# Initialize ConfigurationManager
config_manager = ConfigurationManager()
data_ingestion_config = config_manager.get_data_ingestion_config()
training_config = config_manager.get_training_config()

# ------- 1. Define loss function --------
bce_loss = nn.BCELoss(reduction='mean')

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    loss0 = bce_loss(d0, labels_v)
    loss1 = bce_loss(d1, labels_v)
    loss2 = bce_loss(d2, labels_v)
    loss3 = bce_loss(d3, labels_v)
    loss4 = bce_loss(d4, labels_v)
    loss5 = bce_loss(d5, labels_v)
    loss6 = bce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
        loss0.data.item(), loss1.data.item(), loss2.data.item(),
        loss3.data.item(), loss4.data.item(), loss5.data.item(),
        loss6.data.item()))

    return loss0, loss

# ------- 2. Set the directory of training dataset --------
data_dir = data_ingestion_config.root_dir
tra_image_dir = os.path.join('DUTS-TR', 'DUTS-TR-Image' + os.sep)
tra_label_dir = os.path.join('DUTS-TR', 'DUTS-TR-Mask' + os.sep)
image_ext = '.jpg'
label_ext = '.png'

# Set model save directory based on training configuration
model_dir = PROJECT_ROOT / training_config.root_dir / 'saved_models' / 'u2net'
create_directories([model_dir])  # Create the model directory if it doesn't exist

epoch_num = training_config.epochs
batch_size_train = training_config.batch_size

tra_img_name_list = glob.glob(os.path.join(data_dir, tra_image_dir, '*' + image_ext))

tra_lbl_name_list = []
for img_path in tra_img_name_list:
    img_name = img_path.split(os.sep)[-1]
    imidx = ".".join(img_name.split(".")[:-1])
    tra_lbl_name_list.append(os.path.join(data_dir, tra_label_dir, imidx + label_ext))

print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")

train_num = len(tra_img_name_list)

salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
        RescaleT(training_config.image_size),
        RandomCrop(288),
        ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)

# ------- 3. Define model --------
# Define the net
model_name = 'u2net'  # Change to 'u2netp' if needed
net = U2NET(3, 1)

if torch.cuda.is_available():
    net.cuda()

# ------- 4. Define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=training_config.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

# ------- 5. Training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000  # Save the model every 2000 iterations

for epoch in range(epoch_num):
    net.train()

    for i, data in enumerate(salobj_dataloader):
        ite_num += 1
        ite_num4val += 1

        inputs, labels = data['image'], data['label']
        inputs = inputs.type(torch.FloatTensor)
        labels = labels.type(torch.FloatTensor)

        # Wrap them in Variable
        if torch.cuda.is_available():
            inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
        else:
            inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
        loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.data.item()
        running_tar_loss += loss2.data.item()

        # Del temporary outputs and loss
        del d0, d1, d2, d3, d4, d5, d6, loss2, loss

        print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
            epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num,
            running_loss / ite_num4val, running_tar_loss / ite_num4val))

        if ite_num % save_frq == 0:
            torch.save(net.state_dict(), os.path.join(model_dir, model_name + "_bce_itr_%d_train_%3f_tar_%3f.pth" % (
                ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)))
            running_loss = 0.0
            running_tar_loss = 0.0
            net.train()  # Resume training
            ite_num4val = 0


[2024-10-11 00:50:50,120: INFO: common: yaml file: C:\dev\ml\projects\MLOPS_Project\config\config.yaml loaded successfully]
[2024-10-11 00:50:50,127: INFO: common: yaml file: C:\dev\ml\projects\MLOPS_Project\params.yaml loaded successfully]
[2024-10-11 00:50:50,131: INFO: common: created directory at: artifacts]
[2024-10-11 00:50:50,134: INFO: common: created directory at: C:\dev\ml\projects\MLOPS_Project\artifacts\data_ingestion]
DataIngestionConfig(root_dir=WindowsPath('C:/dev/ml/projects/MLOPS_Project/artifacts/data_ingestion'), train_source_url='http://saliencydetection.net/duts/download/DUTS-TR.zip', test_source_url='http://saliencydetection.net/duts/download/DUTS-TE.zip', train_local_zipped_path=WindowsPath('C:/dev/ml/projects/MLOPS_Project/artifacts/data_ingestion/DUTS-TR.zip'), test_local_zipped_path=WindowsPath('C:/dev/ml/projects/MLOPS_Project/artifacts/data_ingestion/DUTS-TE.zip'))
{'root_dir': 'artifacts/training'}
[2024-10-11 00:50:50,141: INFO: common: created directory a

KeyboardInterrupt: 