In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from itertools import chain

class GZSL(nn.Module):
    def __init__(self,feat_dim, cls_dim, enc_hdim1, enc_hdim2, zdim, dec_hdim1, disc_hdim1
 ):
        super(GZSL, self).__init__()
          
        self.feat_dim = feat_dim #2048 img feat size
        self.cls_dim = cls_dim #85 clas attribute vector dim
        self.enc_hdim1 =enc_hdim1 # 512 first hidden layer in encoder
        self.enc_hdim2 =enc_hdim2 # 512 second hidden layer in encoder
        self.zdim = zdim # 512 second hidden layer in encoder
        self.dec_hdim1 = dec_hdim1 # hidden layer in decoder
        self.disc_hdim1 = disc_hdim1 # hidden layer in decoder

        self.bn1 = nn.BatchNorm1d(disc_hdim1)



        #define encoder layers
        self.enc_lin1 = nn.Linear(self.feat_dim + self.cls_dim, self.enc_hdim1)
        self.enc_lin2 = nn.Linear(self.enc_hdim1, self.enc_hdim2)
        self.mu = nn.Linear(self.enc_hdim2, self.zdim)
        self.log_sigma = nn.Linear(self.enc_hdim2, self.zdim)
        self.dp = nn.Dropout(p=0.3)

        #define decoder layers
        self.dec_lin1 = nn.Linear(self.zdim + self.cls_dim,self.dec_hdim1)
        self.dec_lin2 = nn.Linear(self.dec_hdim1, self.feat_dim)
        
        #define discriminator layers
        self.disc_lin1 = nn.Linear(self.feat_dim, self.disc_hdim1)
        self.disc_lin2 = nn.Linear(self.disc_hdim1, self.cls_dim)
        
        #Grouping the model's parameters: separating encoder, decoder, and discriminator
        self.enc_params = chain(
            self.enc_lin1.parameters(), self.enc_lin2.parameters(),
            self.mu.parameters(), self.log_sigma.parameters()
        )
        self.dec_params = chain(
            self.dec_lin1.parameters(), self.dec_lin2.parameters()
        )
        self.disc_params = chain(
            self.disc_lin1.parameters(), self.disc_lin2.parameters()
        )
        self.vae_params = chain(
            self.enc_params, self.dec_params
        )
        
        # filter parameters that dont need grad back prop
        self.vae_params = filter(lambda p: p.requires_grad, self.vae_params)
        self.disc_params = filter(lambda p: p.requires_grad, self.disc_params)

    def forward_encoder(self, inputs):
        out_lin1 = self.enc_lin1(inputs)        
        act_lin1 = self.dp(F.relu(out_lin1))
        out_lin2 = F.relu(self.enc_lin2(act_lin1))
        mu_out = self.mu(out_lin2)
        log_sigma_out = self.log_sigma(out_lin2)
        return mu_out, log_sigma_out

    def forward_decoder(self,z,c):
        inputs = torch.cat((z.view(-1,50),c.view(-1,85)),1)
        out_lin1 = F.relu(self.dec_lin1(inputs))
        out_lin2 = self.dec_lin2(out_lin1)
        return out_lin2
    
    def forward_discriminator(self, inputs):
        out_lin1 = F.relu(self.disc_lin1(inputs))
        out_lin2 = self.disc_lin2(out_lin1)
        return out_lin2
        

    def sample_z(self, mu, logvar):
        """
        Reparameterization trick: z = mu + std*eps; eps ~ N(0, I)
        """
        eps = Variable(torch.randn(self.zdim))
        return mu + torch.exp(logvar/2) * eps    
        
    def forward(self, img_feat, cls_feat, use_c_prior=True):
     
        self.train()

        enc_in = torch.cat((img_feat,cls_feat),1)
        mu, logvar = self.forward_encoder(enc_in)

        z = self.sample_z(mu, logvar)    
         
        if use_c_prior:
            c = cls_feat   
        else:
            c = self.forward_discriminator(img_feat)
            
        y = self.forward_decoder(z, c)
        recon_loss = F.mse_loss(y.view(-1, self.feat_dim), img_feat.view(-1,self.feat_dim), size_average=True)
        kl_loss = torch.mean(0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1 - logvar, 1))
        return recon_loss, kl_loss

    def generate_images(self, batch_size,cls_attr):
        samples = []
        cs = cls_attr
        for i in range(batch_size):
            z = self.sample_z_prior(1) 
            c= cls_attr[i]
            samples.append(self.sample_img(z, c))

        X_gen = torch.cat(samples, dim=0)
      
        return X_gen, cls_attr

    def sample_img(self, z, c):
        self.eval()
        out = self.forward_decoder(z,c)
        self.train()
        return out

    def sample_z_prior(self, mbsize):
        """
        Sample z ~ p(z) = N(0, I)
        """
        z = Variable(torch.randn(mbsize, self.zdim))
        return z        

