## <font style="color:blue">Header</font>

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import random
from dataclasses import dataclass
from enum import Enum

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as fn
from torch import optim
from torchvision import transforms
from torchinfo import summary

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.image import imread, imsave
import h5py
import tifffile

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def plotData(dataY, rangeY=None, dataYR=None, rangeYR=None,
             dataX=None, rangeX=None, rangeP=None,
             figsize=(16,8), saveTo=None, show=True):

    if type(dataY) is np.ndarray :
        plotData((dataY,), rangeY=rangeY, dataYR=dataYR, rangeYR=rangeYR,
             dataX=dataX, rangeX=rangeX, rangeP=rangeP,
             figsize=figsize, saveTo=saveTo, show=show)
        return
    if type(dataYR) is np.ndarray :
        plotData(dataY, rangeY=rangeY, dataYR=(dataYR,), rangeYR=rangeYR,
             dataX=dataX, rangeX=rangeX, rangeP=rangeP,
             figsize=figsize, saveTo=saveTo, show=show)
        return
    if type(dataY) is not tuple :
        eprint(f"Unknown data type to plot: {type(dataY)}.")
        return
    if type(dataYR) is not tuple and dataYR is not None:
        eprint(f"Unknown data type to plot: {type(dataYR)}.")
        return

    last = min( len(data) for data in dataY )
    if dataYR is not None:
        last = min( last,  min( len(data) for data in dataYR ) )
    if dataX is not None:
        last = min(last, len(dataX))
    if rangeP is None :
        rangeP = (0,last)
    elif type(rangeP) is int :
        rangeP = (0,rangeP) if rangeP > 0 else (-rangeP,last)
    elif type(rangeP) is tuple :
        rangeP = ( 0    if rangeP[0] is None else rangeP[0],
                   last if rangeP[1] is None else rangeP[1],)
    else :
        eprint(f"Bad data type on plotData input rangeP: {type(rangeP)}")
        raise Exception(f"Bug in the code.")
    rangeP = np.s_[ max(0, rangeP[0]) : min(last, rangeP[1]) ]
    if dataX is None :
        dataX = np.arange(rangeP.start, rangeP.stop)

    plt.style.use('default')
    plt.style.use('dark_background')
    fig, ax1 = plt.subplots(figsize=figsize)
    ax1.xaxis.grid(True, 'both', linestyle='dotted')
    if rangeX is not None :
        ax1.set_xlim(rangeX)
    else :
        ax1.set_xlim(rangeP.start,rangeP.stop-1)

    ax1.yaxis.grid(True, 'both', linestyle='dotted')
    nofPlots = len(dataY)
    if rangeY is not None:
        ax1.set_ylim(rangeY)
    colors = [ matplotlib.colors.hsv_to_rgb((hv/nofPlots, 1, 1)) for hv in range(nofPlots) ]
    for idx , data in enumerate(dataY):
        ax1.plot(dataX, data[rangeP], linestyle='-',  color=colors[idx])

    if dataYR is not None : # right Y axis
        ax2 = ax1.twinx()
        ax2.yaxis.grid(True, 'both', linestyle='dotted')
        nofPlots = len(dataYR)
        if rangeYR is not None:
            ax2.set_ylim(rangeYR)
        colors = [ matplotlib.colors.hsv_to_rgb((hv/nofPlots, 1, 1)) for hv in range(nofPlots) ]
        for idx , data in enumerate(dataYR):
            ax2.plot(dataX, data[rangeP], linestyle='dashed',  color=colors[idx])

    if saveTo:
        fig.savefig(saveTo)
    if not show:
        plt.close(fig)





## <font style="color:blue">Configs</font>

In [2]:
def set_seed(SEED_VALUE):
    torch.manual_seed(SEED_VALUE)
    torch.cuda.manual_seed(SEED_VALUE)
    torch.cuda.manual_seed_all(SEED_VALUE)
    np.random.seed(SEED_VALUE)

seed = 7
set_seed(seed)

@dataclass(frozen=True)
class TrainingConfig:
    device: torch.device = 'cuda:0'
    nofEpochs: int = 256
    latentDim: int = 128
    batchSize: int = 16
    labelSmoothFac: float = 0.9 # For Real labels (or set to 1.0 for no smoothing).
    learningRateD: float = 0.0002
    learningRateG: float = 0.0002
    #CHECKPOINT_DIR: str = os.path.join('model_checkpoint', 'dcgan_flickr_faces')

class DatasetConfig:
    gapWidth = 12
    sinoWid = 3*gapWidth
    sinoLen = 4096
    gapSize = gapWidth * sinoLen
    sinoSize = sinoWid * sinoLen


## <font style="color:blue">Data</font>

In [3]:

