Author: Anurag Vaidya  
Date: 2/18/2022  
Lab: Polina Lab @ CSAIL  
Purpose: Create an encoer-generator (stylegan) connected to a clf. Trying to recreate the method from [StyleEx](https://arxiv.org/pdf/2104.13369.pdf)

## Notebook Structure
- Imports
- Args
- Dataset
- Model
- Training/ Val loop
- main()

---

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn 
import torch.nn.functional as Fun
import torchvision.transforms.functional as F
from torch.utils.data import Dataset
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import autograd

import os
import sys
import random
from argparse import Namespace
import time, copy
import math

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import wandb
# wandb.init(project="cat-dog-styleSpace", entity="stylespace")


sys.path.append("./")
sys.path.append("../")

---

#### Args

In [None]:
args = Namespace(device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu"),
                 train_dir = "../data/afhq/train",
                 val_dir = "../data/afhq/val",
                 save_path = "./checkpoints",
                 log_dir = "./",
                 seed = 7,
                 labels = ["cat", "dog"],
                 batch_size = 64,
                 test_batch_size = 64,
                 epochs = 50,
                 num_workers = 0,
                 class_names = {0:"cat", 1:"dog"} ,
                 lr = 0.0001,
                 lr_d = 0.0004,
                 momentum = 0.9,
                 criterion = nn.CrossEntropyLoss(),
                 optim_name = "ranger",
                 scheduler = "STEP",
                 scheduler_step_size = 7,
                 scheduler_gamma = 0.1,
                 exp_name = "stylespace1",
                 wandb_config = {"learning_rate": 0.0001, "epochs": 2, "batch_size": 64},
                 use_wandb = True,
                 output_size = 512,
                 encoder_type = "gradual",
                 n_styles = 0,
                 lambdas = {"adv":1, "reg":1, "rec_x":1, "rec_w":1, "lpips":1, "clf":1},
                 train_decoder = True, # whether to train decoder,
                 dataset_type = "afhq",
                 max_steps = 50000, # max number of training steps,
                 save_interval = 5000, # checkpoint saving interval,
                 start_from_latent_avg = False, #Whether to add average latent vector to generate codes from encoder
                 learn_in_w = True, # Whether to learn in w space instead of w+,
                 img_size = 512, #image sizes for the model
                 channel_multiplier = 2, # channel multiplier factor for the model. config-f = 2, else = 1
    )

---

#### modex class: contains the model (sub-models are encoder, decoder, discriminator), optim, losses, train, val and util methods

In [None]:
from torch.utils.tensorboard import SummaryWriter

class modex:
	def __init__(self, args):
		self.args = args

		self.global_step = 0

		# TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES (use distributed module)
		self.device = self.args.device  

		if self.args.use_wandb:
			self.wb_logger = WBLogger(self.args)
		
		# Initialize network
		self.net = net(self.args).to(self.device)

		# Estimate latent_avg via dense sampling if latent_avg is not available
		if self.net.latent_avg is None:
			self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0].detach()

		# Initialize loss
		# adv loss
		if self.args.lambdas["adv"] > 0:
			self.adv_loss = adv_loss().to(self.device) 
		# path regularization
		if self.args.lambdas["reg"] > 0:
			self.reg_loss = path_reg_loss()
		# rec_x
		if self.args.lambdas["rec_x"] > 0:
			self.rec_x_loss = nn.L1Loss().to(self.device)
		# lpips
		if self.args.lambdas["lpips"] > 0:
			self.lpips_loss = LPIPS(net_type='alex').to(self.device)
		# rec_w
		if self.args.lambdas["rec_w"] > 0:
			self.rec_w_loss = nn.L1Loss().to(self.device)
		# clf
		if self.args.lambdas["clf"] > 0:
			self.clf_loss = clf_loss(self.args)

		# Initialize optimizer
		self.optimizer_g, self.optimizer_d = self.configure_optimizers()

		# Initialize dataset
		self.train_dataset, self.test_dataset = self.configure_datasets()
		self.train_dataloader = DataLoader(self.train_dataset,
											batch_size=self.args.batch_size,
											shuffle=True,
											num_workers=int(self.opts.workers),
											drop_last=True)
		self.test_dataloader = DataLoader(self.test_dataset,
											batch_size=self.args.test_batch_size,
											shuffle=False,
											num_workers=int(self.opts.test_workers),
											drop_last=True)

		# Initialize logger
		log_dir = os.path.join(self.args.log_dir, 'logs')
		os.makedirs(log_dir, exist_ok=True)
		self.logger = SummaryWriter(log_dir=log_dir)

		# Initialize checkpoint dir
		self.checkpoint_dir = os.path.join(self.args.log_dir, 'checkpoints', 'cat_dog_styleGAN')
		os.makedirs(self.checkpoint_dir, exist_ok=True)
		self.best_val_loss = None
		if self.args.save_interval is None:
			self.args.save_interval = self.args.max_steps

	def configure_optimizers(self):
		# encoder + decoder optim
		params_g = list(self.net.encoder.parameters())
		if self.args.train_decoder:
			params_g += list(self.net.decoder.parameters())
		if self.args.optim_name == 'adam':
			g_optimizer = torch.optim.Adam(params_g, lr=self.args.lr)
		else:
			g_optimizer = Ranger(params_g, lr=self.args.lr)

		# discriminator optim
		params_d = self.net.discriminator.parameters()
		d_optimizer = optim.Adam(params_d, lr=self.args.lr_d)

		return g_optimizer, d_optimizer

	def configure_datasets(self):
		print(f'Loading dataset for {self.args.dataset_type}')
		dataset_args = DATASETS[self.args.dataset_type]
		transforms_dict = dataset_args['transforms'](self.args).get_transforms()
		train_dataset = afhq_dataset(dataset_args["train_dir"],
									 dataset_args["seed"], 
									 dataset_args["labels"], 
									 transforms_dict["transform_train"])
									
		val_dataset = afhq_dataset(dataset_args["val_dir"],
									 dataset_args["seed"], 
									 dataset_args["labels"], 
									 transforms_dict["transform_val"])
		if self.args.use_wandb:
			self.wb_logger.log_dataset_wandb(train_dataset, dataset_name="Train")
			self.wb_logger.log_dataset_wandb(val_dataset, dataset_name="Val")
		print(f"Number of training samples: {len(train_dataset)}")
		print(f"Number of test samples: {len(val_dataset)}")
		return train_dataset, val_dataset

	def requires_grad(self, model, flag=True):
		for p in model.parameters():
			p.requires_grad = flag

	def train(self):
		self.net.train()
		mean_path_length = 0 
		while self.global_step < self.args.max_steps:

			for batch_idx, batch in enumerate(self.train_dataloader):

				x, y = batch["inputs"], batch["labels"]
				x, y = x.to(self.device).float(), y.to(self.device).float()

				self.optimizer_g.zero_grad()
				self.optimizer_d.zero_grad()

				########## Discriminator training ##########
				self.requires_grad(self.net.decoder, False)
				self.requires_grad(self.net.discriminator, True)
				
				# get fake img, return_latent = False
				y_hat = self.net.decoder(x)

				# get discriminator outputs
				fake_pred = self.net.discriminator(y_hat)
				real_pred = self.net.discriminator(x) 

				# get discriminator loss (pick up here)
				discriminator_loss, _, _ = self.calc_loss(x, y, y_hat, latent, fake_pred, real_pred, 
														  w_fake, w_real, mean_path_length, loss_type=["adv"])	

				# discriminator updates
				self.net.discriminator.zero_grad()	
				discriminator_loss.backward()
				self.optimizer_d.step()		

				########## Generator + encoder training ##########
				self.requires_grad(self.net.decoder, True)
				self.requires_grad(self.net.discriminator, False)

				# passing through encoder and generator (style gan)
				y_hat, latent = self.net.forward(x, return_latents=True)

				# get encodings
				self.net.encoder.eval()
				with torch.no_grad():
					w_fake = self.net.get_encodings(y_hat)
					w_real = self.net.get_encodings(x)
				self.net.encoder.train()
				
				# calculate losses
				which_loss = ["reg", "rec_x", "lpips", "rec_w", "clf"]
				loss, loss_dict, mean_path_length = self.calc_loss(x, y, y_hat, latent, fake_pred, real_pred,
												 w_fake, w_real, mean_path_length, loss_type=which_loss)

				# backward and step
				loss.backward()
				self.net.zero_grad()
				self.optimizer_g.step()

				# Logging related
				if self.global_step % self.opts.image_interval == 0 or (self.global_step < 1000 and self.global_step % 25 == 0):
					self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
				if self.global_step % self.opts.board_interval == 0:
					self.print_metrics(loss_dict, prefix='train')
					self.log_metrics(loss_dict, prefix='train')

				# Log images of first batch to wandb
				if self.opts.use_wandb and batch_idx == 0:
					self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="train", step=self.global_step, opts=self.opts)

				# Validation related
				val_loss_dict = None
				if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
					val_loss_dict = self.validate()
					if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
						self.best_val_loss = val_loss_dict['loss']
						self.checkpoint_me(val_loss_dict, is_best=True)

				if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
					if val_loss_dict is not None:
						self.checkpoint_me(val_loss_dict, is_best=False)
					else:
						self.checkpoint_me(loss_dict, is_best=False)

				if self.global_step == self.opts.max_steps:
					print('OMG, finished training!')
					break

				self.global_step += 1

	def calc_loss(self, x, y, y_hat, latent, fake_pred, real_pred, w_fake, w_real, mean_path_length, loss_type):
		r"""
		loss_type is a list
		"""
		loss_dict = {}
		loss = 0.0
		types = ["adv", "reg", "rec_x", "lpips", "rec_w", "clf"]
		

		for curr_loss_name in loss_type:
			assert loss_type in types, "Invalid loss name"
			# adversarial loss
			if curr_loss_name == "adv":
				loss_adv = self.adv_loss(real_pred, fake_pred)
				loss_dict["adv_loss"] = loss_adv
				loss += self.args.lambdas["adv"] * loss_adv
			# path regularization
			if curr_loss_name == "reg":
				loss_reg, mean_path_length, path_lengths = self.reg_loss(y_hat, latent, mean_path_length)
				loss_dict["reg"] = loss_reg
				loss += self.args.lambdas["reg"] * loss_reg
			# rec_x
			if curr_loss_name == "rec_x":
				loss_rec_x = self.rec_x_loss(x, y_hat)
				loss_dict["rec_x"] = loss_rec_x
				loss += self.args.lambdas["rec_x"] * loss_rec_x
			# lpips
			if curr_loss_name == "lpips":
				loss_lpips = self.lpips_loss(x, y_hat)
				loss_dict["lpips"] = loss_lpips
				loss += self.args.lambdas["lpips"] * loss_lpips
			# rec_w
			if curr_loss_name == "rec_w":
				loss_rec_w = self.rec_w_loss(w_fake, w_real)
				loss_dict["rec_w"] = loss_rec_w
				loss += self.args.lambdas["rec_w"] * loss_rec_w
			# clf
			if curr_loss_name == "clf":
				loss_clf = self.clf_loss(x, y_hat) 
				loss_dict["clf"] = loss_clf
				loss += self.args.lambdas["clf"] * loss_clf

		loss_dict['loss'] = float(loss)
		return loss, loss_dict, mean_path_length

