In [2]:
%%writefile ddp.py

import warnings
warnings.filterwarnings('ignore')

import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
import PIL
from tqdm import tqdm
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn.utils.spectral_norm as spectral_norm
import torch.nn.functional as F
import functools
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.utils import save_image


img_size = 256
n_channels = 1

latent_size = 128
batch_size = 9

is_parallel = True

step_conv_channels=32

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

beta1 = 0.0 
beta2 = 0.999

gamma = 0.1 # discriminator reg constant
gamma_decay = False

lr={
    'generator': 0.000025,
    'discriminator': 0.000025
}


if is_parallel:
    num_workers = 0
    world_size = 2
else:
    num_workers = 4

DATA_PATH = '...'
EPOCH_START = 0
LOAD_FILENAME_PATH_GENERATOR = ('weights/generator_epoch_%d.pth' % EPOCH_START)
LOAD_FILENAME_PATH_DISCRIMINATOR = ('weights/discriminator_epoch_%d.pth' % EPOCH_START)

#print('Lerning rate:', lr)
#print('device:',device)
#print('device count:', torch.cuda.device_count())

class Split(object):
    def __call__(self, image):
        return transforms.Grayscale(num_output_channels=n_channels)(image[1,:,:].view(n_channels,img_size,img_size))

dataset = ImageFolder(DATA_PATH, transform=transforms.Compose([
        transforms.Resize(img_size,interpolation=transforms.InterpolationMode.BICUBIC),
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.Resize(upsample_transform, interpolation=transforms.InterpolationMode.BICUBIC),
        #transforms.RandomCrop((IMG_WIDTH,IMG_HEIGHT)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
        Split()
        ]))

# for parallelism
from torch.utils.data.distributed import DistributedSampler

def prepare_dataloader_for_paralellism(rank, world_size, batch_size, pin_memory = False, num_workers = 0):
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, 
                            num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
    return dataloader

if not is_parallel:
    dataloader = DataLoader(dataset, batch_size, shuffle=True,num_workers=num_workers)

    
class Self_Attention(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim):
        super(Self_Attention,self).__init__()
        self.chanel_in = in_dim
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out    
    
class GBlock(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=None, channel_ratio=4):
        super(GBlock, self).__init__()
        
        self.in_channels, self.out_channels = in_channels, out_channels
        hidden_channels = in_channels // channel_ratio
        self.upsample = upsample
        
        self.activation = nn.ReLU()
        
        self.conv1 = spectral_norm(nn.Conv2d(in_channels, hidden_channels, kernel_size = 1, padding = 0))
        self.conv2 = spectral_norm(nn.Conv2d(hidden_channels, hidden_channels, kernel_size = 3, padding = 1))
        self.conv3 = spectral_norm(nn.Conv2d(hidden_channels, hidden_channels, kernel_size = 3, padding = 1))
        self.conv4 = spectral_norm(nn.Conv2d(hidden_channels, out_channels, kernel_size = 1, padding = 0))

        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(hidden_channels)
        self.bn3 = nn.BatchNorm2d(hidden_channels)
        self.bn4 = nn.BatchNorm2d(hidden_channels)
    
    def forward(self, x):
        h = self.conv1(self.activation(self.bn1(x)))
        h = self.activation(self.bn2(h))
        
        if self.in_channels != self.out_channels:
            x = x[:, :self.out_channels]
        
        if self.upsample:
            h = self.upsample(h)
            x = self.upsample(x)
            
        h = self.conv2(h)
        h = self.conv3(self.activation(self.bn3(h)))
        h = self.conv4(self.activation(self.bn4(h)))
        
        return x + h
    