class StripesFromHDF :

    def __init__(self, sampleName, maskName, bgName=None, dfName=None, loadToMem=True):

        sampleHDF = sampleName.split(':')
        if len(sampleHDF) != 2 :
            raise Exception(f"String \"{sampleName}\" does not represent an HDF5 format.")
        with h5py.File(sampleHDF[0],'r') as trgH5F:
            if  sampleHDF[1] not in trgH5F.keys():
                raise Exception(f"No dataset '{sampleHDF[1]}' in input file {sampleHDF[0]}.")
            self.data = trgH5F[sampleHDF[1]]
            if not self.data.size :
                raise Exception(f"Container \"{sampleName}\" is zero size.")
            self.sh = self.data.shape
            if len(self.sh) != 3 :
                raise Exception(f"Dimensions of the container \"{sampleName}\" is not 3 {self.sh}.")
            self.fsh = self.sh[1:3]
            self.volume = None
            if loadToMem :
                self.volume = np.empty(self.sh, dtype=np.float32)
                self.data.read_direct(self.volume)
                trgH5F.close()

            def loadImage(imageName) :
                if not imageName:
                    return None
                imdata = imread(imageName).astype(np.float32)
                if len(imdata.shape) == 3 :
                    imdata = np.mean(imdata[:,:,0:3], 2)
                #imdata = imdata.transpose()
                if imdata.shape != self.fsh :
                    raise Exception(f"Dimensions of the input image \"{imageName}\" {imdata.shape} "
                                    f"do not match the face of the container \"{sampleName}\" {self.fsh}.")
                return imdata


            self.mask = loadImage(maskName)
            if self.mask is None :
                self.mask = np.ones(self.fsh, dtype=np.uint8)
            self.mask = self.mask.astype(bool)
            self.bg = loadImage(bgName)
            self.df = loadImage(dfName)
            if self.bg is not None :
                if self.df is not None:
                    self.bg -= self.df
                self.mask  &=  self.bg > 0.0

            self.allIndices = []
            for yCr in range(0,self.fsh[0]) :
                for xCr in range(0,self.fsh[1]) :
                    idx = np.s_[yCr,xCr]
                    if self.mask[idx] :
                        if self.volume is not None :
                            if self.df is not None :
                                self.volume[:,*idx] -= self.df[idx]
                            if self.bg is not None :
                                self.volume[:,*idx] /= self.bg[idx]
                        if  xCr + DatasetConfig.sinoWid < self.fsh[1] \
                        and np.all( self.mask[yCr,xCr+1:xCr+DatasetConfig.sinoWid] ) :
                            self.allIndices.append(idx)

    def get_dataset(self, transform=None) :
        class Sinos(torch.utils.data.Dataset) :
            def __init__(self, root, transform=None):
                self.container = root
                self.transform = transforms.Compose([transforms.ToTensor(), transform]) \
                    if transform else transforms.ToTensor()
            def __len__(self):
                return len(self.container.allIndices)
            def __getitem__(self, index):
                idx = self.container.allIndices[index]
                xyrng=np.s_[ idx[0], idx[1]:idx[1]+DatasetConfig.sinoWid ]
                if self.container.volume is not None :
                    data = self.container.volume[:, *xyrng]
                else :
                    data = self.container.data[:, *xyrng]
                    if self.container.df is not None :
                        data -= self.container.df[None,*xyrng]
                    if self.container.bg is not None :
                        data /= self.container.bg[None,*xyrng]
                if self.transform :
                    data = self.transform(data)
                return data
        return Sinos(self, transform)


    def get_data_loader(self, batch_size, shuffle=None, num_workers=os.cpu_count() ) :
        return torch.utils.data.DataLoader( self.get_dataset(),
                                            batch_size=batch_size,
                                            num_workers=num_workers,
                                            shuffle = shuffle)

sinoRoot = StripesFromHDF("/mnt/ssdData/4176862R_Eig_Threshold-4keV/input/SAMPLE_Y0.hdf:/entry/data/data",
                          "/mnt/ssdData/4176862R_Eig_Threshold-4keV/input/mask.tif",
                          "/mnt/ssdData/4176862R_Eig_Threshold-4keV/output/bgo.tif",
                          None)



In [4]:
dataTransform =  transforms.Compose([
    transforms.Resize((DatasetConfig.sinoLen, DatasetConfig.sinoWid)),
    transforms.Normalize(mean=(0.5), std=(0.5))
])

trainSet = sinoRoot.get_dataset(dataTransform)
trainLoader = torch.utils.data.DataLoader(
    dataset=trainSet,
    batch_size=TrainingConfig.batchSize,
    shuffle=True,
    num_workers=os.cpu_count()
)




In [5]:
#dataTransform =  transforms.Compose([
#    transforms.Resize((DatasetConfig.sinoLen, DatasetConfig.sinoWid)),
#    #transforms.Resize((500, 500)),
#    transforms.Normalize(mean=(0.5), std=(0.5))
#])
#
#testSet = sinoRoot.get_dataset(dataTransform)
#randIdx = random.randint(0,len(testSet)-1)
#print(randIdx, sinoRoot.allIndices[randIdx])
#image = testSet[randIdx].squeeze().transpose(0,1)
#plt.imshow(image, cmap='gray')
#plt.axis("off")
##tifffile.imwrite("tmp.tif", image)


## <font style="color:blue">Save/Load model</font>

