<a href="https://colab.research.google.com/github/Venkatakrishnan-Ramesh/PIFu/blob/master/notebook11fd57d5c1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES
# TO THE CORRECT LOCATION (/kaggle/input) IN YOUR NOTEBOOK,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'fastnerf:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F4314081%2F7415898%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240127%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240127T170810Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D8196c46480b1423cf78c228f5f7214ff7cebb77d7d19c8a727965eadc419d4a0e5fca026b94de88b92688dc6d0614fe7aa9be7eb788cb6ee4c6db0749c6d6ebbe4c4b7674fbd7d8e24afefc23b5f3b47419900a2ac06fa65c224fd217041341aa96806fce7b251bd5f06ebd0ca05a8df0ba360262459abbea1b9b8686fa06feb5f377330ba24541cc88a343fd437757bb577a77d825bb1845e794d95245359ac2aabe7bce00b96d748fbebba265378e3a2fafacc2ba0e046b480d6f975fc2d3ed6bc1491adbc4e089744beaeec95613250873763187950f957690bd843e4baa34201dfebf31325f89cacdffcfcf78d72f69c8eaba3aed0fd0ab9f465c3362085,training:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F4314100%2F7415942%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240127%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240127T170810Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D3e22c53607b613a79024fdfd367725b695deb3dba8129e18429edc6aaff321ea68c5542f0a8aea495ee4e2284b57bc2980528e1446acaf1aae5c5f1ae0ab71039326bd25229be66b5c83978f99f7949095f6655075102c26f70a745b9678dab46c41c6b49efe0c9b9bb3cadfea15e64ff399af05fa1232ae0cdd4b1084feebfd3600024e80aba7cac227336015a609354427ff14d0b94ec9ff76f2a189ae5173e05fd3bcdfa3eed746a0e8d143074ab1b7618385b348d0875166fccbfeab2861a478a45578932a27f51066f726916616c0c69cc410c5731aab80940cccae946d13d2cd257af9f46d96544daf74de1f9c64a70f875a9d137f12339b0014c64b60'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


In [None]:
#%%
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
# from torchsummary import summary
import os
# from torch.utils.tensorboard import SummaryWriter
import sys
import gc
import argparse
from itertools import chain
# import atexit

class FastNerf(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim_pos=384, hidden_dim_dir=128, D=8):
        super(FastNerf, self).__init__()

        self.Fpos = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, hidden_dim_pos), nn.ReLU(),
                                  nn.Linear(hidden_dim_pos, 3 * D + 1), )

        self.Fdir = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + 3, hidden_dim_dir), nn.ReLU(),
                                  nn.Linear(hidden_dim_dir, hidden_dim_dir), nn.ReLU(),
                                  nn.Linear(hidden_dim_dir, hidden_dim_dir), nn.ReLU(),
                                  nn.Linear(hidden_dim_dir, D), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.D = D

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        # print(out[0].shape,"OUT SHAPE")

        print("x",x.shape)
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        # print(torch.cat(out, dim=1).size,"size")

        return torch.cat(out, dim=1)

    def forward(self, o, d):
        # print(o.shape,d.shape)
        # print("POSITIONAL ENCODING")
        sigma_uvw = self.Fpos(self.positional_encoding(o, self.embedding_dim_pos))
        # print("sigma_uvw",sigma_uvw.shape)


        sigma = torch.nn.functional.softplus(sigma_uvw[:, 0][..., None])  # [batch_size, 1]
        uvw = torch.sigmoid(sigma_uvw[:, 1:].reshape(-1, 3, self.D))  # [batch_size, 3, D]

        # print("POSITIONAL ENCODING 2")
        beta = torch.softmax(self.Fdir(self.positional_encoding(d, self.embedding_dim_direction)), -1)
        color = (beta.unsqueeze(1) * uvw).sum(-1)  # [batch_size, 3]

        return color, sigma

def softmax_(x, dim):
        x_max = x.max(dim=dim, keepdim=True).values
        x.sub_(x_max).exp_().div_(x.sum(dim=dim, keepdim=True))

class Cache(nn.Module):
    def __init__(self, model, scale, device, Np, Nd):
        super(Cache, self).__init__()

        with torch.no_grad():
            # Position
            x, y, z = torch.meshgrid([torch.linspace(-scale / 2, scale / 2, Np).to(device),
                                      torch.linspace(-scale / 2, scale / 2, Np).to(device),
                                      torch.linspace(-scale / 2, scale / 2, Np).to(device)])
            # print(x.shape,"x shape")
            # print(y.shape,"y shape")
            # print(z.shape,"z shape")
            xyz = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)), dim=1)
            # print(xyz.shape,"xyz shape")
            sigma_uvw = model.Fpos(model.positional_encoding(xyz, model.embedding_dim_pos))
            self.sigma_uvw = sigma_uvw.reshape((Np, Np, Np, -1))
            # Direction
            xd, yd = torch.meshgrid([torch.linspace(-scale / 2, scale / 2, Nd).to(device),
                                     torch.linspace(-scale / 2, scale / 2, Nd).to(device)])
            xyz_d = torch.cat((xd.reshape(-1, 1), yd.reshape(-1, 1),
                               torch.sqrt((1 - xd ** 2 - yd ** 2).clip(0, 1)).reshape(-1, 1)), dim=1)
            beta = model.Fdir(model.positional_encoding(xyz_d, model.embedding_dim_direction))
            self.beta = beta.reshape((Nd, Nd, -1))
            print ("Beta in cache " ,self.beta.shape)

        self.scale = scale
        self.Np = Np
        self.Nd = Nd
        self.D = model.D



    def forward(self, x, d):
        color = torch.zeros_like(x)
        sigma = torch.zeros((x.shape[0], 1), device=x.device)

