In [1]:
from config import *
from crnet_encoder_model import CRNETENCODER
from crnet_encoder_dataset import CRNETENCODERDataset
from crnet_encoder_loss import CRNETENCODERLoss

In [2]:
image_dir = r'/autodl-fs/data/Images/Train'
gt_dirs = [ r'/autodl-fs/data/totaltext/Tf',
            r'/autodl-fs/data/totaltext/Tc',
            r'/autodl-fs/data/totaltext/Xoffset',
            r'/autodl-fs/data/totaltext/Yoffset'
          ]

In [3]:
print(os.getcwd())

/autodl-fs/data/crnet2


In [8]:
def train_model():
    # 超参数设置
    num_epochs = 100
    learning_rate = 0.0001
    batch_size = 4
    model_num = 8

    # 加载数据集
    dataset = CRNETENCODERDataset(image_dir, gt_dirs, target_size=(128, 128))  # target_size=(160, 160) 用于加载crnet的gt计算loss，目前未使用
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 定义模型
    if model_num == 0 :
        model = CRNETENCODER().to(device)
    else:
        model = torch.load(f'/autodl-fs/data/crnet2/CE_model{model_num}.pth')

    # 定义loss权重
    weight_dict = {'loss_all': 2, 'loss_crnet': 0.5}
    loss_fn = CRNETENCODERLoss()
    loss_fn.to(device)

    # 训练模型
    for epoch in range(model_num, num_epochs):
        if epoch > 10:
            learning_rate *= 0.5
            
        # 定义优化器
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        model.train()

        total_loss = 0
        start_time = time.time()
        print('***********************************************************************')
        print(f'start training epoch{epoch + 1}')
        for batch_idx, (images, gts) in enumerate(dataloader):
            images = images.to(device)
            optimizer.zero_grad()

            # 前向传播
            outputs = model(images)

            # 计算损失
            loss_dict = loss_fn(outputs, gts)
            losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
            total_loss += losses.item()

            # 反向传播和优化
            losses.backward()
            
            # 检查梯度消失
            num_layer = 0
            if epoch >= 0:
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        if 'bias' in name:
                            pass
                        else:
                            grad_norm = torch.norm(param.grad)
                            # print(f'Layer: {name}, Gradient norm: {grad_norm.item()}')
                            if grad_norm.item() < 1e-6:
                                print(f'Step {batch_idx + 1} Warning: Gradient vanishing detected in layer {name}')
                                num_layer += 1

            # 梯度裁剪
            # torch.nn.utils.clip_grad_norm_(detr.parameters(), max_norm=1.0)

            optimizer.step()

            end_time = time.time()
            if (batch_idx + 1) % 50 == 0:
                print(f'Loss_dict : {loss_dict}')
                print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(dataloader)}], Loss: {losses.item():.4f} , Use Time: {end_time - start_time}s")

        end_epoch_time = time.time()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {total_loss / len(dataloader):.4f}, 1 Epoch Time {end_epoch_time - start_time} s")
        if (epoch + 1) % 5 == 0:
            torch.save(model, f'CE_model{epoch+1}.pth')
            print(f'Model {epoch+1} Already Saved')
        print('***********************************************************************')

    print("Training complete.")

In [9]:
train_model()

***********************************************************************
start training epoch9
Loss_dict : {'loss_all': tensor(0.4799, device='cuda:0', grad_fn=<AddBackward0>), 'loss_crnet': tensor(0.4605, device='cuda:0', grad_fn=<AddBackward0>)}
Epoch [9/100], Step [50/314], Loss: 1.1900 , Use Time: 36.915157318115234s
Loss_dict : {'loss_all': tensor(0.4072, device='cuda:0', grad_fn=<AddBackward0>), 'loss_crnet': tensor(0.3878, device='cuda:0', grad_fn=<AddBackward0>)}
Epoch [9/100], Step [100/314], Loss: 1.0084 , Use Time: 75.78193616867065s
Loss_dict : {'loss_all': tensor(0.9011, device='cuda:0', grad_fn=<AddBackward0>), 'loss_crnet': tensor(0.8873, device='cuda:0', grad_fn=<AddBackward0>)}
Epoch [9/100], Step [150/314], Loss: 2.2458 , Use Time: 111.70574760437012s
Loss_dict : {'loss_all': tensor(0.5991, device='cuda:0', grad_fn=<AddBackward0>), 'loss_crnet': tensor(0.5134, device='cuda:0', grad_fn=<AddBackward0>)}
Epoch [9/100], Step [200/314], Loss: 1.4550 , Use Time: 146.92118525