In [6]:
def save_model(model, device, model_path):
    if not device == 'cpu':
        model.to('cpu')
    torch.save(model.state_dict(), model_path)
    if not device == 'cpu':
        model.to(device)
    return

def load_model(model, model_path):
    model.load_state_dict(torch.load(model_path))
    return model

## <font style="color:blue">Models</font>

In [9]:


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.body = nn.Sequential(

            nn.Conv2d(3, 64, (5,3), bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 3, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 3, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, dilation=2, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, dilation=4, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, dilation=8, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, dilation=16, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(256, 128, 4, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(128, 64, 4, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 16, 3, bias=False, padding='same'),
            nn.LeakyReLU(0.2),

            nn.Conv2d(16, 1, 3, bias=False),
            nn.Tanh()

        )
        torch.nn.init.xavier_uniform_(self.body[0].weight)


    def forward(self, input):
        input = input.view(input.shape[0], 3, DatasetConfig.sinoLen, DatasetConfig.gapWidth)
        conv = self.body(input)
        input[:,1,:,:] = conv
        return input.view(input.shape[0], 1, DatasetConfig.sinoLen, DatasetConfig.sinoWid)



#if model is not None:
#    del model
#    gc.collect()
#    torch.cuda.empty_cache()
generator = Generator()
#model = load_model(model, "/home/imbl/usr/src/ReMuse/experiments/e0134_model.pt")
print(generator)
model_summary = summary(generator, (1,1,DatasetConfig.sinoLen,DatasetConfig.sinoWid) ).__str__()
print(model_summary)




Generator(
  (body): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 3), stride=(1, 1), padding=same, bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
    (9): LeakyReLU(negative_slope=0.2)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
    (11): LeakyReLU(negative_slope=0.2)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(2, 2), bias=False)
    (13): LeakyReLU(negative_slope=0.2)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=same, dilation=(4, 4), bia

In [10]:

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.body = nn.Sequential(

            nn.Conv2d(1, 64, (5,3), stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, (5,3), stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, (5,3), stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 5, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 5, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 5, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 5, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 5, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 5, stride=(2,1), bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 5, bias=False),
            nn.LeakyReLU(0.2)

        )
        torch.nn.init.xavier_uniform_(self.body[0].weight)

        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )


    def forward(self, input):
        conv = self.body(input)
        res = self.head(conv)
        return res


discriminator = Discriminator()
#model = load_model(model, "/home/imbl/usr/src/ReMuse/experiments/e0134_model.pt")
print(discriminator)
model_summary = summary(discriminator, (1,1,DatasetConfig.sinoLen,DatasetConfig.sinoWid) ).__str__()
print(model_summary)

Discriminator(
  (body): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 3), stride=(2, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(5, 3), stride=(2, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(128, 256, kernel_size=(5, 3), stride=(2, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 1), bias=False)
    (9): LeakyReLU(negative_slope=0.2)
    (10): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 1), bias=False)
    (11): LeakyReLU(negative_slope=0.2)
    (12): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 1), bias=False)
    (13): LeakyReLU(negative_slope=0.2)
    (14): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 1), bias=False)
    (15): LeakyReLU(negative_slope=0.2)
    (16): Conv2d(512, 512, kernel_size=(5, 5), stride=(2, 1), bia

## <font style="color:blue">Metrics</font>

In [37]:
BCE = nn.BCELoss()

def loss_func(y_true, y_pred):
    loss = BCE(y_pred, y_true)
    return loss


## <font style="color:blue">Optimizers</font>

In [42]:

optimizer_G = optim.Adam(
    generator.parameters(),
    lr=TrainingConfig.learningRateG,
    betas=(0.5, 0.999)
)
optimizer_D = optim.Adam(
    discriminator.parameters(),
    lr=TrainingConfig.learningRateD,
    betas=(0.5, 0.999)
)

## <font style="color:blue">Train step</font>

In [11]:
def train_step(images):

    optimizer_G.zero_grad()
    optimizer_D.zero_grad()

    y_pred_real = discriminator(images)
    noise = torch.randn(images.shape[0],
                        DatasetConfig.sinoLen,
                        DatasetConfig.gapWidth).to(TrainingConfig.device)
    images[:,:,DatasetConfig.gapWidth:2*DatasetConfig.gapWidth] = noise
    y_pred_fake = discriminator(images)
    y_pred_both = torch.cat((y_pred_real, y_pred_fake), dim=0)
    labels = torch.cat(
        torch.full((TrainingConfig.batchSize, 1),  TrainingConfig.labelSmoothFac), # Labels for real data
        torch.zeros(TrainingConfig.batchSize, 1), # Labels for fake data
        dim=0
    ).to(TrainingConfig.DEVICE)

    D_loss = loss_func(labels, y_pred_both)
    D_loss.backward()
    optimizer_D.step()

    labels = torch.ones(TrainingConfig.batchSize, 1)
    G_loss = loss_func(labels, y_pred_fake)
    G_loss.backward()
    optimizer_G.step()

    return D_loss, G_loss