class Generator(nn.Module):
    def __init__(self, encoding_dims = 128, step_channels = 128):
        super(Generator, self).__init__()
        
        self.linear = spectral_norm(nn.Linear(encoding_dims, 4 * 4 * 16 * step_channels))
        
        self.blocks = nn.Sequential(
            GBlock(16 * step_channels, 16 * step_channels, upsample = None),
            GBlock(16 * step_channels, 16 * step_channels, upsample = functools.partial(F.interpolate, scale_factor=2)),
            GBlock(16 * step_channels, 16 * step_channels, upsample = None),
            GBlock(16 * step_channels, 8 * step_channels, upsample = functools.partial(F.interpolate, scale_factor=2)),
            GBlock(8 * step_channels, 8 * step_channels, upsample = None),
            GBlock(8 * step_channels, 8 * step_channels, upsample = functools.partial(F.interpolate, scale_factor=2)),
            GBlock(8 * step_channels, 8 * step_channels, upsample = None),
            GBlock(8 * step_channels, 4 * step_channels, upsample = functools.partial(F.interpolate, scale_factor=2)),
            Self_Attention(4 * step_channels),
            GBlock(4 * step_channels, 4 * step_channels, upsample = None),
            GBlock(4 * step_channels, 2 * step_channels, upsample = functools.partial(F.interpolate, scale_factor=2)),
            GBlock(2 * step_channels, 2 * step_channels, upsample = None),
            GBlock(2 * step_channels, step_channels, upsample = functools.partial(F.interpolate, scale_factor=2))
        )
        
        self.output_layer = nn.Sequential(
            nn.BatchNorm2d(step_channels),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(step_channels, 1, kernel_size = 3, padding = 1)),
            nn.Tanh()
        )
        
    def forward(self, z):
        z = z.view(z.size(0),-1)
        h = self.linear(z)
        h = h.view(h.size(0), -1, 4, 4)
        h = self.blocks(h)
        h = self.output_layer(h)
        
        return h
    
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample = None, channel_ratio = 4):
        super(DBlock, self).__init__()
        
        hidden_channels = out_channels // channel_ratio
        
        self.downsample = downsample
        self.activation = nn.ReLU()
        
        self.conv1 = spectral_norm(nn.Conv2d(in_channels, hidden_channels, kernel_size = 1, padding = 0))
        self.conv2 = spectral_norm(nn.Conv2d(hidden_channels, hidden_channels, kernel_size = 3, padding = 1))
        self.conv3 = spectral_norm(nn.Conv2d(hidden_channels, hidden_channels, kernel_size = 3, padding = 1))
        self.conv4 = spectral_norm(nn.Conv2d(hidden_channels, out_channels, kernel_size = 1, padding = 0))
        
        self.learnable_sc = True if (in_channels != out_channels) else False
        if self.learnable_sc:
            self.conv_sc = spectral_norm(nn.Conv2d(in_channels, out_channels - in_channels, kernel_size = 1, padding = 0))
    
    def shortcut(self, x):
        if self.downsample:
            x = self.downsample(x)
        if self.learnable_sc:
            x = torch.cat([x, self.conv_sc(x)], 1)
        return x
    
    def forward(self, x):
        h = self.conv1(self.activation(x))
        h = self.conv2(self.activation(h))
        h = self.conv3(self.activation(h))
        h = self.activation(h)
        
        if self.downsample:
            h = self.downsample(h)
            
        h = self.conv4(h)
        
        return h + self.shortcut(x)
    
class Discriminator(nn.Module):
    def __init__(self, step_channels):
        super(Discriminator, self).__init__()
        
        self.input_conv = spectral_norm(nn.Conv2d(1, step_channels, kernel_size=3, padding = 1))
        self.activation = nn.ReLU()
        
        self.blocks = nn.Sequential(
            DBlock(step_channels, 2 * step_channels, downsample = nn.AvgPool2d(2)),
            DBlock(2 * step_channels, 2 * step_channels, downsample = None),
            DBlock(2 * step_channels, 4 * step_channels, downsample = nn.AvgPool2d(2)),
            DBlock(4 * step_channels, 4 * step_channels, downsample = None),
            Self_Attention(4 * step_channels),
            DBlock(4 * step_channels, 8 * step_channels, downsample = nn.AvgPool2d(2)),
            DBlock(8 * step_channels, 8 * step_channels, downsample = None),
            DBlock(8 * step_channels, 8 * step_channels, downsample = nn.AvgPool2d(2)),
            DBlock(8 * step_channels, 8 * step_channels, downsample = None),
            DBlock(8 * step_channels, 16 * step_channels, downsample = nn.AvgPool2d(2)),
            DBlock(16 * step_channels, 16 * step_channels, downsample = None),
            DBlock(16 * step_channels, 16 * step_channels, downsample = nn.AvgPool2d(2)),
            DBlock(16 * step_channels, 16 * step_channels, downsample = None),
        )
        
        self.linear = nn.Linear(16 * step_channels, 1)
        
    def forward(self, x):
        h = self.input_conv(x)
        h = self.blocks(h)
        h = torch.sum(self.activation(h), [2, 3])
        h = self.linear(h)
        
        return h
    
