In [None]:

!pip install timm dill
!pip install pytorch-ignite
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.utils import spectral_norm
from torchvision import utils as vutils
from torchsummary import summary
import ignite
import math
import logging
import matplotlib.pyplot as plt
from torchsummary import summary

import torchvision.transforms as transforms
import torchvision.utils as vutils

from ignite.engine import Engine, Events
import ignite.distributed as idist

import argparse
import logging

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

In [None]:
ignite.utils.manual_seed(999)
ignite.utils.setup_logger(name="ignite.distributed.auto.auto_dataloader", level=logging.WARNING)
ignite.utils.setup_logger(name="ignite.distributed.launcher.Parallel", level=logging.WARNING)

In [None]:
%cd /content/drive/MyDrive/NN
#!pip install --upgrade --no-cache-dir gdown
#!gdown https://drive.google.com/u/0/uc?id=1aAJCZbXNHyraJ6Mi13dSbe7pTyfPXha0&export=download

Data Loader

In [None]:
from torchvision.datasets import ImageFolder

image_size = 64
batch_size = 16

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                        transforms.Resize((256, 256)),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

train_data = ImageFolder("train/cat-train", transform=train_transforms)


In [None]:
train_dataloader = idist.auto_dataloader(
    train_data, 
    batch_size=batch_size, 
    num_workers=2, 
    shuffle=True, 
    drop_last=True,
)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
real_batch = next(iter(train_dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

Utils

In [None]:
def kaiming_init(module):
    classname = module.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.kaiming_normal_(module.weight, nonlinearity='relu')


def weights_init(m):
    classname = m.__class__.__name__
    if "Conv" in classname:
        try:
            m.weight.data.normal_(0.0, 0.02)
        except:
            pass
    elif "BatchNorm" in classname:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)



Data Augmentation

In [None]:
class DiffAugment:
    def __init__(self, policy='', channels_first=True):
        self.policies = policy.split(',')
        self.channels_first = channels_first

    def forward(self, x):
        if self.policies:
            if not self.channels_first:
                x = x.permute(0, 3, 1, 2)
            for p in self.policies:
                for f in AUGMENT_FNS[p]:
                    x = f(x)
            if not self.channels_first:
                x = x.permute(0, 2, 3, 1)
            x = x.contiguous()
        return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

MODEL IMPLEMENTATION

1.   DISCRIMINATOR
2.   PROJECT GAN



DISCRIMINATOR - MULTI-SCALE DISCRIMINATOR (MSD)


1.   Down-sampling Block for lower resolution inputs implementation
2.   Multi Scale Discriminator implementation



In [None]:
class DownBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super(DownBlock, self).__init__()
        self.conv = nn.Conv2d(c_in, c_out, 4, 2, 1)
        self.bn = nn.BatchNorm2d(c_out)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.leaky_relu(x)

In [None]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self, channels, l):
        super(MultiScaleDiscriminator, self).__init__()
        self.head_conv = spectral_norm(nn.Conv2d(512, 1, 3, 1, 1))
        layers = [DownBlock(channels, 64 * [1, 2, 4, 8][l - 1])] + [DownBlock(64 * i, 64 * i * 2) for i in [1, 2, 4][l - 1:]]
        self.model = nn.Sequential(*layers)
        self.optim = Adam(self.model.parameters(), lr=0.0002, betas=(0, 0.99))

    def forward(self, x):
        x = self.model(x)
        return self.head_conv(x)


In [None]:
class CSM(nn.Module):
    """
    Implementation for the proposed Cross-Scale Mixing.
    """

    def __init__(self, channels, conv3_out_channels):
        super(CSM, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, 3, 1, 1)
        self.conv3 = nn.Conv2d(channels, conv3_out_channels, 3, 1, 1)

        for param in self.conv1.parameters():
            param.requires_grad = False

        for param in self.conv3.parameters():
            param.requires_grad = False

        self.apply(kaiming_init)

    def forward(self, high_res, low_res=None):
        batch, channels, width, height = high_res.size()
        if low_res is None:
            # high_res_flatten = rearrange(high_res, "b c h w -> b c (h w)")
            high_res_flatten = high_res.view(batch, channels, width * height)
            high_res = self.conv1(high_res_flatten)
            high_res = high_res.view(batch, channels, width, height)
            high_res = self.conv3(high_res)
            high_res = F.interpolate(high_res, scale_factor=2., mode="bilinear")
            return high_res
        else:
            high_res_flatten = high_res.view(batch, channels, width * height)
            high_res = self.conv1(high_res_flatten)
            high_res = high_res.view(batch, channels, width, height)
            high_res = torch.add(high_res, low_res)
            high_res = self.conv3(high_res)
            high_res = F.interpolate(high_res, scale_factor=2., mode="bilinear")
            return high_res


