## <font style="color:lightblue">Header</font>

### <font style="color:lightblue">Imports</font>

In [2]:

import IPython
import sys
import os
import random
import time
from dataclasses import dataclass
from enum import Enum

import math
import statistics
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as fn
from torch import optim
from torchvision import transforms
from torchinfo import summary

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.image import imread, imsave
import h5py
import tifffile
import tqdm




### <font style="color:lightblue">Functions</font>

In [None]:


def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def plotData(dataY, rangeY=None, dataYR=None, rangeYR=None,
             dataX=None, rangeX=None, rangeP=None,
             figsize=(16,8), saveTo=None, show=True):

    if type(dataY) is np.ndarray :
        plotData((dataY,), rangeY=rangeY, dataYR=dataYR, rangeYR=rangeYR,
             dataX=dataX, rangeX=rangeX, rangeP=rangeP,
             figsize=figsize, saveTo=saveTo, show=show)
        return
    if type(dataYR) is np.ndarray :
        plotData(dataY, rangeY=rangeY, dataYR=(dataYR,), rangeYR=rangeYR,
             dataX=dataX, rangeX=rangeX, rangeP=rangeP,
             figsize=figsize, saveTo=saveTo, show=show)
        return
    if type(dataY) is not tuple :
        eprint(f"Unknown data type to plot: {type(dataY)}.")
        return
    if type(dataYR) is not tuple and dataYR is not None:
        eprint(f"Unknown data type to plot: {type(dataYR)}.")
        return

    last = min( len(data) for data in dataY )
    if dataYR is not None:
        last = min( last,  min( len(data) for data in dataYR ) )
    if dataX is not None:
        last = min(last, len(dataX))
    if rangeP is None :
        rangeP = (0,last)
    elif type(rangeP) is int :
        rangeP = (0,rangeP) if rangeP > 0 else (-rangeP,last)
    elif type(rangeP) is tuple :
        rangeP = ( 0    if rangeP[0] is None else rangeP[0],
                   last if rangeP[1] is None else rangeP[1],)
    else :
        eprint(f"Bad data type on plotData input rangeP: {type(rangeP)}")
        raise Exception(f"Bug in the code.")
    rangeP = np.s_[ max(0, rangeP[0]) : min(last, rangeP[1]) ]
    if dataX is None :
        dataX = np.arange(rangeP.start, rangeP.stop)

    plt.style.use('default')
    plt.style.use('dark_background')
    fig, ax1 = plt.subplots(figsize=figsize)
    ax1.xaxis.grid(True, 'both', linestyle='dotted')
    if rangeX is not None :
        ax1.set_xlim(rangeX)
    else :
        ax1.set_xlim(rangeP.start,rangeP.stop-1)

    ax1.yaxis.grid(True, 'both', linestyle='dotted')
    nofPlots = len(dataY)
    if rangeY is not None:
        ax1.set_ylim(rangeY)
    colors = [ matplotlib.colors.hsv_to_rgb((hv/nofPlots, 1, 1)) for hv in range(nofPlots) ]
    for idx , data in enumerate(dataY):
        ax1.plot(dataX, data[rangeP], linestyle='-',  color=colors[idx])

    if dataYR is not None : # right Y axis
        ax2 = ax1.twinx()
        ax2.yaxis.grid(True, 'both', linestyle='dotted')
        nofPlots = len(dataYR)
        if rangeYR is not None:
            ax2.set_ylim(rangeYR)
        colors = [ matplotlib.colors.hsv_to_rgb((hv/nofPlots, 1, 1)) for hv in range(nofPlots) ]
        for idx , data in enumerate(dataYR):
            ax2.plot(dataX, data[rangeP], linestyle='dashed',  color=colors[idx])

    if saveTo:
        fig.savefig(saveTo)
    if not show:
        plt.close(fig)


def plotImage(image) :
    plt.imshow(image, cmap='gray')
    plt.axis("off")
    plt.show()


def plotImages(images) :
    for i, img in enumerate(images) :
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(img.squeeze(), cmap='gray')
        plt.axis("off")
    plt.show()


def tensorStat(stat) :
    print(f"{stat.mean().item():.3e}, {stat.std().item():.3e}, "
          f"{stat.min().item():.3e}, {stat.max().item():.3e}")


