In [None]:
from types import SimpleNamespace
import types

import pytorch_lightning as pl
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt

from ModelDiffusion import get_default_diffusion_model
from ModelDiffusion.diffusion.utils.dataset import ShapeNetCore
from DiffRender import get_default_rasterizer
# from Wasserstein.SlicedWassersteinDistance.swd import swd

In [None]:
def is_interactive():
    return True
    return not hasattr(__main__, "__file__")

In [None]:
args = SimpleNamespace(
    diffusion_context_input_dim=256,
    diffusion_context_hidden_dim=256,
    dataset_path='./ModelDiffusion/diffusion/data/shapenet.hdf5',
    categories=['airplane'],
    scale_mode='shape_unit',
    train_batch_size=32,
    val_batch_size=8,
    sample_num_points=2048,
    max_test_comparisons=10,
    normalize=None,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

torch.manual_seed(0)

## Image encoder

In [None]:
densenet_encoder = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1)
densenet_encoder.classifier = nn.Sequential(
    nn.Linear(
        in_features=densenet_encoder.classifier.in_features, 
        out_features=args.diffusion_context_hidden_dim
    ),
    nn.ReLU(),
    nn.Linear(
        in_features=args.diffusion_context_hidden_dim, 
        out_features=args.diffusion_context_input_dim
    )
)
densenet_encoder = densenet_encoder.to(args.device)

## Image reconstruction

