In [None]:
from pytorch_lightning.loggers import CometLogger

import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import pytorch_lightning as pl
from kornia.utils import create_meshgrid
from torch.utils.data import DataLoader, TensorDataset
from einops import rearrange
from torchvision.io import read_image
from torchvision import transforms
import torchvision
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from collections import OrderedDict
from kornia.augmentation import RandomAffine, CenterCrop
import cv2 as cv
from utils import show
from models import Homography, NeuralRenderer, SineLayer, Siren

# Model

In [None]:
def get_coords_dataloader(img):
    C,H,W = img.size()
    img = rearrange(img,'c H W -> (H W) c')

    x = create_meshgrid(H,W).squeeze()
    ones = torch.ones((H,W,1))
    x_hom = torch.concat([x,ones], dim=2)
    x_hom = rearrange(x_hom,'H W C -> (H W) C')
    dataset  = TensorDataset(x_hom, img)
    dataloader = DataLoader(dataset, batch_size=H*W)
    return dataloader
    

class BARF_PL(pl.LightningModule):
    def __init__(self, imgs, img_mask, img_GT, H_GT, pos_enc=False, L=10, barf_c2f = True, video_file = "output.avi"):
        super().__init__()
        #self.mlp = NeuralRenderer(True)
        
        # Positional Encoding
        self.L = L
        self.pos_enc = pos_enc
        self.barf_c2f = barf_c2f
        if pos_enc:
            input_size = 2*2*L
        else:
            input_size = 2
        
        # Reference Image with mask
        self.mask = img_mask.cuda()
        self.img = imgs[0]

        # MLP-Model (SIREN)
        self.mlp = Siren(in_features=input_size, out_features=3, hidden_features=256,
                  hidden_layers=3, outermost_linear=True)
        
        # Loss
        self.loss = nn.L1Loss()
        self.H_loss = nn.MSELoss()

        
        # Images 
        self.imgs = imgs
        img = self.img
        C,H,W = self.img.size()
        
        
        # Homographies
        self.homographies = nn.ModuleList([Homography() for _ in imgs])

        # Video Writer
        self.out = cv.VideoWriter( video_file, cv.VideoWriter_fourcc('M','J','P','G'), 3, (W,H))
        
        # Ground Truth labels
        self.homographies_GT = H_GT
        self.img_GT = img_GT


    def forward(self, x):
        #x = x.clone().requires_grad_(True)
        if self.pos_enc:
            x = self.positional_encoding(x)

        return self.mlp(x)[0]

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        #loaders = [self.dataloader_ref] + [get_coords_dataloader(img) for img in imgs[1:]]
        
        return get_coords_dataloader(self.img)

    def val_dataloader(self):
        return get_coords_dataloader(self.img)

    def positional_encoding(self, input): # [B,...,N]
        L = self.L
        shape = input.shape
        freq = 2**torch.arange(L,dtype=torch.float32).to(input.device)*np.pi # [L]
        spectrum = input[...,None]*freq # [B,...,N,L]
        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
        
        # coarse-to-fine: smoothly mask positional encoding for BARF
        if self.barf_c2f:
            # set weights for different frequency bands
            start,end = 0,self.trainer.max_epochs
            progress = self.current_epoch
      
            alpha = (progress-start)/(end-start)*L
            k = torch.arange(L,dtype=torch.float32,device=input.device)
            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
            # apply weights
            shape = input_enc.shape
            input_enc = (input_enc.view(-1,L)*weight).view(*shape)
            
        return input_enc

    def training_step(self, batch, batch_idx):
        
        
        C,H,W = self.img_GT.size()
        x,y = batch
        #x_euc = x/x[:,2:]
        #y_hat = self(x_euc[:,:2])*self.mask
        #loss = self.loss(y_hat, y)
        #self.log(f'L_{0}', loss.item())
        grid = rearrange(x[:,:2], '(H W) C -> H W C', H=H,W=W).unsqueeze(0) # (1, H, W, 2)
        
        #losses = [loss]
        losses =[]


        for i,img in enumerate(self.imgs):
            
            
            T = self.homographies[i]                
            
            x_hom = (T(x.T)).T
            x_euc = x_hom/x_hom[:,2:] # normalize homogeneous coordinates

            y_hat = self(x_euc[:,:2]) # I(H(x))
            y_hat = rearrange(y_hat, '(H W) C -> C H W', H=H,W=W).unsqueeze(0) # (1, H, W, 3)
            y = F.grid_sample(img.unsqueeze(0), grid, align_corners=True)
            
            loss_i = self.loss(y_hat, y)
            #loss += loss_i
            losses.append(loss_i)
            self.log(f'L_{i}', loss_i.item())

            
            H_loss = self.H_loss(T.H.detach(), self.homographies_GT[i].H.detach())
            self.log(f'H{i}', H_loss.item())

            
        
        losses = sum(losses)
        self.log('loss', losses.item())
        return losses

    def render_image(self, T):
        
        C,H,W = self.img.size()
        x = create_meshgrid(H,W).squeeze().to(self.device)
        
        ones = torch.ones((H,W,1)).to(self.device)
        x = torch.concat([x,ones], dim=2)
        x = rearrange(x,'H W C -> (H W) C')

        x_hom = (T(x.T)).T
        x_euc = x_hom/x_hom[:,2:]
        y = self(x_euc[:,:2])

        y = rearrange(y, '(H W) C -> C H W ', H=H,W=W)    

        return y

    def on_train_epoch_end(self):
        
        if self.current_epoch % 10 == 0:
            results = []
            for i,T in enumerate(self.homographies):
                y = self.render_image(T)
                results.append(y)

                frame = rearrange(y, 'C H W -> H W C')    
                frame = frame.detach().cpu().numpy()
                frame = 255*frame
                frame = frame.astype(np.uint8)[:,:,::-1]
                if i == 0:
                    self.out.write(frame)
            show(results)
            show(self.imgs)
            
            patches = utils.draw_patches(results[0], self.homographies)
            patches_GT = utils.draw_patches(self.img_GT, self.homographies_GT)
            show([patches,patches_GT])
            plt.show()
            