EfficientNet-Lite Implementation

In [None]:
efficientnet_lite_params = {
    # width_coefficient, depth_coefficient, image_size, dropout_rate
    'efficientnet_lite0': [1.0, 1.0, 224, 0.2],
    'efficientnet_lite1': [1.0, 1.1, 240, 0.2],
    'efficientnet_lite2': [1.1, 1.2, 260, 0.3],
    'efficientnet_lite3': [1.2, 1.4, 280, 0.3],
    'efficientnet_lite4': [1.4, 1.8, 300, 0.3],
}


In [None]:
def round_filters(filters, multiplier, divisor=8, min_width=None):
    """Calculate and round number of filters based on width multiplier."""
    if not multiplier:
        return filters
    filters *= multiplier
    min_width = min_width or divisor
    new_filters = max(min_width, int(filters + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_filters < 0.9 * filters:
        new_filters += divisor
    return int(new_filters)

def round_repeats(repeats, multiplier):
    """Round number of filters based on depth multiplier."""
    if not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))

def drop_connect(x, drop_connect_rate, training):
    if not training:
        return x
    keep_prob = 1.0 - drop_connect_rate
    batch_size = x.shape[0]
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device)
    binary_mask = torch.floor(random_tensor)
    x = (x / keep_prob) * binary_mask
    return x



class MBConvBlock(nn.Module):
    def __init__(self, inp, final_oup, k, s, expand_ratio, se_ratio, has_se=False):
        super(MBConvBlock, self).__init__()

        self._momentum = 0.01
        self._epsilon = 1e-3
        self.input_filters = inp
        self.output_filters = final_oup
        self.stride = s
        self.expand_ratio = expand_ratio
        self.has_se = has_se
        self.id_skip = True  # skip connection and drop connect

        # Expansion phase
        oup = inp * expand_ratio  # number of output channels
        if expand_ratio != 1:
            self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._momentum, eps=self._epsilon)

        # Depthwise convolution phase
        self._depthwise_conv = nn.Conv2d(
            in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
            kernel_size=k, padding=(k - 1) // 2, stride=s, bias=False)
        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._momentum, eps=self._epsilon)

        # Squeeze and Excitation layer, if desired
        if self.has_se:
            num_squeezed_channels = max(1, int(inp * se_ratio))
            self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
            self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)

        # Output phase
        self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._momentum, eps=self._epsilon)
        self._relu = nn.ReLU6(inplace=True)

    def forward(self, x, drop_connect_rate=None):
        """
        :param x: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """

        # Expansion and Depthwise Convolution
        identity = x
        if self.expand_ratio != 1:
            x = self._relu(self._bn0(self._expand_conv(x)))
        x = self._relu(self._bn1(self._depthwise_conv(x)))

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._se_expand(self._relu(self._se_reduce(x_squeezed)))
            x = torch.sigmoid(x_squeezed) * x

        x = self._bn2(self._project_conv(x))

        # Skip connection and drop connect
        if self.id_skip and self.stride == 1  and self.input_filters == self.output_filters:
            if drop_connect_rate:
                x = drop_connect(x, drop_connect_rate, training=self.training)
            x += identity  # skip connection
        return x