def fillWheights(seq) :
    for wh in seq :
        if hasattr(wh, 'weight') :
            torch.nn.init.xavier_uniform_(wh.weight)
            #torch.nn.init.zeros_(wh.weight)
            #torch.nn.init.constant_(wh.weight, 0)
            #torch.nn.init.uniform_(wh.weight, a=0.0, b=1.0, generator=None)
            #torch.nn.init.normal_(wh.weight, mean=0.0, std=0.01)


def unsqeeze4dim(tens):
    orgDims = tens.dim()
    if tens.dim() == 2 :
        tens = tens.unsqueeze(0)
    if tens.dim() == 3 :
        tens = tens.unsqueeze(1)
    return tens, orgDims


def squeezeOrg(tens, orgDims):
    if orgDims == tens.dim():
        return tens
    if tens.dim() != 4 or orgDims > 4 or orgDims < 2:
        raise Exception(f"Unexpected dimensions to squeeze: {tens.dim()} {orgDims}.")
    if orgDims < 4 :
        if tens.shape[1] > 1:
            raise Exception(f"Cant squeeze dimension 1 in: {tens.shape}.")
        tens = tens.squeeze(1)
    if orgDims < 3 :
        if tens.shape[0] > 1:
            raise Exception(f"Cant squeeze dimension 0 in: {tens.shape}.")
        tens = tens.squeeze(0)
    return tens


### <font style="color:lightblue">Configs</font>

In [None]:
def set_seed(SEED_VALUE):
    torch.manual_seed(SEED_VALUE)
    torch.cuda.manual_seed(SEED_VALUE)
    torch.cuda.manual_seed_all(SEED_VALUE)
    np.random.seed(SEED_VALUE)

seed = 7
set_seed(seed)

@dataclass(frozen=True)
class TCfg:
    exec = 1
    device: torch.device = f"cuda:{exec}"
    latentDim: int = 64

