In [9]:
## Imports
import torch 
import torchvision
import torchvision.transforms as tf

import matplotlib.pyplot as plt
from PIL import Image
import scipy.io as sio

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Utils

### Compute model paramters

In [None]:
def compute_model_params(model):
  params = 0
  for p in model.parameters():
    params+= p.numel()
  return params

### Show Images Result

In [6]:
def show_2_Img(img1, title1, img2, title2, size=(10,20)):
  plt.figure(figsize=size)

  plt.subplot(1,2,1)
  plt.imshow(img1)
  plt.axis('off')
  plt.title(title1)

  plt.subplot(1,2,2)
  plt.imshow(img2)
  plt.axis('off')
  plt.title(title2)

  plt.show()

def show_Result(g, image, size=(10,20)):

  g.eval()
  image_fake = g(image)

  show_2_Img(image[0].permute(1,2,0).cpu(), 'Original Images',
             image_fake[0].permute(1,2,0).cpu().detach().numpy(), 'Reconstructed Images',
             size)

### Computed Accuracy the Disciminator

In [7]:
def accuracy_Disc(D, images_young, images_elderly, device='cpu'):
  
  labels_y = torch.zeros(images_young.shape[0],1)
  labels_e = torch.ones(images_elderly.shape[0],1)

  images = torch.cat((images_young, images_elderly), 0)
  labels = torch.cat((labels_y, labels_e), 0)
  
  D.to(device)
  D.eval()
  images = images.to(device)
  labels = labels.to(device)

  # get network predictions
  predicted = D(images)

  # get predicted class
  predicted[predicted>=0.5] = 1
  predicted[predicted<0.5] = 0

  # compare with the ground-truth
  total = labels.shape[0]
  correct = (predicted == labels).sum().item()
  # show accuracy
  acc = 100 * correct / total
  print('Accuracy: {:.2f}%'.format(acc))

### Info Class

In [8]:
class Info():
  def __init__(self, num_epochs, losses_params = [], image=None):
    self.losses = {}
    self.epochs = num_epochs
    self.current_stage = {}
    self.losses_params = losses_params
    for grup in self.losses_params:
      self.losses[grup] = {}
      for par in losses_params[grup]:
        self.losses[grup][par] = []
    self.images = None
    if image != None:
      self.images = {'Origin': image, 'result':[]}

  def start_stage(self):
    self.current_stage = {'nBatches': 0}
    for grup in self.losses_params:
      self.current_stage[grup] = {}
      for par in self.losses_params[grup]:
        self.current_stage[grup][par] = 0

  def update_stage(self, updates):
    self.current_stage['nBatches'] += 1
    i=0
    for grup in self.losses_params:
      for par in self.losses_params[grup]:
        self.current_stage[grup][par] += updates[i]
        i+=1

  def print_info(self, epoch, i, t_steps, all=True):
    nBatches = self.current_stage['nBatches']
    text = 'Epoch [{}/{}], Step [{}/{}]'.format(epoch+1, self.epochs, i+1, t_steps)
    text_extre = ""
    for grup in self.losses_params:
      text_extre += "\t"+grup+".:"
      for par in self.losses_params[grup]:
        if 't' == par:
          text += ', '+grup+'. Loss: {:.4f}'.format(self.current_stage[grup]['t'] / nBatches)
        elif all:
          losses_t = self.current_stage[grup]['t']
          losses_p = self.current_stage[grup][par]
          text_extre += " Loss "+par+": {:.4f} ({:.0f}%),".format(losses_p / nBatches, losses_p/losses_t*100)
      text_extre = text_extre[:-1]
      text_extre += "\n"

    print (text)
    if all:
      print(text_extre[:-1])
  
  def save(self, image=None):
    nBatches = self.current_stage['nBatches']
    for grup in self.losses_params:
      for par in self.losses_params[grup]:
        self.losses[grup][par].append(self.current_stage[grup][par]/nBatches)
    if (self.images != None and image != None):
      self.images['result'].append(image)
  
  def show_Image(self, size=(5,25)):
    if self.images != None:
      image = self.images['Origin']
      plt.figure(figsize=size)
      for i, image_fake in enumerate(self.images['result']):
        plt.subplot(self.epochs,2,2*i+1)
        plt.imshow(image[0].permute(1,2,0).cpu())
        plt.axis('off')

        plt.subplot(self.epochs,2,2*i+2)
        plt.imshow(image_fake[0].permute(1,2,0).cpu().detach().numpy())
        plt.axis('off')
      plt.show()


  def print_all_info(self, all=True, img=False):
    for epoch in range(self.epochs):
      text = 'Epoch [{}/{}]'.format(epoch+1, self.epochs)
      text_extre = ""
      for grup in self.losses_params:
        text_extre += "\t"+grup+".:"
        for par in self.losses_params[grup]:
          if 't' == par:
            text += ', '+grup+'. Loss: {:.4f}'.format(self.losses[grup]['t'][epoch])
          elif all:
            losses_t = self.losses[grup]['t'][epoch]
            losses_p = self.losses[grup][par][epoch]
            text_extre += " Loss "+par+": {:.4f} ({:.0f}%),".format(losses_p, losses_p/losses_t*100)
        text_extre = text_extre[:-1]
        text_extre += "\n"

      print (text)
      if all:
        print(text_extre[:-1])
    if img:
      self.show_Image()

### Class and Funcion to Load Data

In [10]:
#Making native class loader
class FacesDB(torch.utils.data.Dataset):
    # Initialization method for the dataset
    def __init__(self,dataDir = data_path+'/example.mat', transform = None, size = 1000000, clr_type='RGB'):
        mat_loaded = sio.loadmat(dataDir)
        if size > len(mat_loaded['L'][0]):
          self.data = mat_loaded['X']
          self.label = mat_loaded['L'][0]
        else:
          self.data = mat_loaded['X'][:size]
          self.label = mat_loaded['L'][0, :size]
        
        self.transform = transform
        if clr_type == 'w&B':
          self.clr = clr_type
        else:
          self.clr = 'RGB'

    # What to do to load a single item in the dataset ( read image and label)    
    def __getitem__(self, index):
        data = self.data[index]  
        lbl = self.label[index]  
        
        if self.clr == 'RGB':
          data = Image.fromarray(data) # RGB
        else:
          data = Image.fromarray(data, mode = 'RGB').convert('L') # W&B
        # Apply a trasnformaiton to the image if it is indicated in the initalizer
        if self.transform is not None : 
            data = self.transform(data)
        
        # return the image and the label
        return data,lbl

    # Return the number of images
    def __len__(self):
        return self.data.shape[0]

def obtainData(path, batch_size=100, size=1000000, clr_type='RGB', resize=(128,128)):
  tr = tf.Compose([
          tf.Resize(resize),
          tf.ToTensor(), 
          ])
  faces_db = FacesDB(path,tr, size, clr_type)
  loader_size = len(faces_db)
  loader = torch.utils.data.DataLoader(dataset=faces_db,
                                            batch_size=batch_size, 
                                            shuffle=True)
  return loader, loader_size

### Variable of training

In [None]:
mod_print = 5

gen_params = {'Gen':['t','reconst']}
disc_params = {'Disc':['t','young','elderly']}
GAN_params = {'Gen':['t','reconst','disc'],'Disc':['t','elderly','fake']}