In [1]:
import os, sys
from tools.config import Config, read_conf
from train import generator
# from tools.dataset import NibDataset

In [2]:
import os

from tqdm import tqdm
import numpy as np
import itertools

from torchvision.utils import save_image
from torch.autograd import Variable

import torch.nn as nn
import torch

from torch.utils.tensorboard import SummaryWriter

In [3]:
config = read_conf(f'/root/cAAE/config/config_test.json')
config.train.batch_size = 2
dataset, img_shape = generator(config)

In [4]:
config.transforms += {'img_shape': img_shape}
config


_______
'path': '/root/HCP'
'result': '/root/result'
'local': 'T2w_restore'
'slice': ''
'cuda': True
'transforms': 
_______
'norm': True
'resize': True
'img_size': 32
'to_tensor': True
'img_shape': [32
32
311]
_______
'train': 
_______
'batch_size': 2
'latent_dim': 128
'n_epochs': 200
'lr': 0.0002
'b1': 0.5
'b2': 0.999
'sample_interval': 400
'max_batch': 42
_______
'root': '/root/cAAE/model'
_______

In [5]:
img_shape

[32, 32, 311]

In [6]:
class AAE:
    def __init__(self, config):
        self.config = config.train
        self.output = config.result
        #TODO: а если я не буду ресайзить?
        self.img_shape = config.transforms.img_shape
        self.img_shape[2] *= config.train.batch_size
        self.cuda = config.cuda and torch.cuda.is_available()
        print(f'\033[3{2 if self.cuda else 1}m[Cuda: {self.cuda}]\033[0m') 
        self.Tensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor
        self.config += {'Tensor': self.Tensor}

        # Use binary cross-entropy loss
        self.adversarial_loss = torch.nn.BCELoss()
        self.pixelwise_loss = torch.nn.L1Loss()

#         # Initialize generator and discriminator
        self.encoder = Encoder(self.config, self.img_shape)
        self.decoder = Decoder(self.config, self.img_shape)
        self.discriminator = Discriminator(self.config, self.img_shape)

        if self.cuda:
            self.encoder.cuda()
            self.decoder.cuda()
            self.discriminator.cuda()
            self.adversarial_loss.cuda()
            self.pixelwise_loss.cuda()
        
        # Optimizers
        self.optimizer_G = torch.optim.Adam(
            itertools.chain(
                self.encoder.parameters(), 
                self.decoder.parameters()
            ), 
            lr=self.config.lr, 
            betas=(self.config.b1, self.config.b2)
        )
        self.optimizer_D = torch.optim.Adam(
            self.discriminator.parameters(), 
            lr=self.config.lr, 
            betas=(self.config.b1, self.config.b2))
        
        #tensorboard callback
        self.writer = SummaryWriter(os.path.join(self.output, 'log'))

    def __repr__(self):
        return f'cuda: {self.cuda}\n' + \
            f'config: {self.config}'
    
    def __str__(self):
        return f'{self.__repr__()}\n' + \
            f'{self.encoder}\n{self.decoder}\n{self.discriminator}'
    #check
    def sample_image(self, n_row, batches_done):
#         assert False, 'check this'
        """Saves a grid of generated digits"""
        # Sample noise
        z = Variable(self.Tensor(np.random.normal(0, 1, (n_row ** 2, self.config.latent_dim))))
        gen_imgs = self.decoder(z)
        save_image(gen_imgs.data, os.path.join(self.output, f"{batches_done}.png"), nrow=n_row, normalize=True)

    def train(self, dataloader):
        for epoch in tqdm(range(self.config.n_epochs), total=self.config.n_epochs, desc='Epoch', leave=True):
#             self.running_loss_g = 0
#             self.running_loss_d = 0
            for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Bath'):
                if i >= self.config.max_batch:
                    break
                imgs = batch.permute(0, 3, 1, 2).reshape(-1, self.img_shape[0], self.img_shape[1])
                # Adversarial ground truths
                valid = Variable(self.Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(self.Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(self.Tensor))
                
                # -----------------
                #  Train Generator
                # -----------------

                self.optimizer_G.zero_grad()

                encoded_imgs = self.encoder(real_imgs)
                decoded_imgs = self.decoder(encoded_imgs)

                # Loss measures generator's ability to fool the discriminator
                g_loss = \
                    0.001 * self.adversarial_loss(self.discriminator(encoded_imgs), valid) + \
                    0.999 * self.pixelwise_loss(decoded_imgs, real_imgs)
                g_loss.backward()
                self.optimizer_G.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------

                self.optimizer_D.zero_grad()

                # Sample noise as discriminator ground truth
                z = Variable(self.Tensor(np.random.normal(0, 1, (imgs.shape[0], self.config.latent_dim))))

                # Measure discriminator's ability to classify real from generated samples
                real_loss = self.adversarial_loss(self.discriminator(z), valid)
                fake_loss = self.adversarial_loss(self.discriminator(encoded_imgs.detach()), fake)
                d_loss = 0.5 * (real_loss + fake_loss)

                d_loss.backward()
                self.optimizer_D.step()

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, self.config.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )

#                 batches_done = epoch * len(dataloader) + i
#                 if batches_done % self.config.sample_interval == 0:
#                     self.sample_image(n_row=10, batches_done=batches_done)
    