class DCfg:
    gapW = 16
    sinoSh = (5*gapW,5*gapW) # 80x80
    readSh = (80, 80)
    sinoSize = math.prod(sinoSh)
    gapSh = (sinoSh[0],gapW)
    gapSize = math.prod(gapSh)
    gapRngX = np.s_[ sinoSh[1]//2 - gapW//2 : sinoSh[1]//2 + gapW//2 ]
    gapRng = np.s_[ : , gapRngX ]
    disRng = np.s_[ gapW:-gapW , gapRngX ]




### <font style="color:lightblue">Save/Load</font>

In [None]:
def load_model(model, model_path):
    model.load_state_dict(torch.load(model_path))
    return model


def addToHDF(filename, containername, data) :
    if len(data.shape) == 2 :
        data=np.expand_dims(data, 0)
    if len(data.shape) != 3 :
        raise Exception(f"Not appropriate input array size {data.shape}.")

    with h5py.File(filename,'a') as file :

        if  containername not in file.keys():
            dset = file.create_dataset(containername, data.shape,
                                       maxshape=(None,data.shape[1],data.shape[2]),
                                       dtype='f')
            dset[()] = data
            return

        dset = file[containername]
        csh = dset.shape
        if csh[1] != data.shape[1] or csh[2] != data.shape[2] :
            raise Exception(f"Shape mismatch: input {data.shape}, file {dset.shape}.")
        msh = dset.maxshape
        newLen = csh[0] + data.shape[0]
        if msh[0] is None or msh[0] >= newLen :
            dset.resize(newLen, axis=0)
        else :
            raise Exception(f"Insufficient maximum shape {msh} to add data"
                            f" {data.shape} to current volume {dset.shape}.")
        dset[csh[0]:newLen,...] = data
        file.close()


    return 0

## <font style="color:lightblue">Data</font>

### <font style="color:lightblue">Raw Read</font>

In [None]:

def getInData(inputString):
    sampleHDF = inputString.split(':')
    if len(sampleHDF) != 2 :
        raise Exception(f"String \"{inputString}\" does not represent an HDF5 format \"fileName:container\".")
    try :
        trgH5F =  h5py.File(sampleHDF[0],'r')
    except :
        raise Exception(f"Failed to open HDF file '{sampleHDF[0]}'.")
    if  sampleHDF[1] not in trgH5F.keys():
        raise Exception(f"No dataset '{sampleHDF[1]}' in input file {sampleHDF[0]}.")
    data = trgH5F[sampleHDF[1]]
    if not data.size :
        raise Exception(f"Container \"{inputString}\" is zero size.")
    sh = data.shape
    if len(sh) != 3 :
        raise Exception(f"Dimensions of the container \"{inputString}\" is not 3: {sh}.")
    return data


def getOutData(outputString, shape) :
    if len(shape) == 2 :
        shape = (1,*shape)
    if len(shape) != 3 :
        raise Exception(f"Not appropriate output array size {shape}.")

    sampleHDF = outputString.split(':')
    if len(sampleHDF) != 2 :
        raise Exception(f"String \"{outputString}\" does not represent an HDF5 format \"fileName:container\".")
    try :
        trgH5F =  h5py.File(sampleHDF[0],'w')
    except :
        raise Exception(f"Failed to open HDF file '{sampleHDF[0]}'.")

    if  sampleHDF[1] not in trgH5F.keys():
        dset = trgH5F.create_dataset(sampleHDF[1], shape, dtype='f')
    else :
        dset = trgH5F[sampleHDF[1]]
        csh = dset.shape
        if csh[0] < shape[0] or csh[1] != shape[1] or csh[2] != shape[2] :
            raise Exception(f"Shape mismatch: input {shape}, file {dset.shape}.")
    return dset, trgH5F



## <font style="color:lightblue">Models</font>

### <font style="color:lightblue">Lower resolution generators</font>

#### <font style="color:lightblue">Two pixels gap</font>

In [None]:


class Generator2(nn.Module):

    def __init__(self):
        super(Generator2, self).__init__()

        self.gapW = 2
        self.sinoSh = (5*self.gapW,5*self.gapW) # 10,10
        self.sinoSize = math.prod(self.sinoSh)
        self.gapSh = (self.sinoSh[0],self.gapW)
        self.gapSize = math.prod(self.gapSh)
        self.gapRngX = np.s_[ self.sinoSh[1]//2 - self.gapW//2 : self.sinoSh[1]//2 + self.gapW//2 ]
        self.gapRng = np.s_[ : , self.gapRngX ]

        latentChannels = 7
        self.noise2latent = nn.Sequential(
            nn.Linear(TCfg.latentDim, self.sinoSize*latentChannels),
            nn.ReLU(),
            nn.Unflatten( 1, (latentChannels,) + self.sinoSh )
        )
        fillWheights(self.noise2latent)

        baseChannels = 64

        self.encode = nn.Sequential(

            nn.Conv2d(latentChannels+1, baseChannels, 3),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3),
            nn.LeakyReLU(0.2),

        )
        fillWheights(self.encode)

        encSh = self.encode(torch.zeros((1,latentChannels+1,*self.sinoSh))).shape
        linChannels = math.prod(encSh)
        self.link = nn.Sequential(
            nn.Flatten(),
            nn.Linear(linChannels, linChannels),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, encSh[1:]),
        )
        fillWheights(self.link)


        self.decode = nn.Sequential(

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1)),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1)),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1)),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, 1, (1,3)),
            nn.Tanh()

        )
        fillWheights(self.decode)


        self.body = nn.Sequential(
            self.encode,
            self.link,
            self.decode
        )


    def forward(self, input):
        images, noises = input
        images, orgDims = unsqeeze4dim(images)
        latent = self.noise2latent(noises)
        modelIn = torch.cat((images,latent),dim=1)
        mIn = modelIn[:,0,*self.gapRng]
        mIn[()] = self.preProc(images[:,0,:,:])
        patches = self.body(modelIn)
        mIn = mIn.unsqueeze(1)
        #patches = mIn + torch.where( patches < 0 , patches * mIn , patches ) # no normalization
        patches = mIn + patches * torch.where( patches < 0 , mIn+0.5 , 1 ) # normalization
        return squeezeOrg(patches, orgDims)


    def preProc(self, images) :
        images = images.unsqueeze(0) # for the 2D case
        res = torch.zeros(images[...,*self.gapRng].shape, device=images.device)
        res[...,0] += 2*images[...,self.gapRngX.start-1] + images[...,self.gapRngX.stop]
        res[...,1] += 2*images[...,self.gapRngX.stop] + images[...,self.gapRngX.start-1]
        res = res.squeeze(0) # to compensate for the first squeeze
        return res/3

    def generatePatches(self, images, noises=None) :
        if noises is None :
            noises = torch.randn( 1 if images.dim() < 3 else images.shape[0], TCfg.latentDim).to(TCfg.device)
        return self.forward((images,noises))


    def fillImages(self, images, noises=None) :
        images[...,*self.gapRng] = self.generatePatches(images, noises)
        return images


    def generateImages(self, images, noises=None) :
        clone = images.clone()
        return self.fillImages(clone, noises)