class EfficientNetLite(nn.Module):
    def __init__(self, widthi_multiplier, depth_multiplier, num_classes, drop_connect_rate, dropout_rate):
        super(EfficientNetLite, self).__init__()

        # Batch norm parameters
        momentum = 0.01
        epsilon = 1e-3
        self.drop_connect_rate = drop_connect_rate

        mb_block_settings = [
            #repeat|kernal_size|stride|expand|input|output|se_ratio
            [1, 3, 1, 1, 32,  16,  0.25],
            [2, 3, 2, 6, 16,  24,  0.25],
            [2, 5, 2, 6, 24,  40,  0.25],
            [3, 3, 2, 6, 40,  80,  0.25],
            [3, 5, 1, 6, 80,  112, 0.25],
            [4, 5, 2, 6, 112, 192, 0.25],
            [1, 3, 1, 6, 192, 320, 0.25]
        ]

        # Stem
        out_channels = 32
        self.stem = nn.Sequential(
            nn.Conv2d(3, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=out_channels, momentum=momentum, eps=epsilon),
            nn.ReLU6(inplace=True),
        )

        # Build blocks
        self.blocks = nn.ModuleList([])
        for i, stage_setting in enumerate(mb_block_settings):
            stage = nn.ModuleList([])
            num_repeat, kernal_size, stride, expand_ratio, input_filters, output_filters, se_ratio = stage_setting
            # Update block input and output filters based on width multiplier.
            input_filters = input_filters if i == 0 else round_filters(input_filters, widthi_multiplier)
            output_filters = round_filters(output_filters, widthi_multiplier)
            num_repeat= num_repeat if i == 0 or i == len(mb_block_settings) - 1  else round_repeats(num_repeat, depth_multiplier)

            # The first block needs to take care of stride and filter size increase.
            stage.append(MBConvBlock(input_filters, output_filters, kernal_size, stride, expand_ratio, se_ratio, has_se=False))
            if num_repeat > 1:
                input_filters = output_filters
                stride = 1
            for _ in range(num_repeat - 1):
                stage.append(MBConvBlock(input_filters, output_filters, kernal_size, stride, expand_ratio, se_ratio, has_se=False))

            self.blocks.append(stage)

        # Head
        in_channels = round_filters(mb_block_settings[-1][5], widthi_multiplier)
        out_channels = 1280
        self.head = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(num_features=out_channels, momentum=momentum, eps=epsilon),
            nn.ReLU6(inplace=True),
        )

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))

        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None
        self.fc = torch.nn.Linear(out_channels, num_classes)

        self._initialize_weights()

    def forward(self, x):
        features = []
        #print(x.shape)
        x = self.stem(x)
        idx = 0
        for i, stage in enumerate(self.blocks):
            for block in stage:
                drop_connect_rate = self.drop_connect_rate
                if drop_connect_rate:
                    drop_connect_rate *= float(idx) / len(self.blocks)
                x = block(x, drop_connect_rate)
                idx +=1
            if i in [1, 2, 3, 6]:
                features.append(x)
            #print(f"After block{i}", x.shape)
        x = self.head(x)
        #print(x.shape)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.fc(x)
        return x, features

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 1.0/float(n))
                m.bias.data.zero_()

    def load_pretrain(self, path):
        state_dict = torch.load(path)
        self.load_state_dict(state_dict, strict=True)


def build_efficientnet_lite(name, num_classes):
    width_coefficient, depth_coefficient, _, dropout_rate = efficientnet_lite_params[name]
    model = EfficientNetLite(width_coefficient, depth_coefficient, num_classes, 0.2, dropout_rate)
    return model


def load_checkpoint(net, checkpoint):
    from collections import OrderedDict

    temp = OrderedDict()
    if 'state_dict' in checkpoint:
        checkpoint = dict(checkpoint['state_dict'])
    for k in checkpoint:
        k2 = 'module.'+k if not k.startswith('module.') else k
        temp[k2] = checkpoint[k]

    net.load_state_dict(temp, strict=True)

model_name = 'efficientnet_lite1'
model = build_efficientnet_lite(model_name, 1000)

use_gpu = False
if torch.cuda.is_available():
  use_gpu = True
model.load_state_dict(torch.load("efficientnet_lite1.pth"))



FastGAN Generator Implementation

In [None]:
"""
Generator architecture and code taken from "Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis" (arxiv.org/abs/2101.04775) and github.com/odegeasslbc/FastGAN-pytorch, respectively.
"""
import os
import argparse
from tqdm import tqdm
from torchvision.utils import save_image


def conv2d(*args, **kwargs):
    return spectral_norm(nn.Conv2d(*args, **kwargs))


def convTranspose2d(*args, **kwargs):
    return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))


def batchNorm2d(*args, **kwargs):
    return nn.BatchNorm2d(*args, **kwargs)