---

#### Dataset class

In [None]:
# transform config
from abc import abstractmethod
import torchvision.transforms as transforms


class TransformsConfig(object):

	def __init__(self, opts):
		self.opts = opts

	@abstractmethod
	def get_transforms(self):
		pass


class afhq_Transforms(TransformsConfig):

	def __init__(self, args):
		super(afhq_Transforms, self).__init__(args)

	def get_transforms(self):
		transforms_dict = {
			'transform_train': transforms.Compose([
				transforms.Resize((512)),
				transforms.RandomHorizontalFlip(0.5),
				transforms.ToTensor(),
				transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
			'transform_val': transforms.Compose([
				transforms.Resize((512)),
				transforms.ToTensor(),
				transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
			'transform_inference': transforms.Compose([
				transforms.Resize((256, 256)),
				transforms.ToTensor(),
				transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
		}
		return transforms_dict


In [None]:
# data config
DATASETS = {
    "afhq": {
        'transforms': afhq_Transforms,
        'train_dir': "../data/afhq/train",
        "val_dir": "../data/afhq/val",
        'seed': 69,
        'labels': ["cat", "dog"],  
    }
}

In [None]:
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import sys
sys.path.append("./")
sys.path.append("../")

class afhq_dataset(Dataset):
    r"""
    Take a root dir and return the transformed img and associated label with it
    """
    def __init__(self, root_dir, seed, labels, img_transform=None):

        self.seed = seed
        np.random.seed(self.seed)

        # this dir has two sub dirs cat and dog. Need to combine them
        self.root_dir = root_dir
        self.cat_names = os.listdir(os.path.join(self.root_dir, "cat"))
        self.dog_names = os.listdir(os.path.join(self.root_dir, "dog"))
        self.all_names = np.asarray(self.cat_names + self.dog_names)
        np.random.shuffle(self.all_names)
        self.img_transform = img_transform
        self.labels = {}
        for i in range(len(labels)):
            self.labels[labels[i]] = i
        

    def __len__(self):
        return len(self.all_names)

    def __getitem__(self, idx):
        curr_path = os.path.join(self.root_dir, self.all_names[idx].strip().split("_")[1], self.all_names[idx])
        curr_img = Image.open(curr_path)
        curr_label = self.labels[self.all_names[idx].strip().split("_")[1]]
        
        if self.img_transform:
            curr_img_transformed = self.img_transform(curr_img)
        
        return {"inputs" : curr_img_transformed, "labels" : curr_label} 
    
    def viz_img(self, imgs):
        r"""
        Take a tensor or list of tensors and visualize it
        """
        if not isinstance(imgs, list):
            imgs = [imgs]
        fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            img = F.to_pil_image(img)
            axs[0, i].imshow(np.asarray(img))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    
    def what_labels_mean(self):
        return [label + ": " + str(self.labels[label]) for label in self.labels]

---

#### Ranger optimizer

In [None]:
# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.

# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
# and/or
# https://github.com/lessw2020/Best-Deep-Learning-Optimizers

# Ranger has now been used to capture 12 records on the FastAI leaderboard.

# This version = 20.4.11

# Credits:
# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github:  https://github.com/Yonghongwei/Gradient-Centralization
# RAdam -->  https://github.com/LiyuanLucasLiu/RAdam
# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
# Lookahead paper --> MZhang,G Hinton  https://arxiv.org/abs/1907.08610

# summary of changes:
# 4/11/20 - add gradient centralization option.  Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
# changes 8/31/19 - fix references to *self*.N_sma_threshold;
# changed eps to 1e-5 as better default than 1e-8.

import math
import torch
from torch.optim.optimizer import Optimizer


class Ranger(Optimizer):

	def __init__(self, params, lr=1e-3,  # lr
				 alpha=0.5, k=6, N_sma_threshhold=5,  # Ranger options
				 betas=(.95, 0.999), eps=1e-5, weight_decay=0,  # Adam options
				 use_gc=True, gc_conv_only=False
				 # Gradient centralization on or off, applied to conv layers only or conv + fc layers
				 ):

		# parameter checks
		if not 0.0 <= alpha <= 1.0:
			raise ValueError(f'Invalid slow update rate: {alpha}')
		if not 1 <= k:
			raise ValueError(f'Invalid lookahead steps: {k}')
		if not lr > 0:
			raise ValueError(f'Invalid Learning Rate: {lr}')
		if not eps > 0:
			raise ValueError(f'Invalid eps: {eps}')

		# parameter comments:
		# beta1 (momentum) of .95 seems to work better than .90...
		# N_sma_threshold of 5 seems better in testing than 4.
		# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.

		# prep defaults and init torch.optim base
		defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
						eps=eps, weight_decay=weight_decay)
		super().__init__(params, defaults)

		# adjustable threshold
		self.N_sma_threshhold = N_sma_threshhold

		# look ahead params

		self.alpha = alpha
		self.k = k

		# radam buffer for state
		self.radam_buffer = [[None, None, None] for ind in range(10)]

		# gc on or off
		self.use_gc = use_gc

		# level of gradient centralization
		self.gc_gradient_threshold = 3 if gc_conv_only else 1

	def __setstate__(self, state):
		super(Ranger, self).__setstate__(state)

	def step(self, closure=None):
		loss = None

		# Evaluate averages and grad, update param tensors
		for group in self.param_groups:

			for p in group['params']:
				if p.grad is None:
					continue
				grad = p.grad.data.float()

				if grad.is_sparse:
					raise RuntimeError('Ranger optimizer does not support sparse gradients')

				p_data_fp32 = p.data.float()

				state = self.state[p]  # get state dict for this param

				if len(state) == 0:  # if first time to run...init dictionary with our desired entries
					# if self.first_run_check==0:
					# self.first_run_check=1
					# print("Initializing slow buffer...should not see this at load from saved model!")
					state['step'] = 0
					state['exp_avg'] = torch.zeros_like(p_data_fp32)
					state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

					# look ahead weight storage now in state dict
					state['slow_buffer'] = torch.empty_like(p.data)
					state['slow_buffer'].copy_(p.data)

				else:
					state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
					state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

				# begin computations
				exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
				beta1, beta2 = group['betas']

				# GC operation for Conv layers and FC layers
				if grad.dim() > self.gc_gradient_threshold:
					grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))

				state['step'] += 1

				# compute variance mov avg
				exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
				# compute mean moving avg
				exp_avg.mul_(beta1).add_(1 - beta1, grad)

				buffered = self.radam_buffer[int(state['step'] % 10)]

				if state['step'] == buffered[0]:
					N_sma, step_size = buffered[1], buffered[2]
				else:
					buffered[0] = state['step']
					beta2_t = beta2 ** state['step']
					N_sma_max = 2 / (1 - beta2) - 1
					N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
					buffered[1] = N_sma
					if N_sma > self.N_sma_threshhold:
						step_size = math.sqrt(
							(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
										N_sma_max - 2)) / (1 - beta1 ** state['step'])
					else:
						step_size = 1.0 / (1 - beta1 ** state['step'])
					buffered[2] = step_size

				if group['weight_decay'] != 0:
					p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

				# apply lr
				if N_sma > self.N_sma_threshhold:
					denom = exp_avg_sq.sqrt().add_(group['eps'])
					p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
				else:
					p_data_fp32.add_(-step_size * group['lr'], exp_avg)

				p.data.copy_(p_data_fp32)

				# integrated look ahead...
				# we do it at the param level instead of group level
				if state['step'] % group['k'] == 0:
					slow_p = state['slow_buffer']  # get access to slow param tensor
					slow_p.add_(self.alpha, p.data - slow_p)  # (fast weights - slow weights) * alpha
					p.data.copy_(slow_p)  # copy interpolated weights to RAdam param tensor

		return loss

---

#### Criteria/ losses

In [None]:
# lpips

from collections import OrderedDict

import torch


def normalize_activation(x, eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
    return x / (norm_factor + eps)


def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
    # build url
    url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
        + f'master/lpips/weights/v{version}/{net_type}.pth'

    # download
    old_state_dict = torch.hub.load_state_dict_from_url(
        url, progress=True,
        map_location=None if torch.cuda.is_available() else torch.device('cpu')
    )

    # rename keys
    new_state_dict = OrderedDict()
    for key, val in old_state_dict.items():
        new_key = key
        new_key = new_key.replace('lin', '')
        new_key = new_key.replace('model.', '')
        new_state_dict[new_key] = val

    return new_state_dict

####################################

from typing import Sequence

from itertools import chain

import torch
import torch.nn as nn
from torchvision import models


def get_network(net_type: str):
    if net_type == 'alex':
        return AlexNet()
    elif net_type == 'squeeze':
        return SqueezeNet()
    elif net_type == 'vgg':
        return VGG16()
    else:
        raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')


class LinLayers(nn.ModuleList):
    def __init__(self, n_channels_list: Sequence[int]):
        super(LinLayers, self).__init__([
            nn.Sequential(
                nn.Identity(),
                nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
            ) for nc in n_channels_list
        ])

        for param in self.parameters():
            param.requires_grad = False


class BaseNet(nn.Module):
    def __init__(self):
        super(BaseNet, self).__init__()

        # register buffer
        self.register_buffer(
            'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer(
            'std', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def set_requires_grad(self, state: bool):
        for param in chain(self.parameters(), self.buffers()):
            param.requires_grad = state

    def z_score(self, x: torch.Tensor):
        return (x - self.mean) / self.std

    def forward(self, x: torch.Tensor):
        x = self.z_score(x)

        output = []
        for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
            x = layer(x)
            if i in self.target_layers:
                output.append(normalize_activation(x))
            if len(output) == len(self.target_layers):
                break
        return output


class SqueezeNet(BaseNet):
    def __init__(self):
        super(SqueezeNet, self).__init__()

        self.layers = models.squeezenet1_1(True).features
        self.target_layers = [2, 5, 8, 10, 11, 12, 13]
        self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]

        self.set_requires_grad(False)


class AlexNet(BaseNet):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.layers = models.alexnet(True).features
        self.target_layers = [2, 5, 8, 10, 12]
        self.n_channels_list = [64, 192, 384, 256, 256]

        self.set_requires_grad(False)


class VGG16(BaseNet):
    def __init__(self):
        super(VGG16, self).__init__()

        self.layers = models.vgg16(True).features
        self.target_layers = [4, 9, 16, 23, 30]
        self.n_channels_list = [64, 128, 256, 512, 512]

        self.set_requires_grad(False)


##############################
import torch
import torch.nn as nn

class LPIPS(nn.Module):
    r"""Creates a criterion that measures
    Learned Perceptual Image Patch Similarity (LPIPS).
    Arguments:
        net_type (str): the network type to compare the features:
                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
        version (str): the version of LPIPS. Default: 0.1.
    """
    def __init__(self, net_type: str = 'alex', version: str = '0.1'):

        assert version in ['0.1'], 'v0.1 is only supported now'

        super(LPIPS, self).__init__()

        # pretrained network
        self.net = get_network(net_type).to("cuda")

        # linear layers
        self.lin = LinLayers(self.net.n_channels_list).to("cuda")
        self.lin.load_state_dict(get_state_dict(net_type, version))

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        feat_x, feat_y = self.net(x), self.net(y)

        diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
        res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]

        return torch.sum(torch.cat(res, 0)) / x.shape[0]

############################
# adv loss

class adv_loss(nn.Module):
    def __init__(self):
        super(adv_loss, self).__init__()

    def forward(self, real_pred: torch.Tensor, fake_pred: torch.Tensor):
        real_loss = F.softplus(-real_pred)
        fake_loss = F.softplus(fake_pred)

        return real_loss.mean() + fake_loss.mean()


#############################
# clf loss

class clf_loss(nn.Module):
    def __init__(self, args, num_classes = 2, network = "resnet", path_to_weights = "./checkpoint/checkpoint_2.pt"):
        
        super(clf_loss, self).__init__()

        self.model_ft = models.resnet18(pretrained=False)
        self.num_ftrs = self.model_ft.fc.in_features
        self.model_ft.fc = nn.Linear(self.num_ftrs, num_classes)

        checkpoint = torch.load(path_to_weights)
        self.model_ft.load_state_dict(checkpoint['model_state_dict'])

        self.model_ft.eval()

        self.loss_func = nn.KLDivLoss().to(args.device)
        

    def forward(self, x: torch.Tensor, y_hat: torch.Tensor):
        return self.loss_func(self.model_ft(x), self.model_ft(y_hat))

#############################
# path reg loss

class path_reg_loss(nn.Module):
    def __init__(self):
        
        super(path_reg_loss, self).__init__()

    def forward(self, fake_img, latents, mean_path_length, decay=0.01):
        noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
        grad, = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)
        path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

        path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

        path_penalty = (path_lengths - path_mean).pow(2).mean()

        return path_penalty, path_mean.detach(), path_lengths

#### net class: this class will encapsulate all of the sub-models in it

In [None]:
import math

def get_keys(d, name):
	if 'state_dict' in d:
		d = d['state_dict']
	d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
	return d_filt


class net(nn.Module):

	def __init__(self, args):
		super(net, self).__init__()
		self.args = args

		# compute number of style inputs based on the output resolution
		self.args.n_styles = int(math.log(self.args.output_size, 2)) * 2 - 2

		# Define architecture
		self.encoder = self.set_encoder()
		
		# define generator 
		self.decoder = Generator(self.opts.output_size, 512, 8)
		self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))

		# define discrmiinator
		self.discriminator = Discriminator(self.args.img_size, self.channel_multiplier)

		# define the classifier 
		self.classifier = Classifier(self.args)

		self.latent_avg = None

		# Load weights if needed -> we are going to train from scratch
		# self.load_weights()

	def set_encoder(self):
		if self.opts.encoder == 'gradual':
			encoder = GradualStyleEncoder(50, 'ir_se', self.opts)
		
		return encoder

	# def load_weights(self):
	# 	if self.opts.checkpoint_path is not None:
	# 		print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
	# 		ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
	# 		self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
	# 		self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
	# 		self.__load_latent_avg(ckpt)
	# 	else:
	# 		print('Loading encoders weights from irse50!')
	# 		encoder_ckpt = torch.load(model_paths['ir_se50'])
	# 		# if input to encoder is not an RGB image, do not load the input layer weights
	# 		if self.opts.label_nc != 0:
	# 			encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
	# 		self.encoder.load_state_dict(encoder_ckpt, strict=False)
	# 		print('Loading decoder weights from pretrained!')
	# 		ckpt = torch.load(self.opts.stylegan_weights)
	# 		self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
	# 		if self.opts.learn_in_w:
	# 			self.__load_latent_avg(ckpt, repeat=1)
	# 		else:
	# 			self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)

	def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
	            inject_latent=None, return_latents=False, alpha=None):
		if input_code:
			codes = x
		else:
			codes = self.encoder(x)
			# normalize with respect to the center of an average face
			# if True then will need pretrained weights
			if self.args.start_from_latent_avg:
				if self.args.learn_in_w:
					codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
				else:
					codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)


		if latent_mask is not None:
			for i in latent_mask:
				if inject_latent is not None:
					if alpha is not None:
						codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
					else:
						codes[:, i] = inject_latent[:, i]
				else:
					codes[:, i] = 0

		input_is_latent = not input_code

		### get clf output and concatenate it to the encoder output
		clf_out = self.classifier(x)
		codes = torch.cat([codes, x]) # mostly wrong, but troubleshoot as you run the code. Need to know size of codes
		
		images, result_latent = self.decoder([codes],
		                                     input_is_latent=input_is_latent,
		                                     randomize_noise=randomize_noise,
		                                     return_latents=return_latents)

		if resize:
			images = self.face_pool(images)

		if return_latents:
			return images, result_latent
		else:
			return images

	def get_encodings(self, x):
		r"""
		Get the encoding of x. Before coming here, encoder should be set to eval and no_grad should be used
		"""
		return self.encoder(x)

	# def __load_latent_avg(self, ckpt, repeat=None):
	# 	if 'latent_avg' in ckpt:
	# 		self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
	# 		if repeat is not None:
	# 			self.latent_avg = self.latent_avg.repeat(repeat, 1)
	# 	else:
	# 		self.latent_avg = None

