In [None]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1jmQGX3qoUjxyTU_OCpBnDUuwtpRV0-xJ" -O ShapeNetRendering.tgz && rm -rf /tmp/cookies.txt
!tar -zxvf ShapeNetRendering.tgz
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1kGkoh5hePVkSrJ5NRhWh1M6pGSWaVJk7" -O ShapeNetVox32.tgz && rm -rf /tmp/cookies.txt
!tar -zxvf ShapeNetVox32.tgz
!pip3 install torch torchvision torchaudio
!cp -r /content/drive/MyDrive/PFC/tecnica1/utils /content
!cp -r /content/drive/MyDrive/PFC/tecnica1/dataset /content

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/18.png
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/19.png
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/20.png
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/21.png
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/22.png
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/23.png
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/renderings.txt
ShapeNetRendering/03001627/fa33e83563fc2765e238f87ef5154562/rendering/rendering_metadata.txt
ShapeNetRendering/03001627/fa4155f8091689a273801486f0f205ad/
ShapeNetRendering/03001627/fa4155f8091689a273801486f0f205ad/rendering/
ShapeNetRendering/03001627/fa4155f8091689a273801486f0f205ad/rendering/00.png
ShapeNetRendering/03001627/fa4155f8091689a273801486f0f205ad/rendering/01.png
ShapeNetR

In [None]:
import torch
import torchvision.models
import os
import random
import numpy as np
from datetime import datetime as dt
import utils.data_transforms
import dataset.dataset_manager as dataset_manager

In [None]:
# Encoder
class Encoder(torch.nn.Module):

  def __init__(self):
    super(Encoder, self).__init__()
    vgg16_bn = torchvision.models.vgg16_bn(weights = torchvision.models.VGG16_BN_Weights.IMAGENET1K_V1)
    self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]
    self.layer1 = torch.nn.Sequential(
        torch.nn.Conv2d(512, 512, kernel_size=1),
        torch.nn.BatchNorm2d(512),
        torch.nn.ELU(),
    )
    self.layer2 = torch.nn.Sequential(
        torch.nn.Conv2d(512, 256, kernel_size=3),
        torch.nn.BatchNorm2d(256),
        torch.nn.ELU(),
        torch.nn.MaxPool2d(kernel_size=4)
    )
    self.layer3 = torch.nn.Sequential(
        torch.nn.Conv2d(256, 128, kernel_size=3),
        torch.nn.BatchNorm2d(128),
        torch.nn.ELU()
    )

    for param in vgg16_bn.parameters():
        param.requires_grad = False
  
  def forward(self, rendering_images):

      rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()
      rendering_images = torch.split(rendering_images, 1, dim=0)
      image_features = []

      for img in rendering_images:
          features = self.vgg(img.squeeze(dim=0))
          features = self.layer1(features)
          features = self.layer2(features)
          features = self.layer3(features)

          image_features.append(features)

      image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()

      return image_features

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.layer1 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, bias=False, padding=1),
            torch.nn.BatchNorm3d(128),
            torch.nn.ReLU()
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, bias=False, padding=1),
            torch.nn.BatchNorm3d(64),
            torch.nn.ReLU()
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, bias=False, padding=1),
            torch.nn.BatchNorm3d(32),
            torch.nn.ReLU()
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=False, padding=1),
            torch.nn.BatchNorm3d(8),
            torch.nn.ReLU()
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, image_features):
        image_features = image_features.permute(1, 0, 2, 3, 4).contiguous()
        image_features = torch.split(image_features, 1, dim=0)
        gen_voxels = []
        raw_features = []

        for features in image_features:
            gen_voxel = features.view(-1, 256, 2, 2, 2)
            gen_voxel = self.layer1(gen_voxel)
            gen_voxel = self.layer2(gen_voxel)
            gen_voxel = self.layer3(gen_voxel)
            gen_voxel = self.layer4(gen_voxel)
            
            raw_feature = gen_voxel
            gen_voxel = self.layer5(gen_voxel)
            
            raw_feature = torch.cat((raw_feature, gen_voxel), dim=1)
            

            gen_voxels.append(torch.squeeze(gen_voxel, dim=1))
            raw_features.append(raw_feature)

        gen_voxels = torch.stack(gen_voxels).permute(1, 0, 2, 3, 4).contiguous()
        raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous()

        return raw_features, gen_voxels

