In [2]:
import sys
import os
import time
import torch.nn as nn
from collections import OrderedDict
from dataloader import create_dataloader
from options.base_options import Options
from trainers.train_manager import TrainManager
from util.epoch_counter import EpochCounter
from util.visualizer import Visualizer
from util.util import preprocess_train_data
from models.SpadeGAN import SpadeGAN
from models.networks.loss import GANLoss, KLDLoss

In [3]:
opt = Options()

os.environ['CUDA_VISIBLE_DEVICES'] = ",".join((str(gpu_id) for gpu_id in opt.gpu_ids))

# load the dataset
dataloader = create_dataloader(opt)

# create trainer for our model
trainer = TrainManager(opt)

# create tool for counting iterations
epoch_counter = EpochCounter(opt)

# create tool for visualization
visualizer = Visualizer(opt)

dataset ADE20K was created
Could not load iteration record at /chenxiao/FinalDesign/checkpoints/ade20k/iter.txt.
 Starting from beginning.


In [4]:
spade_gan = SpadeGAN(opt)
gan_loss = GANLoss()
kld_loss = KLDLoss()

if len(opt.gpu_ids) > 0:
    # https://www.zhihu.com/question/67726969/answer/389980788
    spade_gan = nn.DataParallel(spade_gan).cuda()
    gan_loss = nn.DataParallel(gan_loss).cuda()
    kld_loss = nn.DataParallel(kld_loss).cuda()

optG, optD = trainer.create_optimizers(opt, spade_gan)

In [5]:
for epoch in epoch_counter.training_epochs():
    epoch_counter.record_epoch_start(epoch)
    for batch_id, data_i in enumerate(dataloader):
        iter_start_time = time.time()

        data_i = preprocess_train_data(data_i, opt)
        trainer.run_generator_one_step(data_i, spade_gan, gan_loss, kld_loss, optG)
        trainer.run_discriminator_one_step(data_i, spade_gan, gan_loss, optD)

        running_time = time.time() - iter_start_time
        visualizer.print_current_errors(epoch, batch_id, running_time, losses)

        if batch_id % 200 == 0:
            generated_imgs = trainer.get_latest_generated()
            visualizer.save_images(epoch, batch_id, labels, real_imgs, generated_imgs)

    epoch_counter.record_epoch_end()
    trainer.save(epoch)
    trainer.update_learning_rate(epoch)

print('Training was successfully finished.', flush=True)

  0%|          | 0/200 [00:01<?, ?it/s]


RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable (malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:291)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f37c4502193 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1c4cd (0x7f37c47444cd in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1cd5e (0x7f37c4744d5e in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #3: THCStorage_resize + 0xa3 (0x7f37c90006f3 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #4: at::native::empty_strided_cuda(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions const&) + 0x636 (0x7f37ca5ce856 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x45bcd2a (0x7f37c8f11d2a in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #6: <unknown function> + 0x1f4fc81 (0x7f37c68a4c81 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #7: <unknown function> + 0x3aadfb0 (0x7f37c8402fb0 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x1f4fc81 (0x7f37c68a4c81 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #9: <unknown function> + 0x1cb869e (0x7f37c660d69e in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #10: at::native::to(at::Tensor const&, c10::TensorOptions const&, bool, bool, c10::optional<c10::MemoryFormat>) + 0x245 (0x7f37c660e6f5 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #11: <unknown function> + 0x1ffdb9a (0x7f37c6952b9a in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #12: <unknown function> + 0x3ce3866 (0x7f37c8638866 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #13: <unknown function> + 0x20485e2 (0x7f37c699d5e2 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #14: torch::cuda::scatter(at::Tensor const&, c10::ArrayRef<long>, c10::optional<std::vector<long, std::allocator<long> > > const&, long, c10::optional<std::vector<c10::optional<c10::cuda::CUDAStream>, std::allocator<c10::optional<c10::cuda::CUDAStream> > > > const&) + 0x710 (0x7f37c930bf60 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #15: <unknown function> + 0x9d7203 (0x7f380f8cd203 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #16: <unknown function> + 0x2961c4 (0x7f380f18c1c4 in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #17: _PyCFunction_FastCallDict + 0x154 (0x563de3e184f4 in /usr/local/miniconda3/envs/dl/bin/python)
frame #18: <unknown function> + 0x198dac (0x563de3e9fdac in /usr/local/miniconda3/envs/dl/bin/python)
frame #19: _PyEval_EvalFrameDefault + 0x30a (0x563de3ec266a in /usr/local/miniconda3/envs/dl/bin/python)
frame #20: <unknown function> + 0x192274 (0x563de3e99274 in /usr/local/miniconda3/envs/dl/bin/python)
frame #21: <unknown function> + 0x1930f1 (0x563de3e9a0f1 in /usr/local/miniconda3/envs/dl/bin/python)
frame #22: <unknown function> + 0x198e85 (0x563de3e9fe85 in /usr/local/miniconda3/envs/dl/bin/python)
frame #23: _PyEval_EvalFrameDefault + 0x30a (0x563de3ec266a in /usr/local/miniconda3/envs/dl/bin/python)
frame #24: PyEval_EvalCodeEx + 0x329 (0x563de3e9ac09 in /usr/local/miniconda3/envs/dl/bin/python)
frame #25: <unknown function> + 0x194a24 (0x563de3e9ba24 in /usr/local/miniconda3/envs/dl/bin/python)
frame #26: PyObject_Call + 0x3e (0x563de3e182fe in /usr/local/miniconda3/envs/dl/bin/python)
frame #27: THPFunction_apply(_object*, _object*) + 0xa8f (0x7f380f55c82f in /usr/local/miniconda3/envs/dl/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #28: _PyCFunction_FastCallDict + 0x91 (0x563de3e18431 in /usr/local/miniconda3/envs/dl/bin/python)
frame #29: <unknown function> + 0x198dac (0x563de3e9fdac in /usr/local/miniconda3/envs/dl/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x30a (0x563de3ec266a in /usr/local/miniconda3/envs/dl/bin/python)
frame #31: <unknown function> + 0x19257e (0x563de3e9957e in /usr/local/miniconda3/envs/dl/bin/python)
frame #32: _PyFunction_FastCallDict + 0x1be (0x563de3e9a5ce in /usr/local/miniconda3/envs/dl/bin/python)
frame #33: _PyObject_FastCallDict + 0x26f (0x563de3e188bf in /usr/local/miniconda3/envs/dl/bin/python)
frame #34: <unknown function> + 0x12cf22 (0x563de3e33f22 in /usr/local/miniconda3/envs/dl/bin/python)
frame #35: PyIter_Next + 0xe (0x563de3e5a4ae in /usr/local/miniconda3/envs/dl/bin/python)
frame #36: PySequence_Tuple + 0xf9 (0x563de3e5eec9 in /usr/local/miniconda3/envs/dl/bin/python)
frame #37: _PyEval_EvalFrameDefault + 0x545f (0x563de3ec77bf in /usr/local/miniconda3/envs/dl/bin/python)
frame #38: <unknown function> + 0x19257e (0x563de3e9957e in /usr/local/miniconda3/envs/dl/bin/python)
frame #39: _PyFunction_FastCallDict + 0x1be (0x563de3e9a5ce in /usr/local/miniconda3/envs/dl/bin/python)
frame #40: _PyObject_FastCallDict + 0x26f (0x563de3e188bf in /usr/local/miniconda3/envs/dl/bin/python)
frame #41: <unknown function> + 0x12cf22 (0x563de3e33f22 in /usr/local/miniconda3/envs/dl/bin/python)
frame #42: PyIter_Next + 0xe (0x563de3e5a4ae in /usr/local/miniconda3/envs/dl/bin/python)
frame #43: PySequence_Tuple + 0xf9 (0x563de3e5eec9 in /usr/local/miniconda3/envs/dl/bin/python)
frame #44: _PyEval_EvalFrameDefault + 0x545f (0x563de3ec77bf in /usr/local/miniconda3/envs/dl/bin/python)
frame #45: <unknown function> + 0x19257e (0x563de3e9957e in /usr/local/miniconda3/envs/dl/bin/python)
frame #46: <unknown function> + 0x1930f1 (0x563de3e9a0f1 in /usr/local/miniconda3/envs/dl/bin/python)
frame #47: <unknown function> + 0x198e85 (0x563de3e9fe85 in /usr/local/miniconda3/envs/dl/bin/python)
frame #48: _PyEval_EvalFrameDefault + 0x30a (0x563de3ec266a in /usr/local/miniconda3/envs/dl/bin/python)
frame #49: <unknown function> + 0x19257e (0x563de3e9957e in /usr/local/miniconda3/envs/dl/bin/python)
frame #50: <unknown function> + 0x1930f1 (0x563de3e9a0f1 in /usr/local/miniconda3/envs/dl/bin/python)
frame #51: <unknown function> + 0x198e85 (0x563de3e9fe85 in /usr/local/miniconda3/envs/dl/bin/python)
frame #52: _PyEval_EvalFrameDefault + 0x30a (0x563de3ec266a in /usr/local/miniconda3/envs/dl/bin/python)
frame #53: <unknown function> + 0x192274 (0x563de3e99274 in /usr/local/miniconda3/envs/dl/bin/python)
frame #54: <unknown function> + 0x1930f1 (0x563de3e9a0f1 in /usr/local/miniconda3/envs/dl/bin/python)
frame #55: <unknown function> + 0x198e85 (0x563de3e9fe85 in /usr/local/miniconda3/envs/dl/bin/python)
frame #56: _PyEval_EvalFrameDefault + 0x10c8 (0x563de3ec3428 in /usr/local/miniconda3/envs/dl/bin/python)
frame #57: <unknown function> + 0x192ebb (0x563de3e99ebb in /usr/local/miniconda3/envs/dl/bin/python)
frame #58: <unknown function> + 0x198e85 (0x563de3e9fe85 in /usr/local/miniconda3/envs/dl/bin/python)
frame #59: _PyEval_EvalFrameDefault + 0x30a (0x563de3ec266a in /usr/local/miniconda3/envs/dl/bin/python)
frame #60: <unknown function> + 0x192274 (0x563de3e99274 in /usr/local/miniconda3/envs/dl/bin/python)
frame #61: _PyFunction_FastCallDict + 0x3d8 (0x563de3e9a7e8 in /usr/local/miniconda3/envs/dl/bin/python)
frame #62: _PyObject_FastCallDict + 0x26f (0x563de3e188bf in /usr/local/miniconda3/envs/dl/bin/python)
frame #63: _PyObject_Call_Prepend + 0x63 (0x563de3e1d313 in /usr/local/miniconda3/envs/dl/bin/python)