In [None]:
# classifier class
class Classifier(nn.Module):
    def __init__(self, args, num_classes = 2, network = "resnet", path_to_weights = "./checkpoint/checkpoint_2.pt"):

        self.model_ft = models.resnet18(pretrained=False)
        self.num_ftrs = self.model_ft.fc.in_features
        self.model_ft.fc = nn.Linear(self.num_ftrs, num_classes)

        checkpoint = torch.load(path_to_weights)
        self.model_ft.load_state_dict(checkpoint['model_state_dict'])

        self.model_ft.eval()

    def forward(self, x: torch.Tensor):
        return self.model_ft(x)

In [None]:
# encoder 

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module

class GradualStyleBlock(Module):
    def __init__(self, in_c, out_c, spatial):
        super(GradualStyleBlock, self).__init__()
        self.out_c = out_c
        self.spatial = spatial
        num_pools = int(np.log2(spatial))
        modules = []
        modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
                    nn.LeakyReLU()]
        for i in range(num_pools - 1):
            modules += [
                Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
                nn.LeakyReLU()
            ]
        self.convs = nn.Sequential(*modules)
        self.linear = EqualLinear(out_c, out_c, lr_mul=1)

    def forward(self, x):
        x = self.convs(x)
        x = x.view(-1, self.out_c)
        x = self.linear(x)
        return x


class GradualStyleEncoder(Module):
    def __init__(self, num_layers, mode='ir', opts=None):
        super(GradualStyleEncoder, self).__init__()
        assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
        assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
        blocks = get_blocks(num_layers)
        if mode == 'ir':
            unit_module = bottleneck_IR
        elif mode == 'ir_se':
            unit_module = bottleneck_IR_SE
        self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
                                      BatchNorm2d(64),
                                      PReLU(64))
        modules = []
        for block in blocks:
            for bottleneck in block:
                modules.append(unit_module(bottleneck.in_channel,
                                           bottleneck.depth,
                                           bottleneck.stride))
        self.body = Sequential(*modules)

        self.styles = nn.ModuleList()
        self.style_count = opts.n_styles
        self.coarse_ind = 3
        self.middle_ind = 7
        for i in range(self.style_count):
            if i < self.coarse_ind:
                style = GradualStyleBlock(512, 512, 16)
            elif i < self.middle_ind:
                style = GradualStyleBlock(512, 512, 32)
            else:
                style = GradualStyleBlock(512, 512, 64)
            self.styles.append(style)
        self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)

    def _upsample_add(self, x, y):
        '''Upsample and add two feature maps.
        Args:
          x: (Variable) top feature map to be upsampled.
          y: (Variable) lateral feature map.
        Returns:
          (Variable) added feature map.
        Note in PyTorch, when input size is odd, the upsampled feature map
        with `F.upsample(..., scale_factor=2, mode='nearest')`
        maybe not equal to the lateral feature map size.
        e.g.
        original input size: [N,_,15,15] ->
        conv2d feature map size: [N,_,8,8] ->
        upsampled feature map size: [N,_,16,16]
        So we choose bilinear upsample which supports arbitrary output sizes.
        '''
        _, _, H, W = y.size()
        return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y

    def forward(self, x):
        x = self.input_layer(x)

        latents = []
        modulelist = list(self.body._modules.values())
        for i, l in enumerate(modulelist):
            x = l(x)
            if i == 6:
                c1 = x
            elif i == 20:
                c2 = x
            elif i == 23:
                c3 = x

        for j in range(self.coarse_ind):
            latents.append(self.styles[j](c3))

        p2 = self._upsample_add(c3, self.latlayer1(c2))
        for j in range(self.coarse_ind, self.middle_ind):
            latents.append(self.styles[j](p2))

        p1 = self._upsample_add(p2, self.latlayer2(c1))
        for j in range(self.middle_ind, self.style_count):
            latents.append(self.styles[j](p1))

        out = torch.stack(latents, dim=1)
        return out

