Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

实际运行中,容易内存爆掉 #4

Closed
zcong17huang opened this issue Jul 28, 2020 · 5 comments
Closed

实际运行中,容易内存爆掉 #4

zcong17huang opened this issue Jul 28, 2020 · 5 comments

Comments

@zcong17huang
Copy link

感谢大佬的杰出贡献,速度上确实提升不少。但是在实际运行的时候,我遇到了一些问题。我在我的网络中插入您的SSIMloss,运行代码后,每次都在运行到一半的时候(开始可以运行),内存直接占满爆掉。无论是服务器还是自己的电脑,都是这样。在修改了dataloader的各种参数后,还是出现这种情况。后来我将SSIMloss改为您代码中的loss3,这种情况就没有了。我不太了解您代码的运行情况,但是在实际运行中确实遇到了这个问题,所以向您反馈一下,不知道原因在哪里。

@One-sixth
Copy link
Owner

One-sixth commented Aug 7, 2020

@zcong17huang 看起来似乎是内存泄漏的问题。内存增长时爆掉时,显存占用情况如何,是随着内存增长而增长还是基本不变?还请大佬提供一下 pytorch 版本,还有系统版本,CUDA版本,CUDNN版本的信息。可以的话,最好还提供下相关的训练代码。

@zcong17huang
Copy link
Author

显存是没有变化的,就是内存到后面就爆了。
windows系统跟服务器都试过,windows上,我用的版本是python3.5.6,pytorch1.2.0,,CUDA10.1, CUDNN7.6.0
训练代码是这样的:

start_full_time = time.time()
model.train()
last_loss = 1000000

for epoch in range(epoch_start, args.epochs+1):
    log.info('This is %d-th epoch, learning rate : %f '%(epoch, scheduler.get_lr()[0]))
    total_train_loss = 0
    length_loader = len(TrainImgLoader)
    start_time = time.time()

    ## training ##
    for batch_idx, (data_img, data_label, data_gt, data_input) in enumerate(TrainImgLoader):
        # --------------------------------------------------
        batch_time = time.time()
        optimizer.zero_grad()

        data_img = data_img.float().to(device)
        data_label = data_label.float().to(device)
        data_gt = data_gt.float().to(device)
        data_input = data_input.float().to(device)
        # print(data_img.shape, data_label.shape, data_gt.shape, data_input.shape)

        _, output1, output2, output3 = model(data_img, data_label, data_input)

        data_gt = data_gt*255.0         #放缩到0~255的深度范围
        output1 = output1*255.0
        output2 = output2*255.0
        output3 = output3*255.0

        loss1 = nn.SmoothL1Loss()
        loss2 = losses.SSIM(data_range=255., channel=1).to(device)   # 这个是大佬您的SSIM loss
        loss3 = losses.SemanticBoundaryLoss(device)

        loss_1 = loss1(data_gt, output1) \
               + 0.1 * (1-loss2(data_gt, output1)) \
               + 0.1 * loss3(data_label, output1)
        loss_2 = loss1(data_gt, output2) \
               + 0.1 * (1-loss2(data_gt, output2)) \
               + 0.1 * loss3(data_label, output2)
        loss_3 = loss1(data_gt, output3) \
                 + 0.1 * (1 - loss2(data_gt, output3)) \
                 + 0.1 * loss3(data_label, output3)

        loss = 0.6*loss_1 + 0.8*loss_2 + loss_3

        loss.backward()
        optimizer.step()
        loss = loss.item()

        EPE_error = torch.mean(torch.abs(data_gt-output3))  # end-point-error

        writer_name = args.logpath.split('/')[-1]
        writer.add_scalar(writer_name, loss, (batch_idx + (epoch * length_loader)))
        writer.close()
        # -----------------------------------------------
        total_train_loss += loss

    log.info('epoch %d total training loss = %.5f, time = %.4f Hours' %(epoch, total_train_loss/length_loader, (time.time() - start_time)/3600))
    scheduler.step()

我的loss都取的值,不应该内存爆掉啊。不知道问题在哪里,麻烦大佬看一看。

@One-sixth
Copy link
Owner

@zcong17huang 看了代码,怀疑是 pytorch 1.2 的 jit 模块释放时的内存可能会存在泄漏问题。有没有试过最新版本pytorch?如果只能使用 pytorch 1.2,你可以在训练循环外面仅初始一次这个SSIM 模块,再重复使用这个模块来避免这个问题。以下是修改代码,看看还有没有问题。

start_full_time = time.time()
model.train()
last_loss = 1000000

# 这个就放在外面初始化一次就行了,可以重复使用。
loss2 = losses.SSIM(data_range=255., channel=1).to(device)   # 这个是大佬您的SSIM loss

for epoch in range(epoch_start, args.epochs+1):
    log.info('This is %d-th epoch, learning rate : %f '%(epoch, scheduler.get_lr()[0]))
    total_train_loss = 0
    length_loader = len(TrainImgLoader)
    start_time = time.time()

    ## training ##
    for batch_idx, (data_img, data_label, data_gt, data_input) in enumerate(TrainImgLoader):
        # --------------------------------------------------
        batch_time = time.time()
        optimizer.zero_grad()

        data_img = data_img.float().to(device)
        data_label = data_label.float().to(device)
        data_gt = data_gt.float().to(device)
        data_input = data_input.float().to(device)
        # print(data_img.shape, data_label.shape, data_gt.shape, data_input.shape)

        _, output1, output2, output3 = model(data_img, data_label, data_input)

        data_gt = data_gt*255.0         #放缩到0~255的深度范围
        output1 = output1*255.0
        output2 = output2*255.0
        output3 = output3*255.0

        loss1 = nn.SmoothL1Loss()
        # 不要在训练循环里面重复生成新的模块。
        # loss2 = losses.SSIM(data_range=255., channel=1).to(device)   # 这个是大佬您的SSIM loss
        loss3 = losses.SemanticBoundaryLoss(device)

        loss_1 = loss1(data_gt, output1) \
               + 0.1 * (1-loss2(data_gt, output1)) \
               + 0.1 * loss3(data_label, output1)
        loss_2 = loss1(data_gt, output2) \
               + 0.1 * (1-loss2(data_gt, output2)) \
               + 0.1 * loss3(data_label, output2)
        loss_3 = loss1(data_gt, output3) \
                 + 0.1 * (1 - loss2(data_gt, output3)) \
                 + 0.1 * loss3(data_label, output3)

        loss = 0.6*loss_1 + 0.8*loss_2 + loss_3

        loss.backward()
        optimizer.step()
        loss = loss.item()

        EPE_error = torch.mean(torch.abs(data_gt-output3))  # end-point-error

        writer_name = args.logpath.split('/')[-1]
        writer.add_scalar(writer_name, loss, (batch_idx + (epoch * length_loader)))
        writer.close()
        # -----------------------------------------------
        total_train_loss += loss

    log.info('epoch %d total training loss = %.5f, time = %.4f Hours' %(epoch, total_train_loss/length_loader, (time.time() - start_time)/3600))
    scheduler.step()

@zcong17huang
Copy link
Author

因为当时电脑的环境是pytorch 1.2,没有再更换版本实验了。我不太懂jit这一块的运行机制,不过看您的代码跟其他SSIM的区别好像就是jit这一块,所以估计应该是jit模块与我所用的pytorch不兼容的问题,后面我要是再做实验了,会及时反馈的。感谢大佬及时解答疑惑!

@One-sixth
Copy link
Owner

@zcong17huang 那我先关闭这个issue了,等你有进展再打开把。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants