<a href="https://colab.research.google.com/github/VedantDere0104/GANs/blob/main/Self_Attention_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
####

In [16]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np


from torch.optim.optimizer import Optimizer, required
from torch import Tensor
from torch.nn import Parameter

In [17]:
def l2normalize(v, eps=1e-12):
  return v / (v.norm() + eps)

In [18]:
class SpectralNorm(nn.Module):
  def __init__(self , module , name = 'weight' , power_iteration = 1):
    super(SpectralNorm , self).__init__()
    self.module = module
    self.name = name
    self.power_iteration = power_iteration
    if not self._made_params():
      self._make_params()
  
  def _update_u_v(self):
    u = getattr(self.module , self.name + '_u')
    v = getattr(self.module , self.name + '_v')
    w = getattr(self.module , self.name + '_bar')
    height = w.data.shape[0]
    for _ in range(self.power_iterations):
      v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
      u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

    # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
    sigma = u.dot(w.view(height, -1).mv(v))
    setattr(self.module, self.name, w / sigma.expand_as(w))   

  def _made_params(self):
    try :
      u = getattr(self.module, self.name + "_u")
      v = getattr(self.module, self.name + "_v")
      w = getattr(self.module, self.name + "_bar")
    except AttributeError:
      return False

  def _make_params(self):
    w = getattr(self.module, self.name)

    height = w.data.shape[0]
    width = w.view(height, -1).data.shape[1]

    u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
    v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
    u.data = l2normalize(u.data)
    v.data = l2normalize(v.data)
    w_bar = Parameter(w.data)

    del self.module._parameters[self.name]

    self.module.register_parameter(self.name + "_u", u)
    self.module.register_parameter(self.name + "_v", v)
    self.module.register_parameter(self.name + "_bar", w_bar)

  def forward(self, *args):
    self._update_u_v()
    return self.module.forward(*args)

In [19]:
class self_attention(nn.Module):
  def __init__(self , in_dim , activation):
    super(self_attention , self).__init__()
    self.in_dim = in_dim
    self.activation = activation

    self.query_conv = nn.Conv2d(in_channels=in_dim , out_channels=in_dim //8 , kernel_size=1)
    self.key_conv = nn.Conv2d(in_channels=in_dim , out_channels=in_dim //8 , kernel_size=1)
    self.value_conv = nn.Conv2d(in_channels=in_dim , out_channels=in_dim , kernel_size=1)
    self.gamma = nn.Parameter(torch.zeros(1))
    self.softmax = nn.Softmax(dim=-1)

  def forward(self , x):
    '''
    x: input_feature_map (B , C , W , H)
    out : self_attention_value + input_feature
    attention: B X N X N (N is Width*Height)
    '''
    m_batchsize , c , width , height = x.size()
    proj_query = self.query_conv(x).view(m_batchsize , -1 , width * height).permute(0 , 2 , 1) # (B , C , N)
    proj_key = self.key_conv(x).view(m_batchsize , -1 , width * height) # (B , C , W*H)
    energy = torch.bmm(proj_query , proj_key)
    attention = self.softmax(energy)
    proj_value = self.value_conv(x).view(m_batchsize , -1 , width * height) # (B , C , N)

    out = torch.bmm(proj_value , attention.permute(0 , 2 , 1))
    out = out.view(m_batchsize , C , width , height)
    out = self.gamma * out + x
    return out , attention


In [27]:
class Generator(nn.Module):
  def __init__(self , batch_size , image_size = 64 , z_dim = 100 , conv_dim = 64):
    super(Generator , self).__init__()
    self.image_size = image_size
    layer1 = []
    layer2 = []
    layer3 = []
    last = []

    repeat_num = int(np.log2(self.image_size)) - 3
    mult = 2 ** repeat_num # 8

    layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim , conv_dim *  mult , 4)))
    layer1.append(nn.BatchNorm2d(conv_dim * mult))
    layer1.append(nn.ReLU())

    cur_dim = conv_dim * mult

    layer2.append(SpectralNorm(nn.ConvTranspose2d(cur_dim , int(cur_dim / 2) , 4 , 2 , 1)))
    layer2.append(nn.BatchNorm2d(int(cur_dim / 2)))
    layer2.append(nn.ReLU())

    cur_dim = int(cur_dim /2)


    layer3.append(SpectralNorm(nn.ConvTranspose2d(cur_dim , int(cur_dim / 2) , 4 , 2 , 1)))
    layer3.append(nn.BatchNorm2d(int(cur_dim / 2)))
    layer3.append(nn.ReLU())

    if self.image_size == 64:
      layer4 = []
      cur_dim = int(cur_dim / 2)

      layer4.append(SpectralNorm(nn.ConvTranspose2d(cur_dim , int(cur_dim / 2) , 4 , 2 , 1)))
      layer4.append(nn.BatchNorm2d(int(cur_dim / 2)))
      layer4.append(nn.ReLU())

      self.l4 = nn.Sequential(*layer4)

      cur_dim = int(cur_dim / 2)
    self.l1 = nn.Sequential(*layer1)
    self.l2 = nn.Sequential(*layer2)
    self.l3 = nn.Sequential(*layer3)

    last.append(nn.ConvTranspose2d(cur_dim , 3 , 4 , 2 , 1))
    last.append(nn.Tanh())
    self.last = nn.Sequential(*last)

    self.attn1 = self_attention(128 , 'relu')
    self.attn2 = self_attention(64 , 'relu')

  def forward(self , z):
    z = z.view(z.size(0), z.size(1), 1, 1)
    out = self.l1(z)
    out = self.l2(out)
    out = self.l3(out)
    out , p1 = self.attn1(out)
    out = self.l4(out)
    out , p2 = self.attn2(out)
    out = self.last(out)

    return out , p1 , p2