In [None]:
# styleGan2

class EqualLinear(nn.Module):
    def __init__(
            self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activation:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)

        else:
            out = F.linear(
                input, self.weight * self.scale, bias=self.bias * self.lr_mul
            )

        return out

In [None]:
import contextlib
import warnings

import torch
from torch import autograd
from torch.nn import functional as F

enabled = True
weight_gradients_disabled = False


@contextlib.contextmanager
def no_weight_gradients():
    global weight_gradients_disabled

    old = weight_gradients_disabled
    weight_gradients_disabled = True
    yield
    weight_gradients_disabled = old


def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    if could_use_op(input):
        return conv2d_gradfix(
            transpose=False,
            weight_shape=weight.shape,
            stride=stride,
            padding=padding,
            output_padding=0,
            dilation=dilation,
            groups=groups,
        ).apply(input, weight, bias)

    return F.conv2d(
        input=input,
        weight=weight,
        bias=bias,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
    )


def conv_transpose2d(
    input,
    weight,
    bias=None,
    stride=1,
    padding=0,
    output_padding=0,
    groups=1,
    dilation=1,
):
    if could_use_op(input):
        return conv2d_gradfix(
            transpose=True,
            weight_shape=weight.shape,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            groups=groups,
            dilation=dilation,
        ).apply(input, weight, bias)

    return F.conv_transpose2d(
        input=input,
        weight=weight,
        bias=bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
        dilation=dilation,
        groups=groups,
    )


