In [None]:
from TrainLoader import *
import torchvision.transforms as transforms
import itertools
import logging
import itertools
import cv2
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

In [None]:
data_dir = '/media/data/salman/Lensless3D/data/'
file_dir = '/media/data/salman/Lensless3D/files/'
dataset_dir = data_dir + 'raw_data/FlyingThings3D_subset/'
#data_dict_psf = mat73.loadmat('psfs_save_magfs.mat')
#psfs = data_dict_psf['psfs'][:,:,:,-25:][::2,::2]

def show_figure(image1, title1, mode="single", image2=None, title2=None, save=False, img_name=None, cmap='gray'):
    
    if mode=='single':
        fig = plt.figure()
        plt.axis('off')
        plt.imshow(image1, cmap=cmap)
        
    if mode=='single-colorbar':
        fig, ax = plt.subplots()
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.2)

        im1 = ax.imshow(image1, cmap=cmap)
        ax.set_title(title1)

        fig.colorbar(im1, cax=cax, orientation='vertical')
        
    elif mode=='comparison':
        fig, (ax1, ax2) = plt.subplots(1, 2)
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes('right', size='5%', pad=0.2)

        im1 = ax1.imshow(image1, cmap=cmap)
        ax1.set_title(title1)

        fig.colorbar(im1, cax=cax, orientation='vertical')
        
        divider = make_axes_locatable(ax2)
        cax = divider.append_axes('right', size='5%', pad=0.2)

        im2 = ax2.imshow(image2, cmap=cmap)
        ax2.set_title(title2)

        fig.colorbar(im2, cax=cax, orientation='vertical')
        fig.tight_layout(pad=1.0)
        fig.show()
        
    if save:
        fig.savefig(img_name)

    
from struct import *
def load_pfm(file_path):
    """
    load image in PFM type.
    Args:
        file_path string: file path(absolute)
    Returns:
        data (numpy.array): data of image in (Height, Width[, 3]) layout
        scale (float): scale of image
    """
    with open(file_path, encoding="ISO-8859-1") as fp:
        color = None
        width = None
        height = None
        scale = None
        endian = None
    
        # load file header and grab channels, if is 'PF' 3 channels else 1 channel(gray scale)
        header = fp.readline().rstrip()
        if header == 'PF':
            color = True
        elif header == 'Pf':
            color = False
        else:
            raise Exception('Not a PFM file.')

        # grab image dimensions
        dim_match = re.match(r'^(\d+)\s(\d+)\s$', fp.readline())
        if dim_match:
            width, height = map(int, dim_match.groups())
        else:
            raise Exception('Malformed PFM header.')

        # grab image scale
        scale = float(fp.readline().rstrip())
        if scale < 0:  # little-endian
            endian = '<'
            scale = -scale
        else:
            endian = '>'  # big-endian

        # grab image data
        data = np.fromfile(fp, endian + 'f')
        shape = (height, width, 3) if color else (height, width)

        # reshape data to [Height, Width, Channels]
        data = np.reshape(data, shape)
        data = np.flipud(data)

        return data

class VGG(nn.Module):
    """VGG/Perceptual Loss
    
    Parameters
    ----------
    conv_index : str
        Convolutional layer in VGG model to use as perceptual output

    """
    def __init__(self, conv_index: str = '22'):

        super(VGG, self).__init__()
        vgg_features = torchvision.models.vgg19(pretrained=False).features
        modules = [m for m in vgg_features]
        
        if conv_index == '22':
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index == '54':
            self.vgg = nn.Sequential(*modules[:35])

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229, 0.224, 0.225)
        #self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        self.vgg.requires_grad = False

    def forward(self, sr: torch.Tensor, hr: torch.Tensor) -> torch.Tensor:
        """Compute VGG/Perceptual loss between Super-Resolved and High-Resolution

        Parameters
        ----------
        sr : torch.Tensor
            Super-Resolved model output tensor
        hr : torch.Tensor
            High-Resolution image tensor

        Returns
        -------
        loss : torch.Tensor
            Perceptual VGG loss between sr and hr

        """

        def _forward(x):
            #x = self.sub_mean(x)
            x = x.cpu()
            x = self.vgg(x)
            return x
            
        vgg_sr = _forward(sr)

        with torch.no_grad():
            vgg_hr = _forward(hr.detach())

        loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss
    
## AIF metrics; gt, recon must be scaled from 0-1, dim:(128, 128, 3)

loss_fn_alex = lpips.LPIPS(net='alex')

def PSNR(gt, recon):
    mse = np.mean((gt - recon) ** 2)
    if(mse == 0): 
        return 100
    max_pixel = 1.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def SSIM(gt, recon):
    return ssim(gt, recon, multichannel=True)

def LPIPSval(gt, recon):
    gt = torch.from_numpy(gt.transpose(2, 0, 1).reshape(1, 3, 128, 128)).type(torch.float32)
    recon = torch.from_numpy(recon.transpose(2, 0, 1).reshape(1, 3, 128, 128)).type(torch.float32)
    return loss_fn_alex(gt, recon).item()
    
## Depth metrics

def RMSE(gt, recon):
    mse = np.mean((gt - recon) ** 2)
    return np.sqrt(mse)

In [None]:
batch_size = 13
train_dataset = TrainDataset('/home/sushanth/psf_captures/','/mnt/data/salman/LenslessDesign/datasets/FlyingThings3D/FlyingThings3D_subset/val/')
train_dl = DataLoader(train_dataset, batch_size, shuffle=True)

In [None]:
from unet2 import UNet
device = 'cuda:1'
model = UNet(in_channels=12*3,
              out_channels=4,
              in_layer='filter',
              device = device,
              batch_size = 2,
              n_blocks=5,
              start_filts = 64,
              attention = True,
              activation=nn.ELU(),
              normalization='batch',
              conv_mode='same',
              out_layer='linear',
              dim=2).to(device)
#model = nn.DataParallel(model).to(device)

criterion1 = VGG()
#criterion2 = nn.CrossEntropyLoss()
criterion2=nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters())
n_epochs = 50

checkpoint = torch.load('/home/sushanth/test_codes/new_chkpoints_50.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:

device = 'cuda:1'
for epoch in range(1, 150):
    # Training loop
    epoch_loss = 0
    model.train()
    
    
    batch_limit = 200
    limited_batches = itertools.islice(train_dl, batch_limit)
    
    for i, (meas, im, depth) in tqdm(enumerate(limited_batches), desc='Training Batch', leave=False):
        im = F.pad(F.interpolate(im, [180, 180], mode='bilinear', align_corners=True), (6, 6, 6, 6))#Final shape B,192,192
        depth = F.pad(F.interpolate(depth.unsqueeze(1), [180, 180], mode='nearest'), (6, 6, 6, 6)).squeeze(1)#Final shape B,192,192

        
        optimizer.zero_grad()
        
        yhat = model(meas.to(device).permute(0, 3, 1, 2).type(torch.float32))

        x = criterion1((yhat[:, :3, ...]), im.to(device).type(torch.float32))
        y = criterion2(yhat[:, 3:, ...].squeeze(1), depth.to(device).type(torch.float32)) 
        loss =  100*x + y
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
      
    avg_loss = epoch_loss / (i + 1)

    
