In [1]:
from dataset import *
from model import *
from tqdm import tqdm
import time
import datetime
import torch.optim as optim
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from omegaconf import OmegaConf

In [2]:
cfg = OmegaConf.load("conf/config.yaml")

In [3]:
resolution = cfg.model.resolution
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
if cfg.dataset.clevr:
    train_set = CLEVR(cfg.dataset.clevr_path, 'train')
    test_set = CLEVR(cfg.dataset.clevr_path, 'test')
    val_set = CLEVR(cfg.dataset.clevr_path, "val")
else:
    MY = MyDataset(cfg.dataset.my_path, cfg.dataset.train_rate, cfg.dataset.test_rate)
    test_set = MY.get_split_data("test")
    train_set = MY.get_split_data("train")
    val_set = MY.get_split_data("val")

In [5]:
def renormalize(x):
    return x / 2. + 0.5


def get_prediction(model2, batch, idx=0):
    recon_combined, recons, masks, slots = model2(batch["image"].to(device))
    image = renormalize(batch["image"].to(device))[idx]
    recon_combined = renormalize(recon_combined)[idx]
    recons = renormalize(recons)[idx]
    masks = masks[idx]
    return image, recon_combined, recons, masks, slots

In [6]:
model = SlotAttentionAutoEncoder(cfg.model.resolution, cfg.slot.num_slots, cfg.slot.num_iterations, cfg.model.hid_dim).to(device)

criterion = nn.MSELoss()

params = [{'params': model.parameters()}]

train_dataloader = torch.utils.data.DataLoader(train_set,
                                               batch_size=cfg.dataset.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.dataset.num_workers)

test_dataloader = torch.utils.data.DataLoader(test_set,
                                              batch_size=cfg.dataset.batch_size,
                                              shuffle=True,
                                              num_workers=cfg.dataset.num_workers)

optimizer = optim.Adam(params, lr=cfg.optim.learning_rate)


In [None]:
start = time.time()
writer2 = SummaryWriter(cfg.tensorboard.logs_path)
for epoch in range(cfg.param.num_epochs):
    model.train()

    total_loss = 0

    for i, sample in tqdm(enumerate(train_dataloader, 1),
                          total=len(train_dataloader)):
        
        if i < cfg.param.warmup_steps:
            #学習率のWarmup
            learning_rate = cfg.optim.learning_rate * (i / cfg.param.warmup_steps)
        else:
            # ウォームアップ後の学習率
            learning_rate = cfg.optim.learning_rate
        # 指数的な学習率減衰
        learning_rate = learning_rate * (cfg.param.decay_rate ** (
            i / cfg.param.decay_steps))

        optimizer.param_groups[0]['lr'] = learning_rate
        image = sample['image'].to(device)
        recon_combined, recons, masks, slots = model(image)
        loss = criterion(recon_combined, image)
        total_loss += loss.item()

        del recons, masks, slots

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss /= len(train_dataloader)

    print("Epoch: {}, Loss: {}, Time: {}". format(epoch, total_loss, datetime.timedelta(seconds=time. time() - start)))
    writer2.add_scalar("loss", total_loss, epoch)
    writer2.add_scalar("lr", learning_rate, epoch)
    if not epoch % 10:
        torch.save({
            'model_state_dict': model.state_dict(),
            }, cfg.model.model_dir)
        writer = SummaryWriter(cfg.tensorboard.img_path+str(epoch))
        dataiter = iter(test_dataloader)
        images = next(dataiter)
        model.eval()
        image, recon_combined, recons, masks, slots = get_prediction(model, images)
        num_slots = len(masks)
        image = torchvision.utils.make_grid(image.to("cpu"), nrow=1)
        writer.add_image('Slot_images', image)
        writer.add_image('Recon', recon_combined.to("cpu").detach().numpy())
        for i in range(num_slots):
            s_m = np.clip((recons[i].to("cpu").detach().numpy()*masks[i].to("cpu").detach().numpy())+(1-masks[i].to("cpu").detach().numpy()),0, 1)
            writer.add_image('Slot '+str(i+1), torch.from_numpy(s_m.astype(np.float32)).permute(2,0,1))