#         print("x",x.shape)
#         print("d",d.shape)

        mask = (x[:, 0].abs() < (self.scale / 2)) & (x[:, 1].abs() < (self.scale / 2)) & (x[:, 2].abs() < (self.scale / 2))
        # mask is done to check if the x is within the range of scale/2 , any other method? to do the same thing
        # Position
        idx = (x[mask] / (self.scale / self.Np) + self.Np / 2).long().clip(0, self.Np - 1)
#         print("indexed shape",self.sigma_uvw[idx[:, 0], idx[:, 1], idx[:, 2]].shape)
#         print("idx",idx.shape)
#         print("self.sigma_uvw",self.sigma_uvw.shape)
#         print("mask",mask.shape)
        sigma_uvw = self.sigma_uvw[idx[:, 0], idx[:, 1], idx[:, 2]]
#         print(self.beta[idx[:, 0], idx[:, 1]].shape)
#         print("beta_indexed",idx[:, 0].shape,idx[:, 1].shape,idx[:, 2].shape)
#         print("sigma_uvw",sigma_uvw.shape)
#         print("beta",self.beta.shape)
        # Direction
        # idx = (d[mask] * self.Nd).long().clip(0, self.Nd - 1)
        beta = softmax_(self.beta[idx[:, 0], idx[:, 1]], -1)
        beta=self.beta[idx[:, 0], idx[:, 1]]
        #
        # beta=torch.softmax(self.beta[idx[:1000, 0], idx[:1000, 1]], -1)

        sigma[mask] = torch.nn.functional.softplus(sigma_uvw[:, 0][..., None])
#         print("sigma",sigma.shape)
        #uvw = torch.sigmoid_(sigma_uvw[:, 1:])
        # print("uvw",uvw.shape)
        # print("uvw_reshaped" , uvw.reshape(-1, 3, self.D).shape)
        uvw = torch.sigmoid_(sigma_uvw[:, 1:].reshape(-1, 3, self.D))
#         print("uvw",uvw.shape)
        uvw.mul_(beta.unsqueeze(1))
#         print("uvw_multiplied",uvw.shape)
        color[mask] = uvw.sum(-1)
        return color, sigma

def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)

def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    print("t",t.shape)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
#     print("lower",lower.shape)
#     print("upper",upper.shape)
    u = torch.rand(t.shape, device=ray_origins.device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
#     print("t",t.shape)
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
#     print("delta",delta.shape)
    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)  # [batch_size, nb_bins, 3]
#     print("x",x.shape)
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)
#     print("ray_directions",ray_directions.shape)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
#     print("colors",colors.shape)

    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    weight_sum = weights.sum(-1).sum(-1) # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)


@torch.no_grad()
def test(model, hn, hf, dataset, img_index=0, nb_bins=192, H=400, W=400):
    # print("test:img_index:",img_index)
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    print("test:ray_origins:",ray_origins.shape)
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]
    print("test:ray_directions:",ray_directions.shape)

    regenerated_px_values = render_rays(model, ray_origins.to(device), ray_directions.to(device), hn=hn, hf=hf,
                                        nb_bins=nb_bins)



#     print("test:regenerated_px_values:",regenerated_px_values.shape)
    # print("test:regenerated_px_values",regenerated_px_values.shape)
    # print("test:regenerated_px_values:1:",regenerated_px_values.data.cpu().numpy().reshape(H, W, 3).clip(0, 1).shape)

    fig = plt.figure()
    fig.set_size_inches(H, W)
    plt.imshow(regenerated_px_values.data.cpu().numpy().reshape(H, W, 3).clip(0, 1))
    plt.axis('off')
    plt.savefig(f'novel_views/img_test_{img_index}.png', dpi=1)
    print('Render successful{img_index}')
    # plt.savefig(f'novel_views/img_{img_index}.png', bbox_inches='tight')
    plt.close()

    # generated_px_values = dataset[img_index * H * W: (img_index + 1) * H * W, 6:]
    # fig=plt.figure()
    # fig.set_size_inches(H, W)
    # plt.imshow(generated_px_values.data.cpu().numpy().reshape(H, W, 3).clip(0, 1))
    # plt.axis('off')
    # plt.savefig(f'novel_views/img_generated_{img_index}.png', dpi=1)
    # print('Testing successful_{img_index}')

def cleanup():
    """
    Cleanup function to release GPU resources.
    """
    # Clear the GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