def discriminator_regularizer(D1_logits, D1_input, D2_logits, D2_input):
    D1 = torch.sigmoid(D1_logits)
    D2 = torch.sigmoid(D2_logits)
    grad_D1_logits = torch.autograd.grad(D1_logits, D1_input, torch.ones_like(D1_logits),
                                         create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]
    grad_D2_logits = torch.autograd.grad(D2_logits, D2_input, torch.ones_like(D2_logits),
                                        create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]
    norm_grad_D1_logits = torch.norm(grad_D1_logits.view(D1_input.shape[0], -1), dim = 1, keepdim = True)
    norm_grad_D2_logits = torch.norm(grad_D2_logits.view(D2_input.shape[0], -1), dim = 1, keepdim = True)
    
    reg_D1 = (1.0-D1)**2 * norm_grad_D1_logits**2
    reg_D2 = D2**2 * norm_grad_D2_logits**2
    reg = torch.mean(reg_D1 + reg_D2)
    
    return reg

def save_samples(samples, epoch):
    print('Saving samples')
    if not os.path.isdir('images'):
        os.mkdir('images')
    if not os.path.isdir('images/images_{}_epochs'.format(epoch)):
        os.mkdir('images/images_{}_epochs'.format(epoch))
    for i in range(samples.shape[0]):
        save_image((samples[i]+1)/2. , 'images/images_{}_epochs/{}.png'.format(epoch, i))
    print('Saving samples complete')

def save_logs(logs, epoch=0):
    txts=['losses_g','loss_g_per_batch', 'losses_d','loss_d_per_batch','real_scores','real_score_per_batch','fake_scores','fake_score_per_batch']
    for i, txt in enumerate(txts):
        with open('logs/'+txt+'_{}_epoch.txt'.format(epoch), 'w') as f:
            for e in logs[i]:
                f.write(str(e)+' ')
    
    print('Logs saved')