def linear(*args, **kwargs):
    return spectral_norm(nn.Linear(*args, **kwargs))


class GLU(nn.Module):
    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc / 2)
        return x[:, :nc] * torch.sigmoid(x[:, nc:])


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)

    def forward(self, feat, noise=None):
        if noise is None:
            batch, _, height, width = feat.shape
            noise = torch.randn(batch, 1, height, width).to(feat.device)

        return feat + self.weight * noise


class Swish(nn.Module):
    def forward(self, feat):
        return feat * torch.sigmoid(feat)


class SEBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()

        self.main = nn.Sequential(nn.AdaptiveAvgPool2d(4),
                                  conv2d(ch_in, ch_out, 4, 1, 0, bias=False), Swish(),
                                  conv2d(ch_out, ch_out, 1, 1, 0, bias=False), nn.Sigmoid())

    def forward(self, feat_small, feat_big):
        return feat_big * self.main(feat_small)


class InitLayer(nn.Module):
    def __init__(self, nz, channel):
        super().__init__()
        self.init = nn.Sequential(
            convTranspose2d(nz, channel * 2, 4, 1, 0, bias=False),
            batchNorm2d(channel * 2), GLU())

    def forward(self, noise):
        noise = noise.view(noise.shape[0], -1, 1, 1)
        return self.init(noise)


def UpBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv2d(in_planes, out_planes * 2, 3, 1, 1, bias=False),
        batchNorm2d(out_planes * 2), GLU())
    return block


def UpBlockComp(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv2d(in_planes, out_planes * 2, 3, 1, 1, bias=False),
        NoiseInjection(),
        batchNorm2d(out_planes * 2), GLU(),
        conv2d(out_planes, out_planes * 2, 3, 1, 1, bias=False),
        NoiseInjection(),
        batchNorm2d(out_planes * 2), GLU()
    )
    return block


class Generator(nn.Module):
    def __init__(self, ngf=64, nz=100, nc=3, im_size=1024):
        super(Generator, self).__init__()

        nfc_multi = {4: 16, 8: 8, 16: 4, 32: 2, 64: 2, 128: 1, 256: 0.5, 512: 0.25, 1024: 0.125}
        nfc = {}
        for k, v in nfc_multi.items():
            nfc[k] = int(v * ngf)

        self.im_size = im_size

        self.init = InitLayer(nz, channel=nfc[4])

        self.feat_8 = UpBlockComp(nfc[4], nfc[8])
        self.feat_16 = UpBlock(nfc[8], nfc[16])
        self.feat_32 = UpBlockComp(nfc[16], nfc[32])
        self.feat_64 = UpBlock(nfc[32], nfc[64])
        self.feat_128 = UpBlockComp(nfc[64], nfc[128])
        self.feat_256 = UpBlock(nfc[128], nfc[256])

        self.se_64 = SEBlock(nfc[4], nfc[64])
        self.se_128 = SEBlock(nfc[8], nfc[128])
        self.se_256 = SEBlock(nfc[16], nfc[256])

        self.to_big = conv2d(nfc[im_size], nc, 3, 1, 1, bias=False)

        self.apply(weights_init)

    def forward(self, x):

        feat_4 = self.init(x)
        feat_8 = self.feat_8(feat_4)
        feat_16 = self.feat_16(feat_8)
        feat_32 = self.feat_32(feat_16)

        feat_64 = self.se_64(feat_4, self.feat_64(feat_32))

        feat_128 = self.se_128(feat_8, self.feat_128(feat_64))

        feat_256 = self.se_256(feat_16, self.feat_256(feat_128))

        return self.to_big(feat_256)

In [None]:
image_size = 256
loadCheckpoint = ""
checkpoint_path = "afhq-cat2/"
lr = 0.0002
beta1 = 0.0
beta2 = 0.999
checkpoint_efficient_net = "efficientnet_lite1.pth"
latent_dim = 100
epochs = 50
diff_aug = True
batch_size = 16
log_every = 100

%ls

In [None]:
def get_feature_channels():
    sample = torch.randn(1, 3, img_size, img_size)
    _, features = efficient_net(sample)
    return [f.shape[1] for f in features]
    