In [28]:
class Discriminator(nn.Module):
  def __init__(self , batch_size = 64 , image_size = 64 , conv_dim = 64):
    super(Discriminator , self).__init__()
    self.image_size = image_size

    layer1 = []
    layer2 = []
    layer3 = []
    last = []

    layer1.append(SpectralNorm(nn.Conv2d(3 , conv_dim , 4 , 2 , 1)))
    layer1.append(nn.LeakyReLU(0.1))

    cur_dim = conv_dim

    layer2.append(SpectralNorm(nn.Conv2d(cur_dim , cur_dim * 2 , 4 , 2 , 1)))
    layer2.append(nn.LeakyReLU(0.1))

    cur_dim = cur_dim * 2

    layer3.append(SpectralNorm(nn.Conv2d(cur_dim , cur_dim * 2 , 4 , 2 , 1)))
    layer3.append(nn.LeakyReLU(0.1))

    cur_dim = cur_dim * 2

    if self.image_size == 64:
      layer4 = []

      layer4.append(SpectralNorm(nn.Conv2d(cur_dim , cur_dim * 2 , 4 , 2 , 1)))
      layer4.append(nn.LeakyReLU(0.1))

      self.l4 = nn.Sequential(*layer4)
      
      cur_dim = cur_dim * 2
    self.l1 = nn.Sequential(*layer1)
    self.l2 = nn.Sequential(*layer2)
    self.l3 = nn.Sequential(*layer3)

    last.append(nn.Conv2d(cur_dim , 1 , 4))
    self.last = nn.Sequential(*last)

    self.attn1 = self_attention(256 , 'relu')
    self.attn2 = self_attention(512 , 'relu')

  def forward(self , x):
    out = self.l1(x)
    out = self.l2(out)
    out = self.l3(out)
    out , p1 = self.attn1(out)
    out = self.l4(out)
    out , p2 = self.attn2(out)
    out = self.last(out)
    
    return out , p1 , p2

