In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import os
from PIL import Image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter

from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
class MiniBatchStd(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def forward(self, x):
        batch_size, num_channels, height, width = x.shape
        avg_std = torch.std(x, dim=0).mean().item()
        feature_map = torch.full((batch_size, 1, height, width), avg_std, device=x.device)
        out = torch.cat([x, feature_map], dim=1)

        return out

model = MiniBatchStd().to(device)
x = torch.randn(8, 32, 16, 16).to(device)
output = model(x)
output.shape

torch.Size([8, 33, 16, 16])

In [4]:
class ConvEqualizedLR(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, is_transpose=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) if not is_transpose else nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)

        fan_in_factor = np.sqrt(in_channels * kernel_size * kernel_size)
        gain_factor = np.sqrt(2)
        self.norm_constant = gain_factor / fan_in_factor

        nn.init.normal_(self.conv.weight.data, 0, 1)
        nn.init.constant_(self.conv.bias.data, 0)

    def forward(self, x):
        out = self.conv(x * self.norm_constant)        

        return out
    
x = torch.randn(8, 3, 16, 16).to(device)
model = ConvEqualizedLR(3, 32, 3, 1, 1).to(device)
output = model(x)
output.shape

torch.Size([8, 32, 16, 16])

In [5]:
class PixelwiseNorm(nn.Module):
    def __init__(self, epsilon=1e-8, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.epsilon = epsilon

    def forward(self, x):
        mean_of_squares = torch.mean(x**2, dim=1, keepdim=True)
        factor = torch.sqrt(mean_of_squares + self.epsilon)
        out = x / factor

        return out
    
model = PixelwiseNorm().to(device)
x = torch.randn(8, 3, 16, 16).to(device)
output = model(x)
output.shape

torch.Size([8, 3, 16, 16])

In [6]:
class Conv3x3Block(nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        pn = PixelwiseNorm()
        act = nn.LeakyReLU(negative_slope=0.2)

        self.conv_eq1 = nn.Sequential(
            ConvEqualizedLR(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            pn,
            act
        )
        self.conv_eq2 = nn.Sequential(
            ConvEqualizedLR(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            pn,
            act
        )

    
    def forward(self, x):
        out = self.conv_eq1(x)
        out = self.conv_eq2(out)

        return out

x = torch.randn(8, 512, 8, 8).to(device)
model = Conv3x3Block(512, 512).to(device)
output = model(x)
output.shape

torch.Size([8, 512, 8, 8])

In [7]:
class ToRGBLayer(nn.Module):
    def __init__(self, in_channels, img_channels=3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.rgb_layer = ConvEqualizedLR(in_channels, img_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        out = self.rgb_layer(x)

        return out
    
x = torch.randn(8, 512, 8, 8).to(device)
model = ToRGBLayer(512).to(device)
output = model(x)
output.shape

torch.Size([8, 3, 8, 8])

In [8]:
class SmoothFadeIn(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.act = nn.Tanh()

    def forward(self, alpha, upsampled_rgb, learnt_upsampled_rgb):
        interpolated_img = (1-alpha)*upsampled_rgb + alpha*learnt_upsampled_rgb
        out = self.act(interpolated_img)

        return out

upsampled_rgb = torch.rand(8, 3, 16, 16).to(device)
learnt_upsampled_rgb = torch.rand(8, 3, 16, 16).to(device)
fadein_layer = SmoothFadeIn().to(device)
for alpha in torch.linspace(0, 1, 5):
    out = fadein_layer(alpha, upsampled_rgb, learnt_upsampled_rgb)
    print(out.shape)
(out == nn.Tanh()(learnt_upsampled_rgb)).all()

torch.Size([8, 3, 16, 16])
torch.Size([8, 3, 16, 16])
torch.Size([8, 3, 16, 16])
torch.Size([8, 3, 16, 16])
torch.Size([8, 3, 16, 16])


tensor(True, device='cuda:0')

In [9]:
class ProGenerator(nn.Module):
    def __init__(self, channel_factors, latent_dim=512, starting_channels=512, img_channels=3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.smooth_fadein = SmoothFadeIn()

        self.first_block = nn.Sequential(
            ConvEqualizedLR(latent_dim, starting_channels, 4, 1, 0, is_transpose=True),
            nn.LeakyReLU(negative_slope=0.2),

            ConvEqualizedLR(starting_channels, starting_channels, 3, 1, 1),
            PixelwiseNorm(),
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.first_torgb = ToRGBLayer(starting_channels, img_channels)
        self.block_modules, self.rgb_modules = nn.ModuleList([self.first_block]), nn.ModuleList([self.first_torgb])

        for idx in range(len(channel_factors) - 1):
            in_channels, out_channels = int(starting_channels * channel_factors[idx]), int(starting_channels * channel_factors[idx+1])
            self.block_modules.append(Conv3x3Block(in_channels, out_channels))
            self.rgb_modules.append(ToRGBLayer(out_channels, img_channels))


    def forward(self, x, resolution_idx, alpha):
        learnt_upsampled = self.first_block(x)
        
        if resolution_idx == 0:
            learnt_upsampled_rgb = self.first_torgb(learnt_upsampled)
            return learnt_upsampled_rgb
        
        elif resolution_idx > 0:
             
            for idx in range(resolution_idx):
                upsampled = self.upsample(learnt_upsampled)
                learnt_upsampled = self.block_modules[idx+1](upsampled)
                        
            upsampled_rgb = self.rgb_modules[resolution_idx - 1](upsampled)
            learnt_upsampled_rgb = self.rgb_modules[resolution_idx](learnt_upsampled)
            
            learnt_upsampled_rgb = self.smooth_fadein(alpha, upsampled_rgb, learnt_upsampled_rgb)
            
            return learnt_upsampled_rgb

x = torch.rand(8, 512, 1, 1).to(device)
channel_factors = [1, 1, 1, 1, 1/2, 1/4, 1/8]
model = ProGenerator(channel_factors).to(device)
img_res = [4, 8, 16, 32, 64, 128, 256]
resolution_indices = [int(np.log2(res/4)) for res in img_res]
for res_idx in resolution_indices:
    output = model(x, res_idx, 0.5)
    print(output.shape)

torch.Size([8, 3, 4, 4])
torch.Size([8, 3, 8, 8])
torch.Size([8, 3, 16, 16])
torch.Size([8, 3, 32, 32])
torch.Size([8, 3, 64, 64])
torch.Size([8, 3, 128, 128])
torch.Size([8, 3, 256, 256])


In [10]:
class FromRGBLayer(nn.Module):
    def __init__(self, out_channels, img_channels=3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.rgb_layer = ConvEqualizedLR(img_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.act = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        out = self.rgb_layer(x)
        out = self.act(out)

        return out
    
x = torch.randn(8, 3, 8, 8).to(device)
model = FromRGBLayer(512).to(device)
output = model(x)
output.shape

torch.Size([8, 512, 8, 8])

In [11]:
class ProDiscriminator(nn.Module):
    def __init__(self, channel_factors, ending_channels=512, img_channels=3, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.downsample = nn.AvgPool2d(2)
        self.smooth_fadein = SmoothFadeIn()

        self.block_modules, self.rgb_modules = nn.ModuleList([]), nn.ModuleList([])

        for idx in range(len(channel_factors) - 1, 0, -1):
            in_channels, out_channels = int(ending_channels * channel_factors[idx]), int(ending_channels * channel_factors[idx-1])
            self.rgb_modules.append(FromRGBLayer(in_channels, img_channels)) 
            self.block_modules.append(Conv3x3Block(in_channels, out_channels))

        self.last_fromrgb = FromRGBLayer(ending_channels, img_channels)
        self.minibatch_std = MiniBatchStd()
        self.last_block = nn.Sequential(
            ConvEqualizedLR(ending_channels+1, ending_channels, kernel_size=3, stride=1, padding=1),
            PixelwiseNorm(),
            nn.LeakyReLU(negative_slope=0.2),

            ConvEqualizedLR(ending_channels, ending_channels, kernel_size=4, stride=1, padding=0),
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.block_modules.append(self.last_block), self.rgb_modules.append(self.last_fromrgb)
        self.fc = nn.Linear(ending_channels*1*1, 1)

    def forward(self, x, resolution_idx, alpha):
        module_index = - resolution_idx - 1
        upsampled = self.rgb_modules[module_index](x)
        
        if resolution_idx > 0:
            learnt_upsampled = self.block_modules[module_index](upsampled)
            learnt_downsampled = self.downsample(learnt_upsampled)
            
            downsampled_rgb = self.downsample(x)
            downsampled = self.rgb_modules[module_index + 1](downsampled_rgb)
            
            learnt_downsampled = self.smooth_fadein(alpha, downsampled, learnt_downsampled)
            
            for idx in range(module_index+1, -1):
                learnt_downsampled = self.block_modules[idx](learnt_downsampled)
                learnt_downsampled = self.downsample(learnt_downsampled)

            added_featuremap = self.minibatch_std(learnt_downsampled)
            learnt_downsampled = self.last_block(added_featuremap)

            flattened = learnt_downsampled.view(x.shape[0], -1)
            out = self.fc(flattened)

            return out
        
        elif resolution_idx == 0:
            added_featuremap = self.minibatch_std(upsampled)
            learnt_downsampled = self.last_block(added_featuremap)

            flattened = learnt_downsampled.view(x.shape[0], -1)
            out = self.fc(flattened)
            return out
        
        
x = torch.randn(8, 3, 16, 16).to(device)
channel_factors = [1, 1, 1, 1, 1/2, 1/4, 1/8]
model = ProDiscriminator(channel_factors).to(device)
output = model(x, 2, 0.5)
output.shape

torch.Size([8, 1])

In [12]:
# x = torch.rand(8, 512, 1, 1).to(device)
# channel_factors = [1, 1, 1, 1, 1/2, 1/4, 1/8]
# model_gen = ProGenerator(channel_factors).to(device)
# model_disc = ProDiscriminator(channel_factors).to(device)
# img_res = [4, 8, 16, 32, 64, 128, 256]
# resolution_indices = [int(np.log2(res/4)) for res in img_res]
# for res_idx in resolution_indices:
#     output = model_gen(x, res_idx, 0.5)
#     score = model_disc(output, res_idx, 0.5)
#     print(output.shape, score.shape)

In [13]:
# root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/dogs/'
num_channels = 3

class DogsDataset(Dataset):
    def __init__(self, data_path, image_resolution, num_channels=num_channels):
        super().__init__()
        self.data_path = data_path
        self.list_filenames = os.listdir(data_path)

        self.resolution_transform = transforms.Compose(
            [
                transforms.Resize((image_resolution, image_resolution)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean = [0.5 for _ in range(num_channels)],
                    std = [0.5 for _ in range(num_channels)]
                )
            ]
        )

    def __len__(self):
        return len(self.list_filenames)
    
    def __getitem__(self, index):
        image_filename = self.list_filenames[index]
        image_path = os.path.join(self.data_path, image_filename)

        image = Image.open(image_path)

        return self.resolution_transform(image)
    
# train_dataset = DogsDataset(root_path+'mix', 6)
# val_dataset = DogsDataset(root_path+'beagle', 6)

In [14]:
class Trainer:
    def __init__(
            self,
            optimizer_d,
            optimizer_g,
            scaler_d,
            scaler_g,
            model_d,
            model_g,
            penalty_coeff,
            epsilon_drift,
            device = device
    ):
        self.optimizer_d = optimizer_d
        self.optimizer_g = optimizer_g
        self.scaler_d = scaler_d
        self.scaler_g = scaler_g
        self.model_d = model_d
        self.model_g = model_g
        self.penalty_coeff = penalty_coeff
        self.epsilon_drift = epsilon_drift
        self.device = device

    def calc_grad_penalty(self, resolution_index, alpha, real, fake):
        batch_size, num_channels, height, width = real.shape
        epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, num_channels, height, width).to(self.device)

        joint_distribution = epsilon*real + (1-epsilon)*fake
        critic_term = self.model_d(joint_distribution, resolution_index, alpha)

        gradient = torch.autograd.grad(
            outputs = critic_term,
            inputs = joint_distribution,
            grad_outputs = torch.ones_like(critic_term),
            retain_graph = True,
            create_graph = True,
        )[0].view(batch_size, -1)

        l2_norm = torch.norm(gradient, p=2, dim=1)
        grad_penalty = torch.mean((l2_norm - 1)**2)        

        return grad_penalty
    
    def calc_disc_loss(self, resolution_index, alpha, real, fake, is_train):
        with torch.set_grad_enabled(is_train):
            real_arg = self.model_d(real, resolution_index, alpha)
            fake_arg = self.model_d(fake, resolution_index, alpha)

            grad_penalty = self.calc_grad_penalty(resolution_index, alpha, real, fake)

            wgan_gp_loss = (torch.mean(fake_arg) - torch.mean(real_arg)) + self.penalty_coeff*grad_penalty

            drift_term = self.epsilon_drift * torch.mean(real_arg**2)

            loss_d = wgan_gp_loss + drift_term

        return loss_d
    
    def calc_gen_loss(self, resolution_index, alpha, fake, is_train):
        with torch.set_grad_enabled(is_train):
            fake_arg = self.model_d(fake, resolution_index, alpha)
            loss_g = - torch.mean(fake_arg)

        return loss_g
    
    def calc_metrics(self, resolution_index, alpha, metrics_dict, fake, train_loader, val_loader):
        self.model_d.eval(), self.model_g.eval()

        final_str = ''
        loaders_list = [('Train', train_loader), ('Val', val_loader)]
            
        if metrics_dict == None:
            metrics_dict = {'Train': {'DiscLoss': [], 'GenLoss': []}, 'Val': {'DiscLoss': [], 'GenLoss': []}}

        for name, loader in loaders_list:
            len_data = 0
            total_loss_d, total_loss_g = 0, 0

            for real in loader:
                real = real.to(device)

                batch_size = real.shape[0]
                len_data += batch_size

                with torch.cuda.amp.autocast():
                    loss_d = self.calc_disc_loss(resolution_index, alpha, real, fake, is_train=False)
                total_loss_d += loss_d

                with torch.cuda.amp.autocast():
                    loss_g = self.calc_gen_loss(resolution_index, alpha, fake, is_train=False)
                total_loss_g += loss_g
                
            disc_loss = total_loss_d/len_data
            gen_loss = total_loss_g/len_data

            final_str += ' -- {} Disc Loss: {:.5f} -- {} Gen Loss: {:.5f}'.format(name, disc_loss, name, gen_loss)
                
            metrics_dict[name]['DiscLoss'].append(disc_loss.item())
            metrics_dict[name]['GenLoss'].append(gen_loss.item()) 

        self.model_d.train(), self.model_g.train()

        return final_str, metrics_dict
    
    def visualize_tensorboard(self, real, fake, loss_g, loss_d, writer_progan, tensorboard_steps):
        self.model_g.eval()

        with torch.no_grad():
            combined_grid = torch.cat([
                fake[:4]*0.5 + 0.5,
                real[:4]*0.5 + 0.5,
            ], dim=0
            )
            image_grid = make_grid(combined_grid, nrow=4, normalize=False)
            writer_progan.add_image('Generated', image_grid, global_step=tensorboard_steps)

            writer_progan.add_scalar('Generator Loss', loss_g.item(), global_step=tensorboard_steps)
            writer_progan.add_scalar('Discriminator Loss', loss_d.item(), global_step=tensorboard_steps)

        self.model_g.train()

        return None
    
    def fit_at_1_res(self, resolution_index, alpha_arr, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps):
        metrics_dict = None        
        epoch_loop = tqdm(range(1, n_epochs+1), total=n_epochs, leave=True)
        for epoch in epoch_loop:
            # batch_loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
            batch_loop = enumerate(train_loader)
            for batch_idx, real in batch_loop:
                real = real.to(device)
                
                batch_size = real.shape[0]
                noise_dim = 512
                noise = torch.randn(batch_size, noise_dim, 1, 1).to(self.device)

                alpha = alpha_arr[epoch-1, batch_idx]
                fake = self.model_g(noise, resolution_index, alpha)
                
                with torch.cuda.amp.autocast():
                    loss_d = self.calc_disc_loss(resolution_index, alpha, real, fake, is_train=True)
                    
                self.optimizer_d.zero_grad()
                self.scaler_d.scale(loss_d).backward(retain_graph=True)
                self.scaler_d.step(self.optimizer_d)
                self.scaler_d.update()

                with torch.cuda.amp.autocast():
                    loss_g = self.calc_gen_loss(resolution_index, alpha, fake, is_train=True)

                self.optimizer_g.zero_grad()
                self.scaler_g.scale(loss_g).backward(retain_graph=True)
                self.scaler_g.step(self.optimizer_g)
                self.scaler_g.update()

                epoch_loop.set_description(epoch_desc)
                epoch_loop.set_postfix(batch = f'{batch_idx+1}/{len(train_loader)}', train_loss_discriminator = f'{loss_d.item():.4f}', train_loss_generator = f'{loss_g.item():.4f}')

                if batch_idx % 5 == 0:
                    # print(f'Epoch: {epoch:2d}/{n_epochs} -- Batch: {batch_idx+1:3d}/{len(train_loader)}' + f' -- Train Disc Loss: {loss_d:.4f} -- Train Gen Loss: {loss_g:.4f}')
                    self.visualize_tensorboard(real, fake, loss_g, loss_d, writer_progan, tensorboard_steps)
                    tensorboard_steps += 1
            
            # final_str, metrics_dict = self.calc_metrics(resolution_index, alpha, metrics_dict, fake, train_loader, val_loader)
            # print('Epoch: {:2d}'.format(epoch) + final_str)
            
        return metrics_dict, tensorboard_steps
    
    def fadein_fit(self, resolution_index, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps):
        alpha_arr = np.linspace(0, 1, n_epochs*len(train_loader)).reshape(n_epochs, len(train_loader))
        metrics_dict, tensorboard_steps = self.fit_at_1_res(resolution_index, alpha_arr, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps)

        return metrics_dict, tensorboard_steps
    
    def stable_fit(self, resolution_index, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps):
        alpha_arr = np.ones((n_epochs, len(train_loader)))
        metrics_dict, tensorboard_steps = self.fit_at_1_res(resolution_index, alpha_arr, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps)

        return metrics_dict, tensorboard_steps
    
    def get_data_loader(self, resolution, path, batch_size, shuffle):
        num_channels = 3
        data = DogsDataset(path, resolution, num_channels)
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=shuffle)

        return dataloader
        
    def fit(self, img_res, epochs_per_res, batches_per_res, train_path, val_path, writer_progan):
        self.model_d.train(), self.model_g.train()

        metrics_per_res = dict(zip(img_res, [None]*len(img_res)))

        tensorboard_steps = 1

        for res, n_epochs, batch_size in zip(img_res, epochs_per_res, batches_per_res):
            
            train_loader = self.get_data_loader(res, train_path, batch_size, shuffle=True)
            val_loader = self.get_data_loader(res, val_path, batch_size, shuffle=False)

            metrics_dict = metrics_per_res[res]

            res_idx = int(np.log2(res/4))
            if res == 4:
                epoch_desc = f'Initial training at Image Resolution: {res}x{res}'
                metrics_dict, tensorboard_steps = self.fadein_fit(res_idx, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps)
            elif res > 4:
                epoch_desc = f'Fade-in training at Image Resolution: {res}x{res}'
                metrics_dict, tensorboard_steps = self.fadein_fit(res_idx, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps)
                epoch_desc = f'Stable training at Image Resolution: {res}x{res}'
                metrics_dict, tensorboard_steps = self.stable_fit(res_idx, n_epochs, epoch_desc, train_loader, val_loader, writer_progan, tensorboard_steps)
                
            metrics_per_res[res] = metrics_dict

            

        self.metrics_per_res = metrics_per_res

In [15]:
channel_factors = [1, 1, 1, 1, 1/2, 1/4, 1/8]
model_d = ProDiscriminator(channel_factors, ending_channels=512).to(device)
model_g = ProGenerator(channel_factors, latent_dim=512, starting_channels=512).to(device)

lr = 1e-3
beta1 = 0.0
beta2 = 0.999
optimizer_d = optim.Adam(model_d.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_g = optim.Adam(model_g.parameters(), lr=lr, betas=(beta1, beta2))

scaler_d = torch.cuda.amp.GradScaler()
scaler_g = torch.cuda.amp.GradScaler()

In [16]:
penalty_coeff = 10
epsilon_drift = 1e-3
trainer = Trainer(optimizer_d, optimizer_g, scaler_d, scaler_g, model_d, model_g, penalty_coeff, epsilon_drift, device)

In [17]:
img_res = [4, 8, 16, 32, 64, 128, 256]
epochs_per_res = [80, 90, 70, 50, 50, 50, 50]
batches_per_res = [64, 64, 32, 32, 16, 16, 8]
root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/dogs/'
train_path, val_path = root_path+'mix', root_path+'beagle'
writer_progan = SummaryWriter('logs/progan')

trainer.fit(img_res, epochs_per_res, batches_per_res, train_path, val_path, writer_progan)

  0%|          | 0/80 [00:00<?, ?it/s]

Finished training at Image Resolution: 4x4


  0%|          | 0/90 [00:00<?, ?it/s]

  0%|          | 0/90 [00:00<?, ?it/s]

Finished training at Image Resolution: 8x8


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Finished training at Image Resolution: 16x16


  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 