generator2 = Generator2()
generator2 = load_model(generator2, model_path="saves/gap2_cor.model_gen.pt" )
generator2.to(TCfg.device)
generator2.requires_grad_(False)
generator2.eval()
#model_summary = summary(generator2, input_data=[ [refImages, refNoises] ] ).__str__()
#print(model_summary)




  model.load_state_dict(torch.load(model_path))


Generator2(
  (noise2latent): Sequential(
    (0): Linear(in_features=64, out_features=700, bias=True)
    (1): ReLU()
    (2): Unflatten(dim=1, unflattened_size=(7, 10, 10))
  )
  (encode): Sequential(
    (0): Conv2d(8, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): LeakyReLU(negative_slope=0.2)
  )
  (link): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Unflatten(dim=1, unflattened_size=torch.Size([64, 4, 4]))
  )
  (decode): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 1), stride=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): ConvTranspose2d(64, 64, kernel_size=(3, 1), stride=(1, 1))
    (3): LeakyReLU(negative_slope=0.2)
    (4): ConvTranspose2d

#### <font style="color:lightblue">Four pixels gap</font>

In [None]:


class Generator4(nn.Module):

    def __init__(self):
        super(Generator4, self).__init__()

        self.gapW = 4
        self.sinoSh = (5*self.gapW,5*self.gapW) # 20,20
        self.sinoSize = math.prod(self.sinoSh)
        self.gapSh = (self.sinoSh[0],self.gapW)
        self.gapSize = math.prod(self.gapSh)
        self.gapRngX = np.s_[ self.sinoSh[1]//2 - self.gapW//2 : self.sinoSh[1]//2 + self.gapW//2 ]
        self.gapRng = np.s_[ : , self.gapRngX ]

        latentChannels = 7
        self.noise2latent = nn.Sequential(
            nn.Linear(TCfg.latentDim, self.sinoSize*latentChannels),
            nn.ReLU(),
            nn.Unflatten( 1, (latentChannels,) + self.sinoSh )
        )
        fillWheights(self.noise2latent)

        baseChannels = 128

        self.encode = nn.Sequential(

            nn.Conv2d(latentChannels+1, baseChannels, 3),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3, stride=2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3),
            nn.LeakyReLU(0.2),

        )
        fillWheights(self.encode)


        encSh = self.encode(torch.zeros((1,latentChannels+1,*self.sinoSh))).shape
        linChannels = math.prod(encSh)
        self.link = nn.Sequential(
            nn.Flatten(),
            nn.Linear(linChannels, linChannels),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, encSh[1:]),
        )
        fillWheights(self.link)


        self.decode = nn.Sequential(

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1)),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1)),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (4,1), stride=(2,1)),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1)),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, 1, 1),
            nn.Tanh()

        )
        fillWheights(self.decode)


        self.body = nn.Sequential(
            self.encode,
            self.link,
            self.decode
        )


    def forward(self, input):
        images, noises = input
        images, orgDims = unsqeeze4dim(images)
        latent = self.noise2latent(noises)
        modelIn = torch.cat((images,latent),dim=1)
        mIn = modelIn[:,0,*self.gapRng]
        mIn[()] = self.preProc(images[:,0,:,:])
        patches = self.body(modelIn)
        mIn = mIn.unsqueeze(1)
        #patches = mIn + torch.where( patches < 0 , patches * mIn , patches ) # no normalization
        patches = mIn + patches * torch.where( patches < 0 , mIn+0.5 , 1 ) # normalization
        return squeezeOrg(patches, orgDims)


    def preProc(self, images) :
        images, orgDims = unsqeeze4dim(images)
        preImages = torch.nn.functional.interpolate(images, scale_factor=0.5, mode='area')
        prePatches = generator2.generatePatches(preImages)
        prePatches = torch.nn.functional.interpolate(prePatches, scale_factor=2, mode='bilinear')
        return squeezeOrg(prePatches, orgDims)


    def generatePatches(self, images, noises=None) :
        if noises is None :
            noises = torch.randn( 1 if images.dim() < 3 else images.shape[0], TCfg.latentDim).to(TCfg.device)
        return self.forward((images,noises))


    def fillImages(self, images, noises=None) :
        images[...,*self.gapRng] = self.generatePatches(images, noises)
        return images


    def generateImages(self, images, noises=None) :
        clone = images.clone()
        return self.fillImages(clone, noises)


