In [None]:
from util.logger import *  # Edit this file to match your local configuration
from util.spectral_normilization import *
import math
import torch
from torch import nn
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import random
from torchvision import transforms, datasets
DATA_FOLDER = 'dataset/CIFAR10'
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

In [None]:
def cifar_data():
    compose = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ])
        
    out_dir = '{}'.format(DATA_FOLDER)
    return datasets.CIFAR10(root=out_dir, train=True, transform=compose, download=True)

batch_size=128
data = cifar_data()
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
num_batches = len(data_loader)
channels = 3
leak = 0.1
w_g = 4
num_batches

In [None]:
class EvidLayer(nn.Module):
    def __init__(self, n_features, n_output=1, d_hidden=128):
        super(EvidLayer, self).__init__()
        self.n_features = n_features
        self.n_output = n_output
        self.d_hidden = d_hidden

        self.l1 = SpectralNorm(nn.Linear(n_features, d_hidden))
        self.l2 = SpectralNorm((nn.Linear(d_hidden, self.n_output)))
        
    def evidence(self, x):
        return torch.nn.functional.softplus(x)
    
    def forward(self, inputs, y=None):
        output = self.l1(inputs)
        features = F.leaky_relu(output, 0.1, inplace=True)
        d = self.l2(features)
        if y is not None:
            w_y = self.linear_y(y)
            d = d + (features * w_y).sum(1, keepdim=True)
        
        mu, logv, logalpha, logbeta = torch.split(d, d.shape[-1]//self.n_output, dim=-1) 
        v = self.evidence(logv)+1e-12
        alpha = self.evidence(logalpha) + 1
        beta = self.evidence(logbeta)+1e-12
        
        return torch.cat([mu, v, alpha, beta],dim=-1)

# Evidential discriminator network
class EviDiscriminator(torch.nn.Module):
    
    def __init__(self):
        super(EviDiscriminator, self).__init__()
        
        self.wg=4 #8 for image size 64
        
        self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1,1)))
        self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1,1)))
        self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1,1)))
        self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1,1)))
        self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1,1)))
        self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1,1)))
        self.linear = EvidLayer(self.wg *  self.wg * 256, n_output=4, d_hidden=128)

    def evidence(self, x):
        return torch.nn.functional.softplus(x)
    
    def forward(self, x):
        m = x
        m = nn.LeakyReLU(leak)(self.conv1(m))
        m = nn.LeakyReLU(leak)(self.conv2(m))
        m = nn.LeakyReLU(leak)(self.conv3(m))
        m = nn.LeakyReLU(leak)(self.conv4(m))
        m = nn.LeakyReLU(leak)(self.conv5(m))
        m = nn.LeakyReLU(leak)(self.conv6(m))
        #print(m.shape)
        feat = m.view(-1, self.wg * self.wg * 256)        
        d = (self.linear(feat))
        
        return feat, d

# Generator Network
class GenerativeNet(torch.nn.Module):
    
    def __init__(self,z_dim):
        super(GenerativeNet,self).__init__()
        
        self.z_dim = z_dim
        #self.dense = torch.nn.Linear(128,512 * 4 * 4)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, stride=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1,1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1,1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1,1)),
            nn.Tanh())

    def forward(self, z):
        return self.model(z.view(-1, self.z_dim, 1, 1))

In [None]:
def noise(size):
    n = Variable(torch.randn(size, 100))
    if torch.cuda.is_available(): return n.cuda()
    return n

# Initialized networks weights
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:
        m.weight.data.normal_(0.00, 0.02)

In [None]:
# Create Network instances and init weights
random.seed(1)
generator = GenerativeNet(100)
generator.apply(init_weights)
evid_dist_r=[]
evid_dist_f=[]

discriminator = EviDiscriminator()
  
# Enable cuda if available
if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()

In [None]:
# Optimizers
d_optimizer = Adam(discriminator.parameters(), lr=0.0002,betas=(0.5, 0.999))
g_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Number of epochs
num_epochs = 600