def csm_forward(features):
    features = features[::-1]
    csm_features = []
    for i, csm in enumerate(csms):
        if i == 0:
            d = csm(features[i])
            csm_features.append(d)
        else:
            d = csm(features[i], d)
            csm_features.append(d)
    return features

In [None]:
img_size = image_size
gen = idist.auto_model(Generator(im_size=image_size))

summary(gen, (latent_dim, 1, 1))

if loadCheckpoint:
    gen.load_state_dict(torch.load(os.path.join(loadCheckpoint,"Generator.pth")))
    gen.train()

gen_optim = Adam(gen.parameters(), lr=lr, betas=(beta1, beta2))
efficient_net = build_efficientnet_lite("efficientnet_lite1", 1000)
efficient_net = nn.DataParallel(efficient_net)
checkpoint = torch.load(checkpoint_efficient_net)
load_checkpoint(efficient_net, checkpoint)
print("carico il checkpoint di efficientnet")

efficient_net.eval()

feature_sizes = get_feature_channels()

#summary(gen,(latent_dim, 1, 1))


csms = nn.ModuleList([
    CSM(feature_sizes[3], feature_sizes[2]),
    CSM(feature_sizes[2], feature_sizes[1]),
    CSM(feature_sizes[1], feature_sizes[0]),
    CSM(feature_sizes[0], feature_sizes[0]),
])


if loadCheckpoint:
    for i in range(len(csms) - 1):
        csms[i].load_state_dict(torch.load(os.path.join(loadCheckpoint,f"CSM_{i}.pth")))
        csms[i].train()
        

discs = idist.auto_model(nn.ModuleList([
    MultiScaleDiscriminator(feature_sizes[0], 1),
    MultiScaleDiscriminator(feature_sizes[1], 2),
    MultiScaleDiscriminator(feature_sizes[2], 3),
    MultiScaleDiscriminator(feature_sizes[3], 4),
][::-1]))

disc = idist.auto_model(MultiScaleDiscriminator(feature_sizes[0], 1))

if loadCheckpoint:
    for i in range(len(discs) - 1):
        discs[i].load_state_dict(torch.load(os.path.join(loadCheckpoint,f"Discriminator_{i}.pth")))
        discs[i].train()

if loadCheckpoint:
    print(f"Checkpoint at : {loadCheckpoint} successfully loaded")

augmentations = 'color,translation,cutout'

DiffAug = DiffAugment(augmentations)






In [None]:


device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")

gen.to(device)
efficient_net.to(device)

for disc in discs:
    disc.to(device)
for csm in csms:
    csm.to(device)

disc_losses = []
gen_losses = []

ckpts_outputs_path = os.path.join(checkpoint_path, "ckpts_outputs")
if not os.path.exists(ckpts_outputs_path):
    os.mkdir(ckpts_outputs_path)

if loadCheckpoint:
    starting_epoch = int(os.path.basename(loadCheckpoint)) + 1
    logging.info(f"Resuming from epoch {starting_epoch}")
else:
    starting_epoch = 0