100%|███████████████████████████████████████| 4375/4375 [10:48<00:00,  6.75it/s]


Epoch: 0, Loss: 1.2774966580308282e-06, Time: 0:10:48.686520


100%|███████████████████████████████████████| 4375/4375 [10:46<00:00,  6.76it/s]

Epoch: 1, Loss: 9.516004337700822e-07, Time: 0:21:37.495470



100%|███████████████████████████████████████| 4375/4375 [10:46<00:00,  6.76it/s]

Epoch: 2, Loss: 5.249217398295097e-07, Time: 0:32:24.558441



100%|███████████████████████████████████████| 4375/4375 [10:46<00:00,  6.77it/s]

Epoch: 3, Loss: 5.466615756732831e-07, Time: 0:43:11.288498



100%|███████████████████████████████████████| 4375/4375 [10:46<00:00,  6.77it/s]

Epoch: 4, Loss: 5.397559888750917e-07, Time: 0:53:57.825908



100%|███████████████████████████████████████| 4375/4375 [10:46<00:00,  6.77it/s]

Epoch: 5, Loss: 5.518607120426999e-07, Time: 1:04:44.190604



100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.77it/s]

Epoch: 6, Loss: 4.423916929122698e-07, Time: 1:15:30.270784



100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.77it/s]

Epoch: 7, Loss: 4.444095597799197e-07, Time: 1:26:16.392487



100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.78it/s]

Epoch: 8, Loss: 5.799007680411204e-07, Time: 1:37:02.212963



100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.78it/s]

Epoch: 9, Loss: 3.7857332921934764e-07, Time: 1:47:47.920225



100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.78it/s]


Epoch: 10, Loss: 3.056296122111416e-07, Time: 1:58:33.551721


100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.78it/s]

Epoch: 11, Loss: 2.844614772873598e-07, Time: 2:09:21.154374



100%|███████████████████████████████████████| 4375/4375 [10:45<00:00,  6.78it/s]

Epoch: 12, Loss: 4.2413528275932784e-07, Time: 2:20:06.660335



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.78it/s]

Epoch: 13, Loss: 3.214089901404727e-07, Time: 2:30:51.826728



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.78it/s]

Epoch: 14, Loss: 3.7804890757278224e-07, Time: 2:41:36.994688



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 15, Loss: 4.700899743020361e-07, Time: 2:52:21.881604



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 16, Loss: 2.9751378246691106e-07, Time: 3:03:06.606550



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 17, Loss: 3.086757526069257e-07, Time: 3:13:51.395877



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 18, Loss: 3.3068326709516234e-07, Time: 3:24:36.005747



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 19, Loss: 3.5583861994792723e-07, Time: 3:35:20.438764



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]


Epoch: 20, Loss: 2.7393936063200176e-07, Time: 3:46:04.858970


100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 21, Loss: 2.6099202087207027e-07, Time: 3:56:51.914613



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 22, Loss: 2.0516937880931438e-07, Time: 4:07:36.232115



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 23, Loss: 2.649304846577937e-07, Time: 4:18:20.601394



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 24, Loss: 2.828700889703788e-07, Time: 4:29:04.951268



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 25, Loss: 3.239697686522734e-07, Time: 4:39:49.311702



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 26, Loss: 3.0232180588289146e-07, Time: 4:50:33.293007



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.79it/s]

Epoch: 27, Loss: 3.0003308343202855e-07, Time: 5:01:17.404605



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.79it/s]

Epoch: 28, Loss: 2.562719231768952e-07, Time: 5:12:01.567157



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 29, Loss: 2.5732854770411785e-07, Time: 5:22:45.811787



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]