class Encoder(nn.Module):
    def __init__(self, config, img_shape):
        super(Encoder, self).__init__()
        self.config = config
        self.img_shape = img_shape[:2]

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.mu = nn.Linear(512, self.config.latent_dim)
        self.logvar = nn.Linear(512, self.config.latent_dim)

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        x = self.model(img_flat)
        mu = self.mu(x)
        logvar = self.logvar(x)
        z = self.reparameterization(mu, logvar)
        return z
    
    def reparameterization(self, mu, logvar):
        std = torch.exp(logvar / 2)
        sampled_z = Variable(self.config.Tensor(np.random.normal(0, 1, (mu.size(0), self.config.latent_dim))))
        z = sampled_z * std + mu
        return z


class Decoder(nn.Module):
    def __init__(self, config, img_shape):
        super(Decoder, self).__init__()
        self.config = config
        self.img_shape = img_shape[:2]

        self.model = nn.Sequential(
            nn.Linear(self.config.latent_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, int(np.prod(self.img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img_flat = self.model(z)
        img = img_flat.view(img_flat.shape[0], *self.img_shape)
        return img
    
class Discriminator(nn.Module):
    def __init__(self,  config, img_shape):
        super(Discriminator, self).__init__()
        self.config = config
        self.img_shape = img_shape[:2]
    
        self.model = nn.Sequential(
            nn.Linear(self.config.latent_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        validity = self.model(z)
        return validity


In [7]:
model = AAE(config)
model

[32m[Cuda: True][0m


cuda: True
config: 
_______
'batch_size': 2
'latent_dim': 128
'n_epochs': 200
'lr': 0.0002
'b1': 0.5
'b2': 0.999
'sample_interval': 400
'max_batch': 42
'Tensor': <class 'torch.cuda.FloatTensor'>
_______

In [8]:
model.train(dataset)

Epoch:   0%|          | 0/200 [00:00<?, ?it/s]
Bath:   0%|          | 0/500 [00:00<?, ?it/s][A
Bath:   0%|          | 1/500 [00:05<47:42,  5.74s/it][A

[Epoch 0/200] [Batch 0/500] [D loss: 0.690247] [G loss: 0.326890]



Bath:   0%|          | 2/500 [00:09<43:32,  5.25s/it][A

[Epoch 0/200] [Batch 1/500] [D loss: 0.682596] [G loss: 0.319934]



Bath:   1%|          | 3/500 [00:14<42:47,  5.17s/it][A

[Epoch 0/200] [Batch 2/500] [D loss: 0.674826] [G loss: 0.305022]



Bath:   1%|          | 4/500 [00:18<39:54,  4.83s/it][A

[Epoch 0/200] [Batch 3/500] [D loss: 0.665154] [G loss: 0.292123]



Bath:   1%|          | 5/500 [00:23<38:16,  4.64s/it][A

[Epoch 0/200] [Batch 4/500] [D loss: 0.658450] [G loss: 0.284579]



Bath:   1%|          | 6/500 [00:27<36:54,  4.48s/it][A

[Epoch 0/200] [Batch 5/500] [D loss: 0.650679] [G loss: 0.274603]



Bath:   1%|▏         | 7/500 [00:31<36:14,  4.41s/it][A

[Epoch 0/200] [Batch 6/500] [D loss: 0.642958] [G loss: 0.258813]



Bath:   2%|▏         | 8/500 [00:35<35:51,  4.37s/it][A

[Epoch 0/200] [Batch 7/500] [D loss: 0.629487] [G loss: 0.251455]



Bath:   2%|▏         | 9/500 [00:39<35:06,  4.29s/it][A

[Epoch 0/200] [Batch 8/500] [D loss: 0.612848] [G loss: 0.240373]



Bath:   2%|▏         | 10/500 [00:44<34:53,  4.27s/it][A

[Epoch 0/200] [Batch 9/500] [D loss: 0.602617] [G loss: 0.224631]



Bath:   2%|▏         | 11/500 [00:48<34:56,  4.29s/it][A

[Epoch 0/200] [Batch 10/500] [D loss: 0.582546] [G loss: 0.209981]



Bath:   2%|▏         | 12/500 [00:53<36:10,  4.45s/it][A

[Epoch 0/200] [Batch 11/500] [D loss: 0.558536] [G loss: 0.189801]



Bath:   3%|▎         | 13/500 [00:57<35:33,  4.38s/it][A

[Epoch 0/200] [Batch 12/500] [D loss: 0.543964] [G loss: 0.173560]



Bath:   3%|▎         | 14/500 [01:01<35:01,  4.32s/it][A

[Epoch 0/200] [Batch 13/500] [D loss: 0.521954] [G loss: 0.159954]



Bath:   3%|▎         | 15/500 [01:05<34:34,  4.28s/it][A

[Epoch 0/200] [Batch 14/500] [D loss: 0.491670] [G loss: 0.138815]



Bath:   3%|▎         | 16/500 [01:10<34:40,  4.30s/it][A

[Epoch 0/200] [Batch 15/500] [D loss: 0.470431] [G loss: 0.130059]



Bath:   3%|▎         | 17/500 [01:14<34:34,  4.29s/it][A

[Epoch 0/200] [Batch 16/500] [D loss: 0.442813] [G loss: 0.118261]



Bath:   4%|▎         | 18/500 [01:19<37:02,  4.61s/it][A

[Epoch 0/200] [Batch 17/500] [D loss: 0.422669] [G loss: 0.107615]



Bath:   4%|▍         | 19/500 [01:24<36:13,  4.52s/it][A

[Epoch 0/200] [Batch 18/500] [D loss: 0.385914] [G loss: 0.100076]


KeyboardInterrupt: 