In [None]:
i = 0
class ImageReconstructionModel(pl.LightningModule):
    def __init__(
        self, 
        img_encoder=densenet_encoder,
        render_depthmap=get_default_rasterizer(img_height=256, img_width=256), 
        diffusion_model=get_default_diffusion_model(sample_num_points=args.sample_num_points, device=args.device, category=args.categories[0]),
        loss_fn=lambda reconstruction, ground_truth: swd(reconstruction, ground_truth, device=args.device)
    ):
        # Input:
        # :img_encoder: function that receives batch of images (shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH) 
        #     and returns batch of encodings (shape: BATCH_SIZE x ENCODING_SIZE)
        # :render_depthmap: function that receives batch of pointclouds 
        #     (shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM) and returns batch of photographs of pointclouds
        #     (shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM)
        # :diffusion_model: function that receives context (shape: BATCH_SIZE x ENCODING_SIZE) and returns batch of pointclouds
        #     (shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM)
        # :loss_fn: function that receives batch of image reconstructions (shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH) and 
        #     batch of image ground truths (shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH) and returns batch of scalar losses
        
        super().__init__()
        self.loss_fn = loss_fn.to(self.device)
        self.render_depthmap = render_depthmap.to(self.device)
        self.img_encoder = img_encoder.to(self.device)
        self.diffusion_model = diffusion_model.to(self.device)
        
    def pointcloud_denoising_comparison(self, pointcloud_batch, denoising_step_number):
        # Fetch noising parameters
        alpha_bar = self.diffusion_model.var_sched.alpha_bars[denoising_step_number]
        c0 = torch.sqrt(alpha_bar).view(-1, 1, 1)       # (B, 1, 1)
        c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1)   # (B, 1, 1)
        
        alpha_bar_prev = self.diffusion_model.var_sched.alpha_bars[denoising_step_number-1]
        c0_prev = torch.sqrt(alpha_bar).view(-1, 1, 1)       # (B, 1, 1)
        c1_prev = torch.sqrt(1 - alpha_bar).view(-1, 1, 1)   # (B, 1, 1)
        c0_prev[torch.where(denoising_step_number == 0)] = 1
        c1_prev[torch.where(denoising_step_number == 0)] = 0

        # Noise pointcloud
        e_rand = torch.randn_like(pointcloud_batch)  # shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM
        groundtruth_pointcloud = c0_prev * pointcloud_batch + c1_prev * e_rand
        noised_pointcloud = c0 * pointcloud_batch + c1 * e_rand
        
        return groundtruth_pointcloud, noised_pointcloud
        
    def training_step(self, pointcloud_batch, batch_idx, denoising_step_number=None):
        # Input: pointcloud_batch with shape BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM
        # Performs diffusion-step loss, by comparing image of noised pointcloud against image of slightly denoised 
        #     image of pointcloud, with that slight denoising being performed by the context-informed diffusion network.
        #     Parallel to self.diffusion_model.get_loss
        pointcloud_batch = pointcloud_batch.to(args.device)
        batch_size, _, point_dim = pointcloud_batch.size()
        if denoising_step_number == None:
            denoising_step_number = self.diffusion_model.var_sched.uniform_sample_t(batch_size)
        
        groundtruth_pointcloud, noised_pointcloud = self.pointcloud_denoising_comparison(
            pointcloud_batch, 
            torch.tensor(denoising_step_number)
        )
        
        # Estimate partially denoised pointcloud
        origin_img_renders = self.render_depthmap(pointcloud_batch)  # shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH
        
        encoded_imgs = self.img_encoder(origin_img_renders.unsqueeze(1).repeat(1,3,1,1))  # shape: BATCH_SIZE x ENCODING_SIZE
        model_pointcloud_reconstructions = self.diffusion_model(
            noised_pointcloud=noised_pointcloud,
            context=encoded_imgs, 
            denoising_step_number=denoising_step_number
        )  # shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM
        reconstruction_img_renders = self.render_depthmap(model_pointcloud_reconstructions)  # shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH

        # Measure success
        groundtruth_img_renders = self.render_depthmap(groundtruth_pointcloud)
        loss = self.loss_fn(
            # reconstruction_img_renders.unsqueeze(1).repeat(1,3,1,1),  for Wasserstein loss
            # groundtruth_img_renders.unsqueeze(1).repeat(1,3,1,1)    for Wasserstein loss
            reconstruction_img_renders,
            groundtruth_img_renders
        )
        self.log('train_loss', loss.mean())
        return loss
        
    def validation_step(self, pointcloud_batch, batch_idx, visualize_validation=False):
        pointcloud_batch = pointcloud_batch.to(args.device)
        
        groundtruth_img_renders = self.render_depthmap(pointcloud_batch)  # shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH
        model_pointcloud_reconstructions = self.forward(img=groundtruth_img_renders) # shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM
        reconstruction_img_renders = self.render_depthmap(model_pointcloud_reconstructions)  # shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH

        loss = self.loss_fn(
            # reconstruction_img_renders.unsqueeze(1).repeat(1,3,1,1),  for Wasserstein loss
            # groundtruth_img_renders.unsqueeze(1).repeat(1,3,1,1)    for Wasserstein loss
            reconstruction_img_renders,
            groundtruth_img_renders
        )
        self.log("validation_loss", loss.mean())

        # Visualize model
        if visualize_validation and batch_idx == 0:
            global i
            fig, axes = plt.subplots(ncols=2, figsize=(15,15))
            fig.suptitle(f"Sample #{batch_idx}")
            axes[0].imshow(groundtruth_img_renders[0].detach().cpu().numpy())
            axes[0].set_title("Ground truth")
            axes[1].imshow(reconstruction_img_renders[0].detach().cpu().numpy())
            axes[1].set_title("Reconstruction")
            fig.savefig(f"Comparisons/Img{i}.png")
            i += 1
        
    def test_step(self, pointcloud_batch, batch_idx):
        pointcloud_batch = pointcloud_batch.to(args.device)
        
        groundtruth_img_renders = self.render_depthmap(pointcloud_batch)  # shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH
        model_pointcloud_reconstructions = self.forward(img=groundtruth_img_renders) # shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM
        reconstruction_img_renders = self.render_depthmap(model_pointcloud_reconstructions)  # shape: BATCH_SIZE x IMG_HEIGHT x IMG_WIDTH

        return groundtruth_img_renders, reconstruction_img_renders
        
    def forward(self, img):
        with torch.no_grad():
            encoded_img = self.img_encoder(img.unsqueeze(1).repeat(1,3,1,1))  # shape: BATCH_SIZE x ENCODING_SIZE
            generated_pointcloud = self.diffusion_model.sample(context=encoded_img)  # shape: BATCH_SIZE x NUM_POINTCLOUD_POINTS x POINTCLOUD_DIM
        return generated_pointcloud
    
    def configure_optimizers(self):
        return torch.optim.NAdam(self.img_encoder.parameters())
    
    def train_model(
        self, 
        datamodule, 
        checkpoint_file_path=None, 
        gpus=None, 
        max_epochs=2000,
        wandb_run_name=None,
        wandb_group_name=None,
        pretrained_model_path=None,
        wandb_config={"project": None, "entity": None}
    ):
        ALL_GPUS = -1
        gpus = ALL_GPUS if gpus is None else gpus
        
        model_checkpoint = pl.callbacks.ModelCheckpoint(
            dirpath=checkpoint_file_path, 
            save_top_k=1, 
            monitor='validation_loss', 
            mode='min', 
            save_on_train_epoch_end=True
        )
        lr_logger = pl.callbacks.LearningRateMonitor(log_momentum=True)
        callbacks_list = [model_checkpoint, lr_logger]

        wandb_logger = pl.loggers.WandbLogger(project=wandb_config["project"], entity=wandb_config["entity"], name=wandb_run_name, group=wandb_group_name)
        self.log('batch_size', args.train_batch_size)
        # wandb_logger.watch(self, log='all', log_freq=100)
        if type(wandb_logger.experiment.config) != types.MethodType:  # if this is main process (as opposed to launched by ddp)
            wandb_logger.experiment.config["model_name"] = self.__class__.__name__
            wandb_logger.experiment.config["img_encoder"] = self.img_encoder.__class__.__name__
            wandb_logger.experiment.config["model_architecture"] = self.__str__()

        print(f"parallelization strategy = {'dp' if is_interactive() else 'ddp_find_unused_parameters_false'}")
        trainer = pl.Trainer(
            devices=gpus, 
            accelerator="gpu", 
            strategy="dp" if is_interactive() else "ddp_find_unused_parameters_false", 
            precision=32, 
            max_epochs=max_epochs, 
            logger=wandb_logger, 
            callbacks=callbacks_list, 
            num_sanity_val_steps=1,
            log_every_n_steps=5
        )
        trainer.fit(self, datamodule, ckpt_path=pretrained_model_path)
        print(f"Best model checkpoint saved at: {callbacks_list[0].best_model_path}")

        wandb_logger.experiment.finish()
        return trainer