Epoch: 30, Loss: 4.899761419349347e-07, Time: 5:33:30.109163


100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.79it/s]

Epoch: 31, Loss: 2.3479893823609481e-07, Time: 5:44:17.155084



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 32, Loss: 2.7732936493875103e-07, Time: 5:55:01.128663



100%|███████████████████████████████████████| 4375/4375 [10:44<00:00,  6.79it/s]

Epoch: 33, Loss: 2.6972498084242855e-07, Time: 6:05:45.357135



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 34, Loss: 2.5916665802815454e-07, Time: 6:16:29.209535



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 35, Loss: 2.1639592059107857e-07, Time: 6:27:13.077522



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 36, Loss: 2.3728171605973961e-07, Time: 6:37:57.169584



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 37, Loss: 2.552734480899691e-07, Time: 6:48:41.119636



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 38, Loss: 2.4650919165424733e-07, Time: 6:59:25.067071



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.79it/s]

Epoch: 39, Loss: 2.523721033884323e-07, Time: 7:10:09.139818



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]


Epoch: 40, Loss: 2.413705793540796e-07, Time: 7:20:52.933437


100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 41, Loss: 2.484467404650256e-07, Time: 7:31:38.461882



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 42, Loss: 3.134096648830857e-07, Time: 7:42:22.056506



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 43, Loss: 2.6585907963441536e-07, Time: 7:53:05.802538



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 44, Loss: 2.3720468019257285e-07, Time: 8:03:49.357177



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 45, Loss: 3.2020264290799346e-07, Time: 8:14:33.063773



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 46, Loss: 2.1204647486023024e-07, Time: 8:25:16.660650



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 47, Loss: 2.526607452747274e-07, Time: 8:36:00.323829



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 48, Loss: 3.884043984830495e-07, Time: 8:46:43.899647



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 49, Loss: 2.902341982953537e-07, Time: 8:57:27.517787



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]


Epoch: 50, Loss: 2.4314053156243397e-07, Time: 9:08:11.064798


100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 51, Loss: 1.926954810335253e-07, Time: 9:18:56.954073



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 52, Loss: 1.7586543931354388e-07, Time: 9:29:40.410250



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 53, Loss: 2.350906565695793e-07, Time: 9:40:23.826104



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 54, Loss: 3.297301566749191e-07, Time: 9:51:07.293274



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 55, Loss: 2.6083336203851724e-07, Time: 10:01:50.803868



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 56, Loss: 1.8251504280510247e-07, Time: 10:12:34.404357



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 57, Loss: 2.074412583013273e-07, Time: 10:23:17.925549



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 58, Loss: 2.8835624937269357e-07, Time: 10:34:01.515049



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 59, Loss: 2.600143143330045e-07, Time: 10:44:45.117117



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]


Epoch: 60, Loss: 3.4599724095823715e-07, Time: 10:55:28.470087


100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 61, Loss: 2.5595016235208567e-07, Time: 11:06:14.200975



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 62, Loss: 2.2788705842076137e-07, Time: 11:16:57.722028



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 63, Loss: 1.9391193849088387e-07, Time: 11:27:41.236904



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 64, Loss: 2.2675515996121372e-07, Time: 11:38:24.690821



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 65, Loss: 3.0038297294284034e-07, Time: 11:49:07.960386



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 66, Loss: 2.556795414636817e-07, Time: 11:59:51.386515



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 67, Loss: 2.7940739916188205e-07, Time: 12:10:34.828976



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 68, Loss: 2.4545200610543013e-07, Time: 12:21:18.337697



100%|███████████████████████████████████████| 4375/4375 [10:43<00:00,  6.80it/s]

Epoch: 69, Loss: 2.1377653620490035e-07, Time: 12:32:01.723006



 50%|███████████████████▌                   | 2199/4375 [05:23<05:20,  6.78it/s]