def fit(rank, model, dataloader, criterion, epochs, lr, epochs_start=0, uploaded=False):
    if epochs_start!=0 and not uploaded:
        map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        model['discriminator'].load_state_dict(torch.load(LOAD_FILENAME_PATH_DISCRIMINATOR, map_location=map_location))
        model['generator'].load_state_dict(torch.load(LOAD_FILENAME_PATH_GENERATOR, map_location=map_location))
        print('Model uploaded')
        
    #model["discriminator"].to(device)
    #model["generator"].to(device)
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()
    
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    loss_g_per_batch = []
    loss_d_per_batch = []
    real_score_per_batch = []
    fake_score_per_batch = []
    
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(), 
                                          lr=lr['discriminator'], betas=(beta1, beta2)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=lr['generator'], betas=(beta1, beta2))
    }
    
    gamma_reg = gamma
    
    for epoch in tqdm(range(epochs_start, epochs)):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        
        if gamma_decay:
            gamma_reg = gamma_reg * 0.01**(epoch/epochs)
        
        if is_parallel:
            dataloader.sampler.set_epoch(epoch)
        
        for real_images, _ in dataloader:#tqdm(dataloader):
             
            # discriminator step
            real_images = real_images.to(rank).requires_grad_()
            optimizer["discriminator"].zero_grad()

            # real images to discriminator
            real_preds = model["discriminator"](real_images)
            
            # generating images
            latent = torch.randn(real_images.size(0), latent_size, 1, 1, device=device)
            fake_images = model["generator"](latent)

            # generated images to discriminator
            fake_preds = model["discriminator"](fake_images)
            
            # logs
            cur_fake_score = torch.mean(fake_preds).item()
            cur_real_score = torch.mean(real_preds).item()
            real_score_per_epoch.append(cur_real_score)
            real_score_per_batch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)
            fake_score_per_batch.append(cur_fake_score)
            
            # backward pass
            loss_d = criterion['discriminator'](real_preds,fake_preds)
            loss_d += gamma_reg / 2. * discriminator_regularizer(real_preds, real_images, fake_preds, fake_images)
            loss_d.backward()
            optimizer["discriminator"].step()
            
            # logs
            loss_d_per_epoch.append(loss_d.item())
            loss_d_per_batch.append(loss_d.item())
            
            # generator step
            optimizer["generator"].zero_grad()
            
            # generating images
            latent = torch.randn(real_images.size(0), latent_size, 1, 1, device=rank)
            fake_images = model["generator"](latent)
            
            # generated images to discriminator
            preds = model["discriminator"](fake_images)
            loss_g = criterion["generator"](preds)
            
            # backward pass
            loss_g.backward()
            optimizer["generator"].step()
            
            # logs
            loss_g_per_epoch.append(loss_g.item())
            loss_g_per_batch.append(loss_g.item())
            
        # logs
        losses_g.append(np.mean(loss_g_per_epoch))
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))
        
        # logs
        if rank==0:
            print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
                epoch+1, epochs, 
                losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1]))
            losses = [losses_g, loss_g_per_batch, losses_d, loss_d_per_batch, real_scores, real_score_per_batch, fake_scores, fake_score_per_batch]
            save_logs(losses, epoch = epoch + 1)
            for l in losses:
                l.clear()

        # examples
        if rank==0:
            #plt.figure(figsize=(12,12))
            #plt.axis("off")
            #plt.title("Generated Images")
            #plt.imshow(np.transpose(make_grid(fake_images.to(rank)[:8], padding=2, normalize=True).cpu(),(1,2,0)))
            #plt.show()
            save_samples(fake_images, epoch+1)
        
        if (epoch+1+epochs_start)%1==0 and rank==0:
            torch.save(model['generator'].state_dict(),'weights/generator_epoch_%d.pth' % (epoch+1))
            torch.save(model['discriminator'].state_dict(),'weights/discriminator_epoch_%d.pth' % (epoch+1))
            print('Model Saved! Epoch: %d' % (epoch+1+epochs_start))
            
    
    return [losses_g, loss_g_per_batch, losses_d, loss_d_per_batch, real_scores, real_score_per_batch, fake_scores, fake_score_per_batch]
    
def show_samples(model, amount=16):
    model['generator'].eval()
    with torch.no_grad():
        z = np.array([np.random.normal(0, 1, latent_size) for i in range(amount)])
        output = model['generator'](torch.FloatTensor(z).to(device))

    plt.figure(figsize=(12, 12))
    plt.axis("off")
    plt.title("Generated Images")
    plt.imshow(np.transpose(make_grid(output.to(device), padding=2, normalize=True).cpu() ,(1,2,0)))
    
# custom weights initialization called on netG and netD
def weights_init(m):
    def _weights_init(module):
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            nn.init.normal_(module.weight, 0.0, 0.02)
            nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.normal_(module.weight, 1.0, 0.02)
            nn.init.constant_(module.bias, 0)
    
    return m.apply(_weights_init)

# hinge losses 
def generator_loss(out_fake):
    return - out_fake.mean()
    
def discriminator_loss(out_real,out_fake):
    return torch.nn.ReLU()(1.0 + out_fake).mean() + torch.nn.ReLU()(1.0 - out_real).mean()

# Distributed Data Parallel Setup

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp

def setup(rank, world_size):    
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '5554'    
    
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
def cleanup():
    dist.destroy_process_group()
    