for epoch in range(starting_epoch, epochs + starting_epoch):
    logging.info(f"Starting epoch {epoch}")



    for i, (real_imgs, _) in enumerate(train_dataloader):
        # pass the real batch to cuda
        real_imgs = real_imgs.to(device)
        # geneate a random batch and pass it to cuda
        z = torch.randn(real_imgs.shape[0], latent_dim)
        z = z.to(device)
        
        gen_imgs_disc = gen(z).detach()
        batch_gen_imgs = torch.clone(gen_imgs_disc)

        # apply augmentation
        if diff_aug:
            gen_imgs_disc = DiffAug.forward(gen_imgs_disc)
            real_imgs = DiffAug.forward(real_imgs)

        # get efficient net features
        _, features_fake = efficient_net(gen_imgs_disc)
        _, features_real = efficient_net(real_imgs)

        # feed efficient net features through CSM
        features_real = csm_forward(features_real)
        features_fake = csm_forward(features_fake)

        # Train Discriminators:

        for feature_real, feature_fake, disc in zip(features_real, features_fake, discs):
            disc.optim.zero_grad()
            y_hat_real = disc(feature_real)  # Cx4x4
            y_hat_fake = disc(feature_fake)  # Cx4x4
            y_hat_real = y_hat_real.sum(1)  # sum along channels axis
            y_hat_fake = y_hat_fake.sum(1)
            loss_real = torch.mean(F.relu(1. - y_hat_real))
            loss_fake = torch.mean(F.relu(1. + y_hat_fake))
            disc_loss = loss_real + loss_fake
            disc_loss.backward(retain_graph=True)
            disc.optim.step()
            disc_losses.append(disc_loss.cpu().detach().numpy())

        # Train Generator: z is the model of the generator
        z = torch.randn(real_imgs.shape[0], latent_dim)
        z = z.to(device)

        gen_imgs_gen = gen(z)

        if diff_aug:
            gen_imgs_gen = DiffAug.forward(gen_imgs_gen)

        # get efficient net features
        _, features_fake = efficient_net(gen_imgs_gen)

        # feed efficient net features through CSM
        features_fake = csm_forward(features_fake)

        gen_loss = 0.
        gen_optim.zero_grad()

        for feature_fake, disc in zip(features_fake, discs):
            y_hat = disc(feature_fake)
            y_hat = y_hat.sum(1)
            gen_loss = -torch.mean(y_hat)
        gen_loss.backward()
        gen_optim.step()
        gen_losses.append(gen_loss)

        if i % log_every == 0:
            path = os.path.join(checkpoint_path, str(epoch))
            if not os.path.exists(path):
                os.mkdir(path)
            with torch.no_grad():
                vutils.save_image(batch_gen_imgs.add(1).mul(0.5), os.path.join(ckpts_outputs_path, f'{epoch}_{i}.jpg'), nrow=4)
            logging.info(f"Iteration {i}: Gen Loss = {gen_loss}, Disc Loss = {disc_losses}.")
            torch.save(gen.state_dict(), os.path.join(path, "Generator.pth"))
            for j in range(len(discs)):
                torch.save(discs[j].state_dict(), os.path.join(path, f"Discriminator_{j}.pth"))
                torch.save(csms[j].state_dict(), os.path.join(path, f"CSM_{j}.pth"))


Images Generator: starting from a weights path generates n_images with images_size=256x256

In [None]:
image_size=256
weights_path="afhq-dog/44/Generator.pth"
verbose=True
out_dir="./afhq-dog-generated"
n_images=100
latent_dim=100
mode=0
grid_size=8

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Using device : "+ device)

# Loading state dict
gen = Generator(im_size=image_size)
gen.load_state_dict(torch.load(weights_path))
gen = gen.to(device)
if verbose:
    print("Successfully loaded generator's state dict at : " + weights_path)
# Creating the output dir
if not os.path.exists(out_dir):
    if verbose:
        print("Creating the output dir : ",out_dir)
    os.mkdir(out_dir)

#Initializing list for grid
gen_imgs_list=[]
fake_imgs=[]

for i in tqdm(range(1,n_images + 1), position=0, leave=True):
    
    noise = torch.randn(latent_dim,1,1,device=device)
    #Generating single image
    gen_img = gen(noise.unsqueeze(0))
    
    #Saving the indivudual image
    if mode in [0,2]:
        img_path = os.path.join(out_dir,f"img-{i}.jpg")
        save_image(gen_img, img_path)
        fake_imgs.append(gen_img)
        
    
    # Saving grid of images (N x N)
    if mode in [1,2]:
        gen_imgs_list.append(gen_img.detach().cpu().squeeze(0))
        if len(gen_imgs_list) == grid_size ** 2:
            grid_img_path=os.path.join(out_dir,f"grid-imgs-{(i - grid_size ** 2 ) + 1}-{i}.jpg")
            save_image( gen_imgs_list, grid_img_path,nrow=grid_size)
            gen_imgs_list = []

# If there are remaining images in the grid array
if len(gen_imgs_list) and mode in [1,2]:
    save_image(gen_imgs_list ,os.path.join(out_dir,f"grid-imgs-{i - len(gen_imgs_list)}-{i}-remaining_batch.jpg"),nrow=grid_size)

print("Done.")


In [None]:
fake_batch = next(iter(gen_imgs_list))
plt.figure(figsize=(2,2))
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(fake_batch, padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

Install pytorch fid packet 

In [None]:
!pip install pytorch-fid

In [None]:
!python -m pytorch_fid train/cat-train/cat/ train/cat-train/cat/ --device cpu