### Data

In [None]:
class ShapenetDatamodule(pl.LightningDataModule):
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=ShapeNetCore(
                path=args.dataset_path,
                cates=args.categories,
                split='train',
                scale_mode=args.scale_mode,
            ),
            shuffle=True,
            batch_size=args.train_batch_size,
            num_workers=80
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=ShapeNetCore(
                path=args.dataset_path,
                cates=args.categories,
                split='val',
                scale_mode=args.scale_mode,
            ),
            shuffle=False,
            batch_size=args.val_batch_size,
            num_workers=80
        )
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=ShapeNetCore(
                path=args.dataset_path,
                cates=args.categories,
                split='test',
                scale_mode=args.scale_mode,
            ),
            shuffle=False,
            batch_size=1,
            num_workers=80
        )
    
shapenet_datamodule = ShapenetDatamodule()

### Train

In [None]:
model = ImageReconstructionModel(loss_fn=nn.HuberLoss())  # Future direction for more compute power: Wasserstsein loss
trainer = model.train_model(shapenet_datamodule, gpus=None, wandb_run_name="Run2", pretrained_model_path="3dDiffusion_ImgMatching/v8xe37f8/checkpoints/epoch=29-step=3240.ckpt")

### Test

In [None]:
for batch_idx, pointcloud_batch in enumerate(shapenet_datamodule.test_dataloader()):
    pointcloud_batch = pointcloud_batch.to(args.device)
    
    groundtruth_img, reconstructed_img = model.test_step(pointcloud_batch, batch_idx)
    groundtruth_img, reconstructed_img = groundtruth_img.squeeze(), reconstructed_img.squeeze()
    
    fig, axes = plt.subplots(ncols=2, figsize=(15,15))
    fig.title(f"Sample #{batch_idx}")
    axes[0].imshow(groundtruth_img)
    axes[0].set_title("Ground truth")
    axes[1].imshow(reconstructed_img)
    axes[1].set_title("Reconstruction")
    fig.show()
    
    if batch_idx >= args.max_test_comparisons:
        break