generator4 = Generator4()
generator4 = load_model(generator4, model_path="saves/gap4_cor.model_gen.pt" )
generator4.to(TCfg.device)
generator4.requires_grad_(False)
generator4.eval()
#model_summary = summary(generator4, input_data=[ [refImages, refNoises] ] ).__str__()
#print(model_summary)




  model.load_state_dict(torch.load(model_path))


Generator4(
  (noise2latent): Sequential(
    (0): Linear(in_features=64, out_features=2800, bias=True)
    (1): ReLU()
    (2): Unflatten(dim=1, unflattened_size=(7, 20, 20))
  )
  (encode): Sequential(
    (0): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (7): LeakyReLU(negative_slope=0.2)
  )
  (link): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=2048, out_features=2048, bias=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Unflatten(dim=1, unflattened_size=torch.Size([128, 4, 4]))
  )
  (decode): Sequential(
    (0): ConvTranspose2d(128, 128, kernel_size=(3, 1), stride=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): ConvTranspose

#### <font style="color:lightblue">Eight pixels gap</font>

In [None]:


class Generator8(nn.Module):

    def __init__(self):
        super(Generator8, self).__init__()

        self.gapW = 8
        self.sinoSh = (5*self.gapW,5*self.gapW) # 20,20
        self.sinoSize = math.prod(self.sinoSh)
        self.gapSh = (self.sinoSh[0],self.gapW)
        self.gapSize = math.prod(self.gapSh)
        self.gapRngX = np.s_[ self.sinoSh[1]//2 - self.gapW//2 : self.sinoSh[1]//2 + self.gapW//2 ]
        self.gapRng = np.s_[ : , self.gapRngX ]

        latentChannels = 7
        self.noise2latent = nn.Sequential(
            nn.Linear(TCfg.latentDim, self.sinoSize*latentChannels),
            nn.ReLU(),
            nn.Unflatten( 1, (latentChannels,) + self.sinoSh )
        )
        fillWheights(self.noise2latent)

        baseChannels = 256

        self.encode = nn.Sequential(

            nn.Conv2d(latentChannels+1, baseChannels, 3, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3, stride=2, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3, stride=2, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, baseChannels, 3, bias=False),
            nn.LeakyReLU(0.2),


        )
        fillWheights(self.encode)


        encSh = self.encode(torch.zeros((1,latentChannels+1,*self.sinoSh))).shape
        linChannels = math.prod(encSh)
        self.link = nn.Sequential(
            nn.Flatten(),
            nn.Linear(linChannels, linChannels),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, encSh[1:]),
        )
        fillWheights(self.link)


        self.decode = nn.Sequential(

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (4,1), stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (4,3), stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(baseChannels, baseChannels, (3,3), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(baseChannels, 1, 1, bias=False),
            nn.Tanh()

        )
        fillWheights(self.decode)


        self.body = nn.Sequential(
            self.encode,
            self.link,
            self.decode
        )


    def forward(self, input):
        images, noises = input
        images, orgDims = unsqeeze4dim(images)
        latent = self.noise2latent(noises)
        modelIn = torch.cat((images,latent),dim=1)
        mIn = modelIn[:,0,*self.gapRng]
        mIn[()] = self.preProc(images[:,0,:,:])
        patches = self.body(modelIn)
        #return patches
        mIn = mIn.unsqueeze(1)
        #patches = mIn + torch.where( patches < 0 , patches * mIn , patches ) # no normalization
        patches = mIn + patches * torch.where( patches < 0 , mIn+0.5 , 1 ) # normalization
        return squeezeOrg(patches, orgDims)


    def preProc(self, images) :
        images, orgDims = unsqeeze4dim(images)
        preImages = torch.nn.functional.interpolate(images, scale_factor=0.5, mode='area')
        prePatches = generator4.generatePatches(preImages)
        prePatches = torch.nn.functional.interpolate(prePatches, scale_factor=2, mode='bilinear')
        return squeezeOrg(prePatches, orgDims)


    def generatePatches(self, images, noises=None) :
        if noises is None :
            noises = torch.randn( 1 if images.dim() < 3 else images.shape[0], TCfg.latentDim).to(TCfg.device)
        return self.forward((images,noises))


    def fillImages(self, images, noises=None) :
        images[...,*self.gapRng] = self.generatePatches(images, noises)
        return images


    def generateImages(self, images, noises=None) :
        clone = images.clone()
        return self.fillImages(clone, noises)


generator8 = Generator8()
generator8 = load_model(generator8, model_path="saves/gap8_cor.model_gen.pt" )
generator8.to(TCfg.device)
generator8.requires_grad_(False)
generator8.eval()
#model_summary = summary(generator8, input_data=[ [refImages, refNoises] ] ).__str__()
#print(model_summary)




  model.load_state_dict(torch.load(model_path))


Generator8(
  (noise2latent): Sequential(
    (0): Linear(in_features=64, out_features=11200, bias=True)
    (1): ReLU()
    (2): Unflatten(dim=1, unflattened_size=(7, 40, 40))
  )
  (encode): Sequential(
    (0): Conv2d(8, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (9): LeakyReLU(negative_slope=0.2)
  )
  (link): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4096, out_features=4096, bias=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Unflatten(dim=1, unflattened_size=torch.Size([256, 4, 4])

### <font style="color:lightblue">Generator</font>

In [None]:


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.gapW = DCfg.gapW
        self.sinoSh = (5*self.gapW,5*self.gapW) # 80,80
        self.sinoSize = math.prod(self.sinoSh)
        self.gapSh = (self.sinoSh[0],self.gapW)
        self.gapSize = math.prod(self.gapSh)
        self.gapRngX = np.s_[ self.sinoSh[1]//2 - self.gapW//2 : self.sinoSh[1]//2 + self.gapW//2 ]
        self.gapRng = np.s_[ : , self.gapRngX ]

        latentChannels = 7
        self.noise2latent = nn.Sequential(
            nn.Linear(TCfg.latentDim, self.sinoSize*latentChannels),
            nn.ReLU(),
            nn.Unflatten( 1, (latentChannels,) + self.sinoSh )
        )
        fillWheights(self.noise2latent)

        baseChannels = 64


        def encblock(chIn, chOut, kernel, stride=1) :
            return nn.Sequential (
                nn.Conv2d(chIn, chOut, kernel, stride=stride, bias=True),
                #nn.BatchNorm2d(chOut),
                nn.LeakyReLU(0.2),
                #nn.ReLU(),
            )
        self.encode = nn.Sequential(
            encblock(  latentChannels+1,   baseChannels, 3),
            encblock(  baseChannels,     2*baseChannels, 3, stride=2),
            encblock(2*baseChannels,     2*baseChannels, 3),
            encblock(2*baseChannels,     2*baseChannels, 3),
            encblock(2*baseChannels,     4*baseChannels, 3, stride=2),
            encblock(4*baseChannels,     4*baseChannels, 3),
            encblock(4*baseChannels,     8*baseChannels, 3, stride=2),
            encblock(8*baseChannels,     8*baseChannels, 3),
        )
        fillWheights(self.encode)


        encSh = self.encode(torch.zeros((1,latentChannels+1,*self.sinoSh))).shape
        linChannels = math.prod(encSh)
        self.link = nn.Sequential(
            nn.Flatten(),
            nn.Linear(linChannels, linChannels),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, encSh[1:]),
        )
        fillWheights(self.link)

        def decblock(chIn, chOut, kernel, stride=1) :
            return nn.Sequential (
                nn.ConvTranspose2d(chIn, chOut, kernel, stride, bias=False),
                #nn.BatchNorm2d(chOut),
                nn.LeakyReLU(0.2),
                #nn.ReLU(),
            )
        self.decode = nn.Sequential(
            decblock(8*baseChannels, 8*baseChannels, 3),
            decblock(8*baseChannels, 4*baseChannels, 4, stride=2),
            decblock(4*baseChannels, 4*baseChannels, 3),
            decblock(4*baseChannels, 2*baseChannels, 4, stride=2),
            decblock(2*baseChannels, 2*baseChannels, 3),
            decblock(2*baseChannels, 2*baseChannels, 3),
            decblock(2*baseChannels,   baseChannels, 4, stride=2),
            decblock(baseChannels, baseChannels, 3),

            nn.Conv2d(baseChannels, 1, 1),
            nn.Tanh()
        )
        fillWheights(self.decode)


        self.body = nn.Sequential(
            self.encode,
            self.link,
            self.decode
        )


    def forward(self, input):
        images, noises = input
        images, orgDims = unsqeeze4dim(images)
        latent = self.noise2latent(noises)
        modelIn = torch.cat((images,latent),dim=1)
        mIn = modelIn[:,0,*self.gapRng]
        mIn[()] = self.preProc(images[:,0,:,:])
        patches = self.body(modelIn)[...,self.gapRngX]
        #return patches
        mIn = mIn.unsqueeze(1)
        patches = mIn + patches * torch.where( patches < 0 , mIn+0.5 , 1 ) # normalization
        return squeezeOrg(patches, orgDims)


    def preProc(self, images) :
        images, orgDims = unsqeeze4dim(images)
        preImages = torch.nn.functional.interpolate(images, scale_factor=0.5, mode='area')
        prePatches = generator8.generatePatches(preImages)
        prePatches = torch.nn.functional.interpolate(prePatches, scale_factor=2, mode='bilinear')
        return squeezeOrg(prePatches, orgDims)


    def generatePatches(self, images, noises=None) :
        if noises is None :
            noises = torch.randn( 1 if images.dim() < 3 else images.shape[0], TCfg.latentDim).to(TCfg.device)
        return self.forward((images,noises))


    def fillImages(self, images, noises=None) :
        images[...,*self.gapRng] = self.generatePatches(images, noises)
        return images


    def generateImages(self, images, noises=None) :
        clone = images.clone()
        return self.fillImages(clone, noises)



generator = Generator()
generator = load_model(generator, model_path="model_1_gen.pt" )
generator = generator.to(TCfg.device)
generator = generator.requires_grad_(False)
generator = generator.eval()
#model_summary = summary(generator, input_data=[ [refImages, refNoises] ] ).__str__()
#print(model_summary)




  model.load_state_dict(torch.load(model_path))


## <font style="color:lightblue">Fill sinogram</font>

In [None]:

def fillSinogram(sinogram) :

    sinoW = sinogram.shape[-1]
    sinoL = sinogram.shape[-2]
    if sinoW % 5 :
        raise Exception(f"Sinogram width {sinoW} is not devisable bny 5.")
    blockW = sinoW // 5
    sinogram, _ = unsqeeze4dim(sinogram)
    sinogram = sinogram.to(TCfg.device)
    resizedSino = torch.zeros(( 1 , 1 , sinoL , DCfg.sinoSh[1] ), device=TCfg.device)
    resizedSino[ ... , : 2*DCfg.gapW ] = torch.nn.functional.interpolate(
        sinogram[ ... , : 2*blockW ], size=( sinoL , 2*DCfg.gapW ), mode='bilinear')
    resizedSino[ ... , 2*DCfg.gapW : 3*DCfg.gapW ] = torch.nn.functional.interpolate(
        sinogram[ ... , : 2*blockW : 3*blockW ], size=( sinoL , DCfg.gapW ), mode='bilinear')
    resizedSino[ ... , 3*DCfg.gapW:] = torch.nn.functional.interpolate(
        sinogram[ ... , 3*blockW : ], size=( sinoL , 2*DCfg.gapW ), mode='bilinear')

    blockH = DCfg.sinoSh[0]
    sinoCutStep = DCfg.gapW
    lastStart = sinoL - blockH
    nofBlocks, lastBlock = divmod(lastStart, sinoCutStep)
    modelIn = torch.empty( ( nofBlocks + bool(lastBlock) , 1 , *DCfg.sinoSh ), device=TCfg.device )
    for block in range(nofBlocks) :
        modelIn[ block, 0, ... ] = resizedSino[0 , 0, block * sinoCutStep : block * sinoCutStep + blockH , : ]
    if lastBlock :
        modelIn[ -1, 0, ... ] = resizedSino[0,0, -blockH : , : ]

    mytransforms = transforms.Compose([
        transforms.Normalize(mean=(0.5), std=(1))
    ])
    modelIn = mytransforms(modelIn)

    modelIn[ -1, 0, ... ] = modelIn[ -1, 0, ... ].flip(dims=(-2,)) # to get rid of the deffect in the end
    results = None
    with torch.no_grad() :
        results = generator.generatePatches(modelIn)
    results[ -1, 0, ... ] = results[ -1, 0, ... ].flip(dims=(-2,)) # to flip back

    if lastBlock :
        newLast = torch.zeros(DCfg.gapSh, device=TCfg.device)
        newLast[:-lastBlock,:] = results[-1,0,lastBlock:,:]
        results[-1,0,...] = newLast
    preBlocks = torch.zeros((4,1,*DCfg.gapSh), device=TCfg.device)
    pstBlocks = torch.zeros((4,1,*DCfg.gapSh), device=TCfg.device)
    for curs in range(4) :
        preBlocks[ -curs-1 , 0 , sinoCutStep*(curs+1) :  , : ] = results[ 0 , 0 , : -sinoCutStep*(curs+1) , : ]
        pstBlocks[ curs , 0 , : (-sinoCutStep*curs) if curs else (blockH+1) , : ] = results[ -1 , 0 , sinoCutStep*curs : , : ]
    resultsPatched = torch.cat( (preBlocks, results, pstBlocks), dim=0 )

    blockCut = blockH / 5
    profileWeight = torch.empty( (blockH,), device=TCfg.device )
    for curi in range(blockH) :
        if curi < blockCut :
            profileWeight[curi] = 0
        elif curi < 2 * blockCut :
            profileWeight[curi] = ( curi - blockCut ) / blockCut
        elif curi < 3 * blockCut :
            profileWeight[curi] = 1
        elif curi < 4 * blockCut :
            profileWeight[curi] = ( 4*blockCut - curi ) / blockCut
        else :
            profileWeight[curi] = 0
    #plotData(profileWeight.numpy())
    resultsProfiled = ( resultsPatched + 0.5 ) * profileWeight.view(1,1,-1,1)
    stitchedGap = torch.zeros( ( (resultsProfiled.shape[0]-1) * sinoCutStep + blockH, DCfg.gapW ), device=TCfg.device )
    for curblock in range(resultsProfiled.shape[0]) :
        stitchedGap[ curblock*sinoCutStep : curblock*sinoCutStep + blockH , : ] += resultsProfiled[curblock,0,...]
    stitchedGap = stitchedGap.unsqueeze(0).unsqueeze(0)
    resizedGap = torch.nn.functional.interpolate(
        stitchedGap, size=( stitchedGap.shape[-2] ,  blockW), mode='bilinear')

    sinogram[..., 2*blockW : 3*blockW ] = resizedGap[0,0, sinoCutStep*4 : sinoCutStep*4 + sinoL, : ] / 2
    return sinogram




## <font style="color:lightblue">Test</font>

In [None]:
#inSinogram = torch.tensor(inData[:,300,605:665], device=TCfg.device)
#print(inSinogram.shape)
#
#filledSinogram = fillSinogram(inSinogram).squeeze()
#tifffile.imwrite("tmp.tif", filledSinogram.cpu().numpy())
#plotImage(filledSinogram.transpose(0,1).cpu().numpy())

## <font style="color:lightblue">Execute</font>

In [None]:
#inputString = "/home/imbl/usr/src/bctppl/data/clean_org_0005_2650_AM.hdf:/data"
inputString = "/home/imbl/usr/src/bctppl/data/clean_sft_2701_5346_AM.hdf:/data"
gapsToProc = [
    np.s_[113:117],
    np.s_[629:641],
    np.s_[1153:1157],
    np.s_[1669:1681],
    np.s_[2193:2197],
]
inData = getInData(inputString)
outputString = "/home/imbl/usr/src/bctppl/data/result.hdf:/data"

outData, outFile = getOutData(outputString, inData.shape)
try :
    for curSl in tqdm.tqdm(range(inData.shape[-2])):
        inSinogram = torch.tensor(inData[:,curSl,:], device=TCfg.device)
        for gap in gapsToProc :
            gapW = gap.stop-gap.start
            stripe=np.s_[ gap.start - 2*gapW : gap.stop + 2*gapW]
            stripeData = inSinogram[:,stripe]
            filledData = fillSinogram(stripeData).squeeze()
            inSinogram[:,stripe] = filledData
        outData[:,curSl,:] = inSinogram.cpu().numpy()
except :
    outFile.close()
    raise
outFile.close()
print("Done")

  0%|          | 0/934 [00:00<?, ?it/s]

100%|██████████| 934/934 [09:57<00:00,  1.56it/s]

Done