def main(rank, world_size):
    setup(rank, world_size)
    if rank==0:
        print('Setup done, starting preparing dataloader')
    
    dataloader = prepare_dataloader_for_paralellism(rank, world_size, batch_size)
    if rank==0:
        print('Dataloader prepared')
    
    generator = Generator(encoding_dims = latent_size, step_channels = step_conv_channels)
    generator.apply(weights_init)
    generator = generator.to(rank)
    generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator)
    generator = DDP(generator, device_ids=[rank], output_device=rank, find_unused_parameters=False)
    
    discriminator = Discriminator(step_channels = step_conv_channels)
    discriminator.apply(weights_init)
    discriminator = discriminator.to(rank)
    discriminator = DDP(discriminator, device_ids=[rank], output_device=rank, find_unused_parameters=False)
    if rank==0:
        print('weights init done')
    model = {
        'generator' : generator,
        'discriminator' : discriminator
    }
    criterion={
        'generator' : generator_loss,
        'discriminator' : discriminator_loss
    }
    if rank==0:
        print('Model prepared, starting learning')
    
    epochs = 50
    
    logs = fit(rank, model, dataloader, criterion, epochs, lr, epochs_start = EPOCH_START)
    
    if rank==0:
        print('Learning done, epochs: %d' %epochs)
        
    cleanup()
    
    #if rank==0:
        #save_logs(logs)
        #show_samples(model, amount=16)
    
world_size = 2
if __name__ == '__main__':
    mp.spawn(main, args=(world_size,), nprocs=world_size)

Overwriting ddp.py


In [4]:
!CUDA_VISIBLE_DEVICES=0,2 python ddp.py

Setup done, starting preparing dataloader
Dataloader prepared
weights init done
Model prepared, starting learning
  0%|                                                    | 0/50 [00:00<?, ?it/s]Epoch [1/50], loss_g: 0.1249, loss_d: 1.9304, real_score: 0.0027, fake_score: -0.0938
  2%|▋                                    | 1/50 [1:22:58<67:46:07, 4978.93s/it]Logs saved
Saving samples
Saving samples complete
Model Saved! Epoch: 1
  2%|▋                                    | 1/50 [1:22:59<67:46:18, 4979.16s/it]Epoch [2/50], loss_g: 0.1954, loss_d: 1.9770, real_score: -0.1450, fake_score: -0.1744
Logs saved
Saving samples
  4%|█▍                                   | 2/50 [2:45:33<66:11:43, 4964.66s/it]Saving samples complete
Model Saved! Epoch: 2
  4%|█▍                                   | 2/50 [2:45:33<66:11:48, 4964.75s/it]Epoch [3/50], loss_g: 0.0717, loss_d: 1.9861, real_score: -0.0361, fake_score: -0.0539
  6%|██▏                                  | 3/50 [4:07:24<64:29:48, 4940.19s/it]Lo

Saving samples complete
Model Saved! Epoch: 28
 56%|███████████████████▌               | 28/50 [38:15:08<29:46:38, 4872.64s/it]Epoch [29/50], loss_g: 0.4868, loss_d: 1.4186, real_score: 0.3968, fake_score: -0.3908
 58%|████████████████████▎              | 29/50 [39:36:02<28:23:32, 4867.24s/it]Logs saved
Saving samples
Saving samples complete
Model Saved! Epoch: 29
 58%|████████████████████▎              | 29/50 [39:36:02<28:23:31, 4867.24s/it]Epoch [30/50], loss_g: 0.5403, loss_d: 1.3624, real_score: 0.4398, fake_score: -0.4376
 60%|█████████████████████              | 30/50 [40:57:37<27:05:08, 4875.41s/it]Logs saved
Saving samples
Saving samples complete
Model Saved! Epoch: 30
 60%|█████████████████████              | 30/50 [40:57:37<27:05:08, 4875.41s/it]Epoch [31/50], loss_g: 0.5335, loss_d: 1.3710, real_score: 0.4410, fake_score: -0.4307
 62%|█████████████████████▋             | 31/50 [42:18:27<25:41:29, 4867.89s/it]Logs saved
Saving samples
Saving samples complete
Model Saved! Epo