In [None]:
class Merger(torch.nn.Module):
    def __init__(self):
        super(Merger, self).__init__()

        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv3d(9, 16, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(16),
            torch.nn.LeakyReLU(.2)
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv3d(16, 8, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(8),
            torch.nn.LeakyReLU(.2)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv3d(8, 4, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(4),
            torch.nn.LeakyReLU(.2)
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv3d(4, 2, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(2),
            torch.nn.LeakyReLU(.2)
        )
        self.layer5 = torch.nn.Sequential(
            torch.nn.Conv3d(2, 1, kernel_size=3, padding=1),
            torch.nn.BatchNorm3d(1),
            torch.nn.LeakyReLU(.2)
        )

    def forward(self, raw_features, coarse_volumes):
        n_views_rendering = coarse_volumes.size(1)
        raw_features = torch.split(raw_features, 1, dim=1)
        volume_weights = []

        for i in range(n_views_rendering):
            raw_feature = torch.squeeze(raw_features[i], dim=1)
            
            volume_weight = self.layer1(raw_feature)
            
            volume_weight = self.layer2(volume_weight)
            
            volume_weight = self.layer3(volume_weight)
            
            volume_weight = self.layer4(volume_weight)
            
            volume_weight = self.layer5(volume_weight)

            volume_weight = torch.squeeze(volume_weight, dim=1)
            volume_weights.append(volume_weight)

        volume_weights = torch.stack(volume_weights).permute(1, 0, 2, 3, 4).contiguous()
        volume_weights = torch.softmax(volume_weights, dim=1)
        coarse_volumes = coarse_volumes * volume_weights
        coarse_volumes = torch.sum(coarse_volumes, dim=1)

        return torch.clamp(coarse_volumes, min=0, max=1)

In [None]:
def init_weights(m):
    if type(m) == torch.nn.Conv2d or type(m) == torch.nn.Conv3d or type(m) == torch.nn.ConvTranspose3d:
        torch.nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.BatchNorm3d:
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)
    elif type(m) == torch.nn.Linear:
        torch.nn.init.normal_(m.weight, 0, 0.01)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def var_or_cuda(x):
    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)

    return x

In [None]:
def train():
  torch.backends.cudnn.benchmark = True

  batch_size = 64

  # Transformacion de los datos
  IMG_SIZE = 224, 224
  CROP_SIZE = 128, 128
  TRAIN_RANDOM_BG_COLOR_RANGE = [[225, 255], [225, 255], [225, 255]]
  TRAIN_BRIGHTNESS = .4
  TRAIN_CONTRAST = .4
  TRAIN_SATURATION = .4
  TRAIN_NOISE_STD = .1
  DATASET_MEAN = [0.5, 0.5, 0.5]
  DATASET_STD = [0.5, 0.5, 0.5]


  train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(TRAIN_RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(TRAIN_BRIGHTNESS, TRAIN_CONTRAST, TRAIN_SATURATION),
        utils.data_transforms.RandomNoise(TRAIN_NOISE_STD),
        utils.data_transforms.Normalize(mean=DATASET_MEAN, std=DATASET_STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
  
  TEST_RANDOM_BG_COLOR_RANGE = [[240, 240], [240, 240], [240, 240]]
  
  val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(TEST_RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=DATASET_MEAN, std=DATASET_STD),
        utils.data_transforms.ToTensor(),
    ])
  
  train_dataset_loader = dataset_manager.ShapeNetDataLoader()
  val_dataset_loader = dataset_manager.ShapeNetDataLoader()
  
  train_data_loader = torch.utils.data.DataLoader(dataset = train_dataset_loader.get_dataset(dataset_manager.DatasetType.TRAIN, 1, train_transforms),
                                                  batch_size = batch_size,
                                                  num_workers = 2,
                                                  pin_memory = True,
                                                  shuffle = True,
                                                  drop_last = True)
  val_data_loader = torch.utils.data.DataLoader(dataset = val_dataset_loader.get_dataset(dataset_manager.DatasetType.VAL, 1, train_transforms),
                                                  batch_size = 1,
                                                  num_workers = 1,
                                                  pin_memory = True,
                                                  shuffle = True)

  encoder = Encoder()
  decoder = Decoder()
  merger = Merger()

  print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), sum(p.numel() for p in encoder.parameters())))
  print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), sum(p.numel() for p in decoder.parameters())))
  print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), sum(p.numel() for p in merger.parameters())))

  encoder.apply(init_weights)
  decoder.apply(init_weights)
  merger.apply(init_weights)

  encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),
                                    lr=1e-3,
                                    betas=(.9, .999))
  decoder_solver = torch.optim.Adam(decoder.parameters(),
                                    lr=1e-3,
                                    betas=(.9, .999))
  merger_solver = torch.optim.Adam(merger.parameters(), lr=1e-4, betas=(.9, .999))

  encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                              milestones=[150],
                                                              gamma=.5)
  decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                              milestones=[150],
                                                              gamma=.5)
  merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(merger_solver,
                                                              milestones=[150],
                                                              gamma=.5)
  
  if torch.cuda.is_available():
    encoder = torch.nn.DataParallel(encoder).cuda()
    decoder = torch.nn.DataParallel(decoder).cuda()
    merger = torch.nn.DataParallel(merger).cuda()
  
  bce_loss = torch.nn.BCELoss()

  for epoch_idx in range(0, 100):
    encoder_losses = AverageMeter()

    encoder.train()
    decoder.train()
    merger.train()

    n_batches = len(train_data_loader)

    for batch_idx, (taxonomy_names, sample_names, rendering_images, ground_truth_volumes) in enumerate (train_data_loader):

      rendering_images = var_or_cuda(rendering_images)
      #print(ground_truth_volumes.type)
      ground_truth_volumes = var_or_cuda(ground_truth_volumes)

      image_features = encoder(rendering_images)
      raw_features, generated_volumes = decoder(image_features)
      generated_volumes = merger(raw_features, generated_volumes)

      encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10

      # Gradient decent
      encoder.zero_grad()
      decoder.zero_grad()
      merger.zero_grad()

      encoder_loss.backward()

      encoder_solver.step()
      decoder_solver.step()
      merger_solver.step()

      encoder_losses.update(encoder_loss.item())

      n_itr = epoch_idx * n_batches + batch_idx

    encoder_lr_scheduler.step()
    decoder_lr_scheduler.step()
    merger_lr_scheduler.step()

    print('[INFO] %s Epoch [%d/%d] EDLoss = %.4f' %
              (dt.now(), epoch_idx + 1, 250, encoder_losses.avg))