In [None]:
def real_data_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(2*torch.ones(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data

def fake_data_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(-2*torch.ones(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data

def NIG_NLL(y, gamma, v, alpha, beta, reduce=True):
    twoBlambda = 2*beta*(1+v)
    nll = 0.5*torch.log(np.pi/v)  \
        - alpha*torch.log(twoBlambda)  \
        + (alpha+0.5) * torch.log(v*(y-gamma)**2 + twoBlambda)  \
        + torch.lgamma(alpha)  \
        - torch.lgamma(alpha+0.5)
    return torch.mean(nll)

def NIG_Reg(y, gamma, v, alpha, beta, omega=0.01, reduce=True):
    reg_loss=torch.nn.SmoothL1Loss()
    error = reg_loss(gamma,y)
    evi = 2*v+(alpha)
    reg = error*evi
    return torch.mean(reg)

def loss_evid(evidential_output,y_true, coeff=1.0):
    gamma, v, alpha, beta = torch.split(evidential_output, evidential_output.shape[-1]//4, dim=-1)    
    loss_nll = NIG_NLL(y_true, gamma, v, alpha, beta)
    loss_reg = NIG_Reg(y_true, gamma, v, alpha, beta)
    return loss_nll + coeff * loss_reg

In [None]:
# L_Sup for AsymReg
def loss_sup(out1, others, temperature=0.1):
   
    N = out1.size(0)
    out1 = F.normalize(out1)
    others = F.normalize(others)
    _out = [out1, others]
    outputs = torch.cat(_out, dim=0)
    sim_matrix = outputs @ outputs.t()
    sim_matrix = sim_matrix / temperature
    sim_matrix.fill_diagonal_(-5e4)

    mask = torch.zeros_like(sim_matrix)
    mask[N:,N:] = 1
    mask.fill_diagonal_(0)

    sim_matrix = sim_matrix[N:]
    mask = mask[N:]
    mask = mask / mask.sum(1, keepdim=True)

    lsm = F.log_softmax(sim_matrix, dim=1)
    lsm = lsm * mask
    d_loss = -lsm.sum(1).mean()
    return d_loss

# L_G_reg for AsymReg
def loss_nt_xent(out1, out2, temperature=0.1,  normalize=True):
   
    assert out1.size(0) == out2.size(0)
    if normalize:
        out1 = F.normalize(out1)
        out2 = F.normalize(out2)
    N = out1.size(0)

    _out = [out1, out2]
    outputs = torch.cat(_out, dim=0)

    sim_matrix = outputs @ outputs.t()
    sim_matrix = sim_matrix / temperature

    sim_matrix.fill_diagonal_(-5e4)
    sim_matrix = F.log_softmax(sim_matrix, dim=1)
    loss = -torch.sum(sim_matrix[:N, N:].diag() + sim_matrix[N:, :N].diag()) / (2*N)
    return loss

In [None]:
def train_discriminator(optimizer,real_data,fake_data):
    # Reset gradients  
    optimizer.zero_grad()
    N=real_data.size(0)
    # 1.1 Train on Real Data
    feat, prediction = discriminator(torch.cat([real_data,fake_data],dim=0))    
  
    prediction_real, prediction_fake= prediction[:N],prediction[N:]

    # Calculate error and backpropagate
    error_dis = (loss_evid(prediction_real,real_data_target(real_data.size(0)))+loss_evid(prediction_fake,fake_data_target(real_data.size(0))))
    (error_dis).backward()
    # 1.3 Update weights with gradients
    optimizer.step()
    return error_dis, prediction_real.mean().item(), prediction_fake.mean().item()

def train_generator(optimizer,fake_data):
    # 2. Train Generator
    # Reset gradients
    optimizer.zero_grad()
    feat, prediction_fake = discriminator(fake_data)
    d_pred_fake2 = prediction_fake.mean().item()
    error_gen = loss_evid(prediction_fake,real_data_target(fake_data.size(0)))
    (error_gen).backward()
    optimizer.step()
    # Return error
    return error_gen,d_pred_fake2

In [None]:
num_test_samples = 16
test_noise = noise(num_test_samples)
n_critic=1
alpha=1.
use_cuda=True
g_step=0
device = 'cuda' 

In [None]:
logger = Logger(model_name='EviD-GAN-C', data_name='CIFAR10')
g_error,d_pred_fake2,fake_data=0,0,0
for epoch in range(num_epochs):
    if epoch%2==0: 
        display.clear_output(True)
    for n_batch, (real_batch,_) in enumerate(data_loader):
       
        step = epoch * len(data_loader) + n_batch + 1
            
        # 1. Train Discriminator
        real_data = Variable(real_batch)
        #a_real_data = Variable(a_real_batch)
        if torch.cuda.is_available(): 
            real_data = real_data.cuda()
        # Generate fake data
        fake_data = generator(noise(real_data.size(0))).detach()
    
        # Train D
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, 
                                                                real_data,fake_data)
        logger.log(d_error, g_error, epoch, n_batch, num_batches)   
        
        # 2. Train Generator
        if step % n_critic == 0:
            # Generate fake data
            fake_data = generator(noise(real_batch.size(0)))
            # Train G
            g_error,d_pred_fake2 = train_generator(g_optimizer,fake_data)
            g_step=g_step+1
        # Log error
        
        # Display Progress
        if (n_batch) % 500 == 0:
            #display.clear_output(True)
            # Display Images
            test_images = generator(test_noise).data.cpu()
            #test_images = generator(noise_simclr(real_data[:16])).data.cpu()
            logger.log_images(test_images, num_test_samples, step, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake,d_pred_fake2
            )
        # Model Checkpoints
        if g_step >= 1e4 and g_step % 2e3==0:
            logger.save_models(generator, discriminator, g_step)
        if(g_step >= 200000):
            break
    if(g_step >= 200000):
        break