#    def on_after_backward(self):
#        # example to inspect gradient information in tensorboard
#        if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
#            params = self.state_dict()
#            for k, v in params.items():
#                grads = v
#                name = k
#                if name.startswith("homographies"):
#                    print(name)
#                    print(grads, v.grad)

                #self.logger.log_histogram(tag="grads", values=grads, global_step=self.trainer.global_step)

In [None]:
import utils
img = read_image("baboon.png", torchvision.io.ImageReadMode.RGB)/255.


img_numpy = rearrange(img*255,'C H W -> H W C')
img_numpy = img_numpy.numpy().astype(np.uint8)
T, w, corners_H = utils.get_random_Warp()

C,H,W = img.size()

img_warps = list(zip(*[utils.get_random_Patch(img) for _ in range(5)]))
warps = list(img_warps[0])
Hs = list(img_warps[1])
print(Hs)
#plt.imshow(img_poly)
show(list(img_warps[0]))
show(utils.draw_patches(img, Hs))

K = 120
img_ref = img.clone()
img_mask = torch.ones_like(img)
img_mask[:,:K,:]= 0
img_mask[:,-K:,:]= 0
img_mask[:,:,:K]= 0
img_mask[:,:,-K:]= 0

img_mask = torch.ones_like(img_ref)
img_mask[:,:K,:]= 0
img_mask[:,-K:,:]= 0
img_mask[:,:,:K]= 0
img_mask[:,:,-K:]= 0

img_ref = img_ref*img_mask
warps = [img_ref]+warps
img_mask = rearrange(img,'c H W -> (H W) c')

show(warps)

In [None]:
COMET_ML_PROJECT = "barf"
torch.cuda.empty_cache()
experiment_name = "BARF"
    

comet_logger = CometLogger(
    api_key="tMEjeyq5M7v1IPRCvS5fyGyuo",
    workspace="semjon", # Optional
    project_name= COMET_ML_PROJECT, # Optional
    # rest_api_key=os.environ["COMET_REST_KEY"], 
    #save_dir='./segmentation',
    experiment_name=experiment_name, # Optional,
    #display_summary_level = 0
)

imgs = [img.cuda() for img in warps]
Hs = [Homography().cuda()]+[H.cuda() for H in Hs]

model = BARF_PL(imgs, img_mask, img_GT=img, H_GT=Hs, pos_enc=True, barf_c2f=True).cuda()
trainer = pl.Trainer(accelerator="gpu", logger=comet_logger,log_every_n_steps=1, max_epochs=3000)
trainer.fit(model)
model.out.release()

In [None]:
w = torch.zeros(8)

w[0] = 0.25
w[1] = 0.25

w[5] = 0.7
w[4] = -0.5*w[5]

H = Homography(w)

show(utils.draw_patches(img, [H]))
show(utils.get_Patch(img, H))

# Apply Homography

In [None]:
H = torch.tensor([[1.0, 0, 0.5],
                 [0, 1.0, 0.5],
                 [0, 0, 1]]).to(model.device)
H_inv = torch.linalg.inv(H)

C,H,W = model.img.size()
x = create_meshgrid(H,W).squeeze().to(model.device)
ones = torch.ones((H,W,1)).to(model.device)
x_hom = torch.concat([x,ones], dim=2)
x_hom = rearrange(x_hom,'H W C -> (H W) C')

x_hom = (H_inv@x_hom.T).T

x_euc = x_hom/x_hom[:,2:]
y = model(x_euc[:,:2])
y = rearrange(y, '(H W) C -> C H W ', H=H,W=W)    
show([model.img,y])
plt.show()