In [1]:
import os
import sys
import matplotlib.pyplot as plt
import torch
import random
import gc
from PIL import Image
sys.path.insert(1, '/home/sayuhika/anaconda3/envs/pytorch3d/lib/python3.9/site-packages')
import pytorch3d

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, load_obj

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    TexturesUV,
    TexturesVertex,
    BlendParams,
)

# add path for demo utils functions 
sys.path.append(os.path.abspath(''))

ModuleNotFoundError: ignored

In [None]:
AllSatellitesNames = {
    "AcrimSAT": 3, 
    "Aqua": 27, 
    "Aura": 27, 
    "Cassini": 22, 
    "Chandra": 26, 
    "Dawn": 17, 
    "Galileo": 24,
    "Mars Global Surveyor": 15, 
    "Mars Odyssey": 10, 
    "Maven": 14, 
    "TESS": 10
    };

def CreateSatellitesDataset(file_names = [*AllSatellitesNames], sample_count = 10, resolution = 512, lavtE_min = 0, lavtE_max = 360, lavtA_min = 0, lavtA_max = 360, lavtD_koff = 0):
  ExistingNames = [*AllSatellitesNames]
  AcceptableNames = [x for x in ExistingNames if x in file_names]

  if torch.cuda.is_available():
      device = torch.device("cuda:0")
      torch.cuda.set_device(device)
      print(device)
  else:
      device = torch.device("cpu")
      print(device)

    # Renderer parameters
  raster_settings = RasterizationSettings(
    image_size=resolution, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
  )
  blend_params = BlendParams(background_color=(0,0,0))

  lights = PointLights(device=device)
  lights.location = torch.tensor([[7, 7.0, -7.0]], device=device)
  materials = Materials(
      device=device,
      specular_color=[[1.0, 1.0, 1.0]],
      shininess=10.0
  )

  renderer = MeshRenderer(
      rasterizer=MeshRasterizer(
          raster_settings=raster_settings
      ),
      shader=SoftPhongShader(
          device=device, 
          lights=lights,
          blend_params=blend_params
      )
  )

  file_csv = open("/home/sayuhika/CNN/Satellites Detector/SatellitesDataset/labels.csv", "w")  
  DATA_DIR = "/home/sayuhika/CNN/Satellites Detector/Satellite Models"

    # Generation of dataset images
  for x in AcceptableNames: 
    obj_filename = os.path.join(DATA_DIR, x + "/" + x + ".obj")
    mesh = load_objs_as_meshes([obj_filename], device=device, create_texture_atlas=True, texture_atlas_size=30)

    for i in range(sample_count):
        # Camera settings
      lavtE = random.randint(lavtE_min, lavtE_max)
      lavtA = random.randint(lavtA_min, lavtA_max)
      R, T = look_at_view_transform(AllSatellitesNames[x] + lavtD_koff, lavtE, lavtA)
      cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

        # Model rendering   
      rlx = random.uniform(5, 10) * (random.randint(0, 1) - 0.5) * 2
      rly = random.uniform(5, 10) * (random.randint(0, 1) - 0.5) * 2
      rlz = random.uniform(5, 10) * (random.randint(0, 1) - 0.5) * 2
      lights.location = torch.tensor([[rlx, rly, rlz]], device=device)
      images = renderer(mesh, lights=lights, materials=materials, cameras=cameras)
      #images = images.permute(0,3,1,2)
      #images = torch.nn.functional.avg_pool2d(images, kernel_size=3, stride=3)
      #images = images.permute(0,2,3,1)
      image = images[0, ..., :3].cpu().numpy()

        # Making BW image and save
      image_name = "/home/sayuhika/CNN/Satellites Detector/SatellitesDataset/satellites_data/"+ x + "_" + str(i) + ".png" 
      plt.figure(figsize=(10, 10))
      plt.imsave(image_name, image)
      plt.close('all')
      plt.clf()
      del lavtE, lavtA, R, T, cameras, rlx, rly, rlz, images, image
      
      with Image.open(image_name) as image:
          image = image.convert("L")
          image.save(image_name)
      del image_name, image
        
      temptext = x + "_" + str(i) + ".png" + ", " + str(ExistingNames.index(x)) + "\n"
      print(temptext)
      file_csv.write(temptext)
      file_csv.flush()
      del temptext

      torch.cuda.empty_cache()
      gc.collect()
    
    del mesh
    
  file_csv.close()

In [None]:
CreateSatellitesDataset(file_names=['AcrimSAT', 'Cassini'],sample_count=1000, lavtD_koff=2)