In [None]:
train()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
03001627 /f1167a0c4bfc1f3fcf004563556ddb36 .OK
03001627 /3936ef166d22e60ff7628281ecb18112 .OK
03001627 /758b4dd493ebb4b34ec0aa53d814a8cb .OK
03001627 /56300f790763af1a872860b02b1bf58 .OK
03001627 /bdd51e6d7ff84be7492d9da2668ec34c .OK
03001627 /5d0a9fa5c8d9bef386f6991406b6a562 .OK
03001627 /ff167d9f25fb6ede2419ec0765e66c90 .OK
03001627 /ed53217c9a4443b8a4ad5308cbfec5eb .OK
03001627 /3774a2b8c71e70b9f18a36d57b7cced0 .OK
03001627 /f4a36a5ae5a596942d19175e7d19b7cb .OK
03001627 /d2992fd5e6715bad3bbf93f83cbaf271 .OK
03001627 /b9e93c2036f24661ae890f02c6b951ff .OK
03001627 /107caefdad02cf1c8ab8e68cb52baa6a .OK
03001627 /d0894aed032460fafebad4f49b26ec52 .OK
03001627 /bec78ebd204764f637a0eda928b574d2 .OK
03001627 /587ee5822bb56bd07b11ae648ea92233 .OK
03001627 /f55a514cc8f2d255f51f77a6d7299806 .OK
03001627 /48fd6cc3f407f1d650c04806fcb7ceb6 .OK
03001627 /8f6634a231e3f3ccdfe9cab879fd37e8 .OK
03001627 /3d697c411b8bf8a0df6cfab91d65bb91 