def could_use_op(input):
    if (not enabled) or (not torch.backends.cudnn.enabled):
        return False

    if input.device.type != "cuda":
        return False

    if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
        return True

    warnings.warn(
        f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
    )

    return False


def ensure_tuple(xs, ndim):
    xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim

    return xs


conv2d_gradfix_cache = dict()


def conv2d_gradfix(
    transpose, weight_shape, stride, padding, output_padding, dilation, groups
):
    ndim = 2
    weight_shape = tuple(weight_shape)
    stride = ensure_tuple(stride, ndim)
    padding = ensure_tuple(padding, ndim)
    output_padding = ensure_tuple(output_padding, ndim)
    dilation = ensure_tuple(dilation, ndim)

    key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
    if key in conv2d_gradfix_cache:
        return conv2d_gradfix_cache[key]

    common_kwargs = dict(
        stride=stride, padding=padding, dilation=dilation, groups=groups
    )

    def calc_output_padding(input_shape, output_shape):
        if transpose:
            return [0, 0]

        return [
            input_shape[i + 2]
            - (output_shape[i + 2] - 1) * stride[i]
            - (1 - 2 * padding[i])
            - dilation[i] * (weight_shape[i + 2] - 1)
            for i in range(ndim)
        ]

    class Conv2d(autograd.Function):
        @staticmethod
        def forward(ctx, input, weight, bias):
            if not transpose:
                out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)

            else:
                out = F.conv_transpose2d(
                    input=input,
                    weight=weight,
                    bias=bias,
                    output_padding=output_padding,
                    **common_kwargs,
                )

            ctx.save_for_backward(input, weight)

            return out

        @staticmethod
        def backward(ctx, grad_output):
            input, weight = ctx.saved_tensors
            grad_input, grad_weight, grad_bias = None, None, None

            if ctx.needs_input_grad[0]:
                p = calc_output_padding(
                    input_shape=input.shape, output_shape=grad_output.shape
                )
                grad_input = conv2d_gradfix(
                    transpose=(not transpose),
                    weight_shape=weight_shape,
                    output_padding=p,
                    **common_kwargs,
                ).apply(grad_output, weight, None)

            if ctx.needs_input_grad[1] and not weight_gradients_disabled:
                grad_weight = Conv2dGradWeight.apply(grad_output, input)

            if ctx.needs_input_grad[2]:
                grad_bias = grad_output.sum((0, 2, 3))

            return grad_input, grad_weight, grad_bias

    class Conv2dGradWeight(autograd.Function):
        @staticmethod
        def forward(ctx, grad_output, input):
            op = torch._C._jit_get_operation(
                "aten::cudnn_convolution_backward_weight"
                if not transpose
                else "aten::cudnn_convolution_transpose_backward_weight"
            )
            flags = [
                torch.backends.cudnn.benchmark,
                torch.backends.cudnn.deterministic,
                torch.backends.cudnn.allow_tf32,
            ]
            grad_weight = op(
                weight_shape,
                grad_output,
                input,
                padding,
                stride,
                dilation,
                groups,
                *flags,
            )
            ctx.save_for_backward(grad_output, input)

            return grad_weight

        @staticmethod
        def backward(ctx, grad_grad_weight):
            grad_output, input = ctx.saved_tensors
            grad_grad_output, grad_grad_input = None, None

            if ctx.needs_input_grad[0]:
                grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)

            if ctx.needs_input_grad[1]:
                p = calc_output_padding(
                    input_shape=input.shape, output_shape=grad_output.shape
                )
                grad_grad_input = conv2d_gradfix(
                    transpose=(not transpose),
                    weight_shape=weight_shape,
                    output_padding=p,
                    **common_kwargs,
                ).apply(grad_output, grad_grad_weight, None)

            return grad_grad_output, grad_grad_input

    conv2d_gradfix_cache[key] = Conv2d

    return Conv2d

In [None]:
# generator and discriminator layers
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)

class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)

        return out
        
def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)

    if k.ndim == 1:
        k = k[None, :] * k[:, None]

    k /= k.sum()

    return k

class EqualConv2d(nn.Module):
    def __init__(
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
    ):
        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))

        else:
            self.bias = None

    def forward(self, input):
        out = conv2d_gradfix.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
        )

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
            f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
        )

class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
    ):
        layers = []

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

            stride = 2
            self.padding = 0

        else:
            stride = 1
            self.padding = kernel_size // 2

        layers.append(
            EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                padding=self.padding,
                stride=stride,
                bias=bias and not activate,
            )
        )

        if activate:
            layers.append(FusedLeakyReLU(out_channel, bias=bias))

        super().__init__(*layers)


class Blur(nn.Module):
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()

        kernel = make_kernel(kernel)

        if upsample_factor > 1:
            kernel = kernel * (upsample_factor ** 2)

        self.register_buffer('kernel', kernel)

        self.pad = pad

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out

class ModulatedConv2d(nn.Module):
    def __init__(
            self,
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            demodulate=True,
            upsample=False,
            downsample=False,
            blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

        self.demodulate = demodulate

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
            f'upsample={self.upsample}, downsample={self.downsample})'
        )

    def forward(self, input, style):
        batch, in_channel, height, width = input.shape

        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(
                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
            )
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out