def save_checkpoint(state, epoch, filename='nerf_checkpoint_epoch_{}.pth'):
    torch.save(state, filename.format(epoch))

def load_checkpoint(checkpoint_path, model, optimizer, scheduler, device):
    start_epoch = 0  # Default start epoch

    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

        # Resume from the next epoch
        start_epoch = checkpoint['epoch'] + 1
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {start_epoch})")
    else:
        print(f"No checkpoint found at '{checkpoint_path}', starting from epoch 0")

    return start_epoch

def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5), nb_bins=192, checkpoint_path=None):
    start_epoch = 0
    training_loss = []

    # Load checkpoint if it exists
    start_epoch = load_checkpoint(checkpoint_path, nerf_model, optimizer, scheduler, device)
    nerf_model.to(device)

    nerf_model.to(device)

    for epoch in range(start_epoch, nb_epochs):
        epoch_loss = 0.0
        for batch in tqdm(data_loader, desc=f'Epoch {epoch}/{nb_epochs}', leave=False):
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)
            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        scheduler.step()

        # Calculate and record the average loss for the epoch
        avg_epoch_loss = epoch_loss / len(data_loader)
        training_loss.append(avg_epoch_loss)

        # Save checkpoint at the end of each epoch
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': nerf_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }, epoch)

        print(f'Epoch {epoch+1} completed, Avg Loss: {avg_epoch_loss:.4f}')

    return training_loss

# Usage example
checkpoint_dir = 'checkpoints'  # Ensure this directory exists


def pattern(dataset_size):
    # Set the device to GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Create sequence as PyTorch tensor
    seq = torch.arange(dataset_size, device=device)
    # Define steps and calculations
    step1 = 32
    step2 = 400
    step3 = step2 // 2
    subt = step1 // 2
    range_start = step3 - subt

    # First Transformation
    indices = torch.arange(range_start, seq.numel(), step2, device=device)
    result = torch.cat([seq[i:i+step1] for i in indices])

    # Second Transformation
    step11 = step1 * step1
    range_start1 = range_start * step1
    step21 = range_start1 * 2 + step11
    indices1 = torch.arange(range_start1, result.numel(), step21, device=device)
    result1 = torch.cat([result[i:i+step11] for i in indices1])

    # Converting result back to CPU for return, if needed
    return result1.cpu().numpy()

import time
def main():
    # Parse arguments
#     parser = argparse.ArgumentParser(description="Neural Rendering with FastNeRF")
#     parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True,
#                         help="Mode to run the script: 'train' or 'test'")
#     args = parser.parse_args()

    # images , labels = training_dataset
    # print(images.shape)
    # sys.exit()
#     testing_dataset = torch.from_numpy(np.load('testing_data.pkl', allow_pickle=True))
    model = FastNerf().to(device)
    model = torch.nn.DataParallel(model, device_ids = [0,1])
#     atexit.register(cleanup)

#     if args.mode == 'train':
    training_dataset = torch.from_numpy(np.load('/kaggle/input/training/training_data.pkl', allow_pickle=True))
    training_dataset_shape = np.shape(training_dataset)
    # print("training_dataset_shape",training_dataset_shape)
    #pattern1=pattern(training_dataset_shape[0])
    # print("pattern1",pattern1)
    #training_dataset_new=training_dataset[pattern1,:]
    #training_dataset_new_shape = np.shape(training_dataset_new)
    # print("training_dataset_new_shape",training_dataset_new_shape
    model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
    data_loader = DataLoader(training_dataset, batch_size=8192, shuffle=True)
    time1 = time.time()
    train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=device, hn=2, hf=6)
    time2= time.time()    # Save the model
    print(time2-time1 , "time taken for training")
    model_path = 'model.pth'
    torch.save(model, model_path)
    model_size = os.path.getsize(model_path)
    print(f"Size of the saved model on disk: {model_size / (1024*1024)} MB")

#     elif args.mode == 'test':
#         model = torch.load('model.pth')
#         # model.eval()
#         cache = Cache(model, 2.2, 'cuda', 64, 64)
#         print("cache")

#         for idx in range(200):
#             test(cache, 2., 6., testing_dataset, img_index=idx, nb_bins=12, H=400, W=400)
#             torch.cuda.empty_cache()

import time
if __name__ == '__main__':
    start=time.time()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    main()
    end=time.time()
    print(end-start,"Time")


# %%



In [None]:
import zipfile
import os
from IPython.display import FileLink

def zip_dir(directory = os.curdir, file_name = 'directory.zip'):
    """
    zip all the files in a directory

    Parameters
    _____
    directory: str
        directory needs to be zipped, defualt is current working directory

    file_name: str
        the name of the zipped file (including .zip), default is 'directory.zip'

    Returns
    _____
    Creates a hyperlink, which can be used to download the zip file)
    """
    os.chdir(directory)
    zip_ref = zipfile.ZipFile(file_name, mode='w')
    for folder, _, files in os.walk(directory):
        for file in files:
            if file_name in file:
                pass
            else:
                zip_ref.write(os.path.join(folder, file))

    return FileLink(file_name)

In [None]:
zip_dir()

In [None]:
!nvidia-smi