class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1))

    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = image.new_empty(batch, 1, height, width).normal_()

        return image + self.weight * noise

class StyledConv(nn.Module):
    def __init__(
            self,
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=False,
            blur_kernel=[1, 3, 3, 1],
            demodulate=True,
    ):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )

        self.noise = NoiseInjection()
        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
        # self.activate = ScaledLeakyReLU(0.2)
        self.activate = FusedLeakyReLU(out_channel)

    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        # out = out + self.bias
        out = self.activate(out)

        return out

class Upsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel) * (factor ** 2)
        self.register_buffer('kernel', kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)

        return out

class ToRGB(nn.Module):
    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        if upsample:
            self.upsample = Upsample(blur_kernel)

        self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias

        if skip is not None:
            skip = self.upsample(skip)

            out = out + skip

        return out

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
        )

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)

        return out

class Generator(nn.Module):
    def __init__(
            self,
            size,
            style_dim,
            n_mlp,
            channel_multiplier=2,
            blur_kernel=[1, 3, 3, 1],
            lr_mlp=0.01,
    ):
        super().__init__()

        self.size = size

        self.style_dim = style_dim

        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(
                EqualLinear(
                    style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
                )
            )

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
        )
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
                )
            )

            self.convs.append(
                StyledConv(
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
                )
            )

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2

    def make_noise(self):
        device = self.input.input.device

        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]

        for i in range(3, self.log_size + 1):
            for _ in range(2):
                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))

        return noises

    def mean_latent(self, n_latent):
        latent_in = torch.randn(
            n_latent, self.style_dim, device=self.input.input.device
        )
        latent = self.style(latent_in).mean(0, keepdim=True)

        return latent

    def get_latent(self, input):
        return self.style(input)

    def forward(
            self,
            styles,
            return_latents=False,
            return_features=False,
            inject_index=None,
            truncation=1,
            truncation_latent=None,
            input_is_latent=False,
            noise=None,
            randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]

        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
                    getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
                ]

        if truncation < 1:
            style_t = []

            for style in styles:
                style_t.append(
                    truncation_latent + truncation * (style - truncation_latent)
                )

            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            else:
                latent = styles[0]

        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])

        skip = self.to_rgb1(out, latent[:, 1])

        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(
                self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
        ):
            out = conv1(out, latent[:, i], noise=noise1)
            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)

            i += 2

        image = skip

        if return_latents:
            return image, latent
        elif return_features:
            return image, out
        else:
            return image, None

In [None]:
# define discriminator 
class Discriminator(nn.Module):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        convs = [ConvLayer(3, channels[size], 1)]

        log_size = int(math.log(size, 2))

        in_channel = channels[size]

        for i in range(log_size, 2, -1):
            out_channel = channels[2 ** (i - 1)]

            convs.append(ResBlock(in_channel, out_channel, blur_kernel))

            in_channel = out_channel

        self.convs = nn.Sequential(*convs)

        self.stddev_group = 4
        self.stddev_feat = 1

        self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
        self.final_linear = nn.Sequential(
            EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
            EqualLinear(channels[4], 1),
        )

    def forward(self, input):
        out = self.convs(input)

        batch, channel, height, width = out.shape
        group = min(batch, self.stddev_group)
        stddev = out.view(
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        )
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)
        out = torch.cat([out, stddev], 1)

        out = self.final_conv(out)

        out = out.view(batch, -1)
        out = self.final_linear(out)

        return out

#### Utils

In [None]:
import datetime

class WBLogger:

    def __init__(self, args):
        wandb_run_name = args.exp_name
        wandb.init(project="style_space_run", config=vars(args.wandb_config), name=wandb_run_name, entity="stylespace")

    @staticmethod
    def log_best_model():
        wandb.run.summary["best-model-save-time"] = datetime.datetime.now()

    @staticmethod
    def log_dataset_wandb(dataset, dataset_name, device, n_images=16):
        idxs = np.random.choice(a=range(len(dataset)), size=n_images, replace=False)
        if device == "cpu":
            data = [wandb.Image(dataset[idx]["inputs"].numpy().transpose(1, 2, 0)) for idx in idxs]
        else:
            data = [wandb.Image(dataset[idx]["inputs"].detach().cpu().numpy().transpose(1, 2, 0)) for idx in idxs]
            
        wandb.log({f"{dataset_name} Data Samples": data})

    @staticmethod
    # method should log all relevant metrics for a run
    # TODO: update after code is written
    def log(prefix, metrics_dict, global_step):
        log_dict = {f'{prefix}_{key}': value for key, value in metrics_dict.items()} # figure out what metrics dict is
        log_dict["global_step"] = global_step
        wandb.log(log_dict)

#### ops for stylegan2

In [None]:
from torch.autograd import Function
from torch.utils.cpp_extension import load

module_path = os.path.dirname(__file__) # dirname will get the parent directory name of the current file (__file__)
# TODO: need these fused_bias files
fused = load(
    'fused',
    sources=[
        os.path.join(module_path, 'fused_bias_act.cpp'),
        os.path.join(module_path, 'fused_bias_act_kernel.cu'),
    ],
)

class FusedLeakyReLUFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, out, negative_slope, scale):
        ctx.save_for_backward(out)
        ctx.negative_slope = negative_slope
        ctx.scale = scale

        empty = grad_output.new_empty(0)

        grad_input = fused.fused_bias_act(
            grad_output, empty, out, 3, 1, negative_slope, scale
        )

        dim = [0]

        if grad_input.ndim > 2:
            dim += list(range(2, grad_input.ndim))

        grad_bias = grad_input.sum(dim).detach()

        return grad_input, grad_bias

    @staticmethod
    def backward(ctx, gradgrad_input, gradgrad_bias):
        out, = ctx.saved_tensors
        gradgrad_out = fused.fused_bias_act(
            gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
        )

        return gradgrad_out, None, None, None


class FusedLeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, bias, negative_slope, scale):
        empty = input.new_empty(0)
        out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
        ctx.save_for_backward(out)
        ctx.negative_slope = negative_slope
        ctx.scale = scale

        return out

    @staticmethod
    def backward(ctx, grad_output):
        out, = ctx.saved_tensors

        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
            grad_output, out, ctx.negative_slope, ctx.scale
        )

        return grad_input, grad_bias, None, None

class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        self.bias = nn.Parameter(torch.zeros(channel))
        self.negative_slope = negative_slope
        self.scale = scale

    def forward(self, input):
        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)

def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)

#######################################
import os

import torch
from torch.autograd import Function
from torch.utils.cpp_extension import load

module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
    'upfirdn2d',
    sources=[
        os.path.join(module_path, 'upfirdn2d.cpp'),
        os.path.join(module_path, 'upfirdn2d_kernel.cu'),
    ],
)


class UpFirDn2dBackward(Function):
    @staticmethod
    def forward(
            ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
    ):
        up_x, up_y = up
        down_x, down_y = down
        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad

        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)

        grad_input = upfirdn2d_op.upfirdn2d(
            grad_output,
            grad_kernel,
            down_x,
            down_y,
            up_x,
            up_y,
            g_pad_x0,
            g_pad_x1,
            g_pad_y0,
            g_pad_y1,
        )
        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])

        ctx.save_for_backward(kernel)

        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        ctx.up_x = up_x
        ctx.up_y = up_y
        ctx.down_x = down_x
        ctx.down_y = down_y
        ctx.pad_x0 = pad_x0
        ctx.pad_x1 = pad_x1
        ctx.pad_y0 = pad_y0
        ctx.pad_y1 = pad_y1
        ctx.in_size = in_size
        ctx.out_size = out_size

        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_input):
        kernel, = ctx.saved_tensors

        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)

        gradgrad_out = upfirdn2d_op.upfirdn2d(
            gradgrad_input,
            kernel,
            ctx.up_x,
            ctx.up_y,
            ctx.down_x,
            ctx.down_y,
            ctx.pad_x0,
            ctx.pad_x1,
            ctx.pad_y0,
            ctx.pad_y1,
        )
        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
        gradgrad_out = gradgrad_out.view(
            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
        )

        return gradgrad_out, None, None, None, None, None, None, None, None


class UpFirDn2d(Function):
    @staticmethod
    def forward(ctx, input, kernel, up, down, pad):
        up_x, up_y = up
        down_x, down_y = down
        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        kernel_h, kernel_w = kernel.shape
        batch, channel, in_h, in_w = input.shape
        ctx.in_size = input.shape

        input = input.reshape(-1, in_h, in_w, 1)

        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
        ctx.out_size = (out_h, out_w)

        ctx.up = (up_x, up_y)
        ctx.down = (down_x, down_y)
        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

        g_pad_x0 = kernel_w - pad_x0 - 1
        g_pad_y0 = kernel_h - pad_y0 - 1
        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

        out = upfirdn2d_op.upfirdn2d(
            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
        )
        # out = out.view(major, out_h, out_w, minor)
        out = out.view(-1, channel, out_h, out_w)

        return out

    @staticmethod
    def backward(ctx, grad_output):
        kernel, grad_kernel = ctx.saved_tensors

        grad_input = UpFirDn2dBackward.apply(
            grad_output,
            kernel,
            grad_kernel,
            ctx.up,
            ctx.down,
            ctx.pad,
            ctx.g_pad,
            ctx.in_size,
            ctx.out_size,
        )

        return grad_input, None, None, None, None


def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    out = UpFirDn2d.apply(
        input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
    )

    return out


def upfirdn2d_native(
        input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
          :,
          max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
          max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
          :,
          ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)

    return out[:, ::down_y, ::down_x, :]

In [None]:
from collections import namedtuple
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module

def get_blocks(num_layers):
	if num_layers == 50:
		blocks = [
			get_block(in_channel=64, depth=64, num_units=3),
			get_block(in_channel=64, depth=128, num_units=4),
			get_block(in_channel=128, depth=256, num_units=14),
			get_block(in_channel=256, depth=512, num_units=3)
		]
	elif num_layers == 100:
		blocks = [
			get_block(in_channel=64, depth=64, num_units=3),
			get_block(in_channel=64, depth=128, num_units=13),
			get_block(in_channel=128, depth=256, num_units=30),
			get_block(in_channel=256, depth=512, num_units=3)
		]
	elif num_layers == 152:
		blocks = [
			get_block(in_channel=64, depth=64, num_units=3),
			get_block(in_channel=64, depth=128, num_units=8),
			get_block(in_channel=128, depth=256, num_units=36),
			get_block(in_channel=256, depth=512, num_units=3)
		]
	else:
		raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
	return blocks

class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
	""" A named tuple describing a ResNet block. """

def get_block(in_channel, depth, num_units, stride=2):
	return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]

class bottleneck_IR(Module):
	def __init__(self, in_channel, depth, stride):
		super(bottleneck_IR, self).__init__()
		if in_channel == depth:
			self.shortcut_layer = MaxPool2d(1, stride)
		else:
			self.shortcut_layer = Sequential(
				Conv2d(in_channel, depth, (1, 1), stride, bias=False),
				BatchNorm2d(depth)
			)
		self.res_layer = Sequential(
			BatchNorm2d(in_channel),
			Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
			Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
		)

	def forward(self, x):
		shortcut = self.shortcut_layer(x)
		res = self.res_layer(x)
		return res + shortcut

class SEModule(Module):
	def __init__(self, channels, reduction):
		super(SEModule, self).__init__()
		self.avg_pool = AdaptiveAvgPool2d(1)
		self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
		self.relu = ReLU(inplace=True)
		self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
		self.sigmoid = Sigmoid()

	def forward(self, x):
		module_input = x
		x = self.avg_pool(x)
		x = self.fc1(x)
		x = self.relu(x)
		x = self.fc2(x)
		x = self.sigmoid(x)
		return module_input * x


class bottleneck_IR_SE(Module):
	def __init__(self, in_channel, depth, stride):
		super(bottleneck_IR_SE, self).__init__()
		if in_channel == depth:
			self.shortcut_layer = MaxPool2d(1, stride)
		else:
			self.shortcut_layer = Sequential(
				Conv2d(in_channel, depth, (1, 1), stride, bias=False),
				BatchNorm2d(depth)
			)
		self.res_layer = Sequential(
			BatchNorm2d(in_channel),
			Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
			PReLU(depth),
			Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
			BatchNorm2d(depth),
			SEModule(depth, 16)
		)

	def forward(self, x):
		shortcut = self.shortcut_layer(x)
		res = self.res_layer(x)
		return res + shortcut

class Flatten(Module):
	def forward(self, input):
		return input.view(input.size(0), -1)


def l2_norm(input, axis=1):
	norm = torch.norm(input, 2, axis, True)
	output = torch.div(input, norm)
	return output