## <font style="color:lightblue">Header</font>

### <font style="color:lightblue">Imports</font>

In [1]:
#%load_ext autoreload
#%autoreload 2

import math
import numpy as np
import torch
import torch.nn as nn
from torchinfo import summary
import itertools
import random

import sinogap_module_alt as sg




### <font style="color:lightblue">Redefine</font>

In [2]:
sg.plt.rcParams['figure.dpi']=223





### <font style="color:lightblue">Configs</font>

In [3]:
sg.set_seed(7)

sg.TCfg = sg.TCfgClass(
     exec = 1
    ,nofEpochs = None
    ,latentDim = 64
    ,batchSize = 2**5
    ,batchSplit = 1
    ,labelSmoothFac = 0.1 # For Fake labels (or set to 0.0 for no smoothing).
    ,learningRateD = 0.0001
    ,learningRateG = 0.0001
)

sg.DCfg = sg.DCfgClass(16)


### <font style="color:lightblue">Raw Read</font>

In [4]:
trainSet = sg.createTrainSet()
#testSet = sg.createTestSet()

Loading train set 1 of 9: 18515.Lamb1_Eiger_7m_45keV_360Scan ... Done
Loading train set 2 of 9: 18692a.ExpChicken6mGyShift ... Done
Loading train set 3 of 9: 18692b_input_PhantomM ... Done
Loading train set 4 of 9: 18692b.MinceO ... Done
Loading train set 5 of 9: 19022g.11-EggLard ... Done
Loading train set 6 of 9: 19736b.09_Feb.4176862R_Eig_Threshold-4keV ... Done
Loading train set 7 of 9: 19736c.8733147R_Eig_Threshold-8keV.SAMPLE_Y1 ... Done
Loading train set 8 of 9: 20982b.04_774784R ... Done
Loading train set 9 of 9: 23574.8965435L.Eiger.32kev_org ... Done


### <font style="color:lightblue">Show</font>

In [5]:
#sg.refImages, sg.refNoises = sg.createReferences(testSet, 1)
#sg.showMe(testSet, 0 )

## <font style="color:lightblue">Models</font>

### Generator 2pix

In [6]:


class Generator2(sg.GeneratorTemplate):

    def __init__(self):
        super(Generator2, self).__init__(2)
        self.amplitude = 4
        self.encoders =  nn.ModuleList([
            self.encblock(  1/self.baseChannels,
                               1, 3, padding=1, norm=False),
            self.encblock(  1, 1, 3, padding=1),
            self.encblock(  1, 1, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  1, 1, 3, padding=1),
            self.encblock(  1, 1, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  1, 1, 3, padding=1),
            self.encblock(  1, 1, 3, stride=(2,1), padding=(1,0)),
            #self.encblock(  1, 1, 3, padding=1),
            #self.encblock(  1, 1, 3, stride=(2,1), padding=(1,0)),
            ])
        self.fcLink = self.createFClink()
        self.decoders = nn.ModuleList([
            #self.decblock(2, 1, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            #self.decblock(2, 1, 3, padding=1),
            self.decblock(2, 1, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(2, 1, 3, padding=1),
            self.decblock(2, 1, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(2, 1, 3, padding=1),
            self.decblock(2, 1, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(2, 1, 3, padding=1),
            self.decblock(2, 1, 3, padding=1, norm=False),
            ])
        self.lastTouch = self.createLastTouch()
        #sg.load_model(self, model_path="saves/gap2/noBNreNorm_SSIM/model_gen.pt" )

#generator2 = Generator2()
#generator2 = generator2.to(sg.TCfg.device)
#generator2 = generator2.requires_grad_(False)
#generator2 = generator2.eval()
#sg.lowResGenerators[2] = generator2
#
#input_data=[ (torch.randn( (1,1,*generator2.sinoSh), device=sg.TCfg.device),
#              torch.randn( (1,sg.TCfg.latentDim), device=sg.TCfg.device)) ]
#model_summary = summary(generator2, input_data=input_data ).__str__()
#print(model_summary)




### Generator 4pix

In [7]:


class Generator4(sg.GeneratorTemplate):

    def __init__(self):
        super(Generator4, self).__init__(4)
        self.amplitude = 4
        self.encoders =  nn.ModuleList([
            self.encblock( 1/self.baseChannels,
                               1, 3, padding=1, norm=False),
            self.encblock(  1, 1, 3, padding=1),
            self.encblock(  1, 2, 3, stride=(2,2), padding=(1,1)),
            self.encblock(  2, 2, 3, padding=1),
            self.encblock(  2, 2, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  2, 2, 3, padding=1),
            self.encblock(  2, 2, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  2, 2, 3, padding=1),
            self.encblock(  2, 2, 3, stride=(2,1), padding=(1,0)),
            ])
        self.fcLink = self.createFClink()
        self.decoders = nn.ModuleList([
            self.decblock(4, 2, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(4, 2, 3, padding=1),
            self.decblock(4, 2, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(4, 2, 3, padding=1),
            self.decblock(4, 2, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(4, 2, 3, padding=1),
            self.decblock(4, 1, 3, stride=(2,2), outputPadding=(1,1), padding=(1,1)),
            self.decblock(2, 1, 3, padding=1),
            self.decblock(2, 1, 3, padding=1, norm=False),
            ])
        self.lastTouch = self.createLastTouch()
        self.lowResGenerator = Generator2()
        #sg.load_model(self, model_path="saves/gap4/noBNreNorm_SSIM/model_gen.pt" )

#generator4 = Generator4()
#generator4 = generator4.to(sg.TCfg.device)
#generator4 = generator4.requires_grad_(False)
#generator4 = generator4.eval()
#sg.lowResGenerators[4] = generator4
#
#input_data=[ (torch.randn( (1,1,*generator4.sinoSh), device=sg.TCfg.device),
#              torch.randn( (1,sg.TCfg.latentDim), device=sg.TCfg.device)) ]
#model_summary = summary(generator4, input_data=input_data ).__str__()
#print(model_summary)




### Generator 8pix

In [8]:


class Generator8(sg.GeneratorTemplate):

    def __init__(self):
        super(Generator8, self).__init__(8)
        self.amplitude = 4

        self.encoders =  nn.ModuleList([
            self.encblock( 1/self.baseChannels,
                               1, 3, padding=1, norm=False),
            self.encblock(  1, 1, 3, padding=1),
            self.encblock(  1, 2, 3, stride=(2,2), padding=(1,1)),
            self.encblock(  2, 2, 3, padding=1),
            self.encblock(  2, 4, 3, stride=(2,2), padding=(1,1)),
            self.encblock(  4, 4, 3, padding=1),
            self.encblock(  4, 4, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  4, 4, 3, padding=1),
            self.encblock(  4, 4, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  4, 4, 3, padding=1),
            self.encblock(  4, 4, 3, stride=(2,1), padding=(1,0)),
            ])

        self.fcLink = self.createFClink()

        self.decoders = nn.ModuleList([
            self.decblock(8, 4, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(8, 4, 3, padding=1),
            self.decblock(8, 4, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(8, 4, 3, padding=1),
            self.decblock(8, 4, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(8, 4, 3, padding=1),
            self.decblock(8, 2, 3, stride=(2,2), outputPadding=(1,1), padding=(1,1)),
            self.decblock(4, 2, 3, padding=1),
            self.decblock(4, 1, 3, stride=(2,2), outputPadding=(1,1), padding=(1,1)),
            self.decblock(2, 1, 3, padding=1),
            self.decblock(2, 1, 3, padding=1, norm=False),
            ])

        self.lastTouch = self.createLastTouch()
        self.lowResGenerator = Generator4()
        #sg.load_model(self, model_path="saves/gap8/noBNreNorm_SSIM/model_gen.pt" )


#generator8 = Generator8()
#generator8 = generator8.to(sg.TCfg.device)
#generator8 = generator8.requires_grad_(False)
#generator8 = generator8.eval()
#sg.lowResGenerators[8] = generator8
#
#input_data=[ (torch.randn( (1,1,*generator8.sinoSh), device=sg.TCfg.device),
#              torch.randn( (1,sg.TCfg.latentDim), device=sg.TCfg.device)) ]
#model_summary = summary(generator8, input_data=input_data ).__str__()
#print(model_summary)


### Generator 16pix

In [9]:


class Generator16(sg.GeneratorTemplate):

    def __init__(self):
        super(Generator16, self).__init__(16,1)
        self.amplitude = 4

        self.noise2latent = self.createLatent()

        self.encoders =  nn.ModuleList([
            self.encblock( (1+self.latentChannels)/self.baseChannels,
                               1, 3, padding=1, norm=False),
            self.encblock(  1, 1, 3, padding=1),
            self.encblock(  1, 2, 3, stride=(2,2), padding=(1,1)),
            self.encblock(  2, 2, 3, padding=1),
            self.encblock(  2, 4, 3, stride=(2,2), padding=(1,1)),
            self.encblock(  4, 4, 3, padding=1),
            self.encblock(  4, 8, 3, stride=(2,2), padding=(1,1)),
            self.encblock(  8, 8, 3, padding=1),
            self.encblock(  8, 8, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  8, 8, 3, padding=1),
            self.encblock(  8, 8, 3, stride=(2,1), padding=(1,0)),
            self.encblock(  8, 8, 3, padding=1),
            self.encblock(  8, 8, 3, stride=(2,1), padding=(1,0)),
            ])

        self.fcLink = self.createFClink()

        self.decoders = nn.ModuleList([
            self.decblock(16, 8, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(16, 8, 3, padding=1),
            self.decblock(16, 8, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(16, 8, 3, padding=1),
            self.decblock(16, 8, 3, stride=(2,1), outputPadding=(1,0), padding=(1,0)),
            self.decblock(16, 8, 3, padding=1),
            self.decblock(16, 4, 3, stride=(2,2), outputPadding=(1,1), padding=(1,1)),
            self.decblock( 8, 4, 3, padding=1),
            self.decblock( 8, 2, 3, stride=(2,2), outputPadding=(1,1), padding=(1,1)),
            self.decblock( 4, 2, 3, padding=1),
            self.decblock( 4, 1, 3, stride=(2,2), outputPadding=(1,1), padding=(1,1)),
            self.decblock( 2, 1, 3, padding=1),
            self.decblock( 2, 1, 3, padding=1, norm=False),
            ])

        self.lowResGenerator = Generator8()
        self.lastTouch = self.createLastTouch()
        #sg.load_model(self, model_path="saves/gap16/noBNreNorm_SSIM/model_gen.pt" )

generator16 = Generator16()
generator16 = generator16.to(sg.TCfg.device)
sg.lowResGenerators[16] = generator16
#
#input_data=[ (torch.randn( (1,1,*generator16.sinoSh), device=sg.TCfg.device),
#              torch.randn( (1,sg.TCfg.latentDim), device=sg.TCfg.device)) ]
#model_summary = summary(generator16, input_data=input_data ).__str__()
#print(model_summary)


### <font style="color:lightblue">Generator</font>

In [10]:
sg.generator = sg.lowResGenerators[sg.DCfg.gapW]
sg.optimizer_G = sg.createOptimizer(sg.generator, sg.TCfg.learningRateG)
input_data=[ (torch.randn( (1,1,*sg.generator.sinoSh), device=sg.TCfg.device),
              torch.randn( (1,sg.TCfg.latentDim), device=sg.TCfg.device)) ]
#input_data=[ [sg.refImages[[0],...], sg.refNoises[[0],...]] ]
model_summary = summary(sg.generator, input_data=input_data ).__str__()
print(model_summary)


Layer (type:depth-idx)                                  Output Shape              Param #
Generator16                                             [1, 1, 4096, 16]          --
├─Generator8: 1-1                                       --                        --
│    └─Generator4: 2-1                                  --                        --
│    │    └─Generator2: 3-1                             --                        2,102,178
│    │    └─ModuleList: 3-2                             --                        3,988
│    │    └─Sequential: 3-3                             [1, 8, 64, 4]             8,392,704
│    │    └─ModuleList: 3-4                             --                        8,124
│    │    └─Sequential: 3-5                             [1, 1, 1024, 20]          6
│    └─ModuleList: 2-2                                  --                        --
│    │    └─Sequential: 3-6                             [1, 4, 2048, 40]          40
│    │    └─Sequential: 3-7              

### <font style="color:lightblue">Discriminator</font>

In [11]:

class Discriminator(sg.DiscriminatorTemplate):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.param = nn.Parameter(torch.zeros(1))
    def forward(self, images):
        return torch.zeros((images.shape[0],1), device=sg.TCfg.device)


sg.discriminator = Discriminator()
sg.discriminator = sg.discriminator.to(sg.TCfg.device)
#model_summary = summary(sg.discriminator, input_data=sg.refImages[0,...] ).__str__()
model_summary = summary(sg.discriminator, input_data=torch.randn( (1,1,*sg.generator.sinoSh), device=sg.TCfg.device) ).__str__()
#print(model_summary)
#sg.writer.add_graph(sg.discriminator, refImages)

sg.optimizer_D = sg.createOptimizer(sg.discriminator, sg.TCfg.learningRateD)



## <font style="color:lightblue">Restore checkpoint</font>

In [12]:
sg.noAdv = True
#sg.dataLoader = sg.createDataLoader(trainSet, num_workers=16)
#sg.testLoader = sg.createDataLoader(testSet , num_workers=16)

#sg.normRec, sg.normMSE, sg.normL1L = sg.summarizeSet(sg.dataLoader)[0:3]
#sg.normTestRec, sg.normTestMSE, sg.normTestL1L, = sg.summarizeSet(sg.testLoader)[0:3]
sg.normRec, sg.normMSE, sg.normL1L = 3 * 4.021e-03, 6.625e-03, 5.5e-02 #sg.summarizeSet(sg.dataLoader)[0:3]
sg.normTestRec, sg.normTestMSE, sg.normTestL1L, = 2 * 4.846e-03, 1.370e-03, 3.605e-02 # sg.summarizeSet(sg.testLoader)[0:3]
sg.normSSIM = sg.normL1L
sg.normTestSSIM = sg.normTestL1L
print((sg.normRec, sg.normMSE, sg.normL1L))
print((sg.normTestRec, sg.normTestMSE, sg.normTestL1L))

(0.012063, 0.006625, 0.055)
(0.009692, 0.00137, 0.03605)


In [13]:
#sg.scheduler_G = torch.optim.lr_scheduler.StepLR(sg.optimizer_G, 1, gamma=1-0.001)
#sg.scheduler_D = torch.optim.lr_scheduler.StepLR(sg.optimizer_D, 1, gamma=1-0.001)
savedCheckPoint = f"checkPoint_{sg.TCfg.exec}"
sg.epoch, sg.imer, sg.minGEpoch, sg.minGdLoss, sg.startFrom, sg.resAcc = \
    sg.restoreCheckpoint()#savedCheckPoint+".pth")
sg.writer = sg.createWriter(sg.TCfg.logDir, True)
#sg.writer.add_graph(sg.generator, ((sg.refImages, sg.refNoises),) )
#sg.writer.add_graph(sg.discriminator, refImages)
#sg.minGdLoss = 100
#sg.epoch, sg.imer, sg.minGEpoch, sg.minGdLoss, sg.startFrom = 0, 0, 0, 1, 0
#print(sg.epoch, sg.imer, sg.minGEpoch, sg.minGdLoss, sg.scheduler_D.get_last_lr()[0], sg.startFrom)
#lastLR = sg.scheduler_D.get_last_lr()[0]
#initialLR = sg.TCfg.learningRateD
#print(f"Initial LR : {lastLR} {lastLR/initialLR:.3f}")
#sg.initialTest()


## <font style="color:lightblue">Execute</font>

In [16]:
sg.noAdv = True
sg.dataLoader = sg.createDataLoader(trainSet, num_workers=16)
#sg.testLoader = sg.createDataLoader(testSet , num_workers=16)
#torch.autograd.set_detect_anomaly(True)

torch.optim.lr_scheduler.LambdaLR(sg.optimizer_G, lambda epoch: 0.1).step()

#
#def my_afterEachEpoch(epoch) :
#    if sg.minGEpoch < 600 :
#        return
#    if not sg.dataLoader is None :
#        del sg.dataLoader
#        sg.freeGPUmem()
#    if sg.TCfg.batchSize < 131072 :
#    sg.TCfg.batchSize += round( 0.01 * sg.TCfg.batchSize )
#    sg.dataLoader = sg.createTrainLoader(trainSet, num_workers=24)
#    print("Batch size: ",sg.TCfg.batchSize)
#sg.afterEachEpoch = my_afterEachEpoch

#def my_beforeReport() :
#    sg.generator.amplitude = max(4, sg.generator.amplitude * (1-0.0005) )
#    print(f"AMPL : {sg.generator.amplitude}")
#    with open(f"message_{sg.TCfg.exec}.txt", 'a') as file:
#        file.write(f"sg.generator.amplitude: {sg.generator.amplitude}\n")
#    return
#sg.beforeReport = my_beforeReport

sg.SSIM_MSE = 1
sg.ADV_DIF = 0
def my_beforeReport() :
    lastLR = sg.scheduler_D.get_last_lr()[0]
    print(f"LR : {lastLR} {lastLR/sg.TCfg.learningRateD:.3f}")
    if lastLR  >  0.2 * sg.TCfg.learningRateD :
        if sg.scheduler_G is not None :
            sg.scheduler_G.step()
        if sg.scheduler_D is not None :
            sg.scheduler_D.step()
    #sg.generator.amplitude = min(4, sg.generator.amplitude * (1+0.0005) )
    #print(f"AMPL : {sg.generator.amplitude}")
    #with open(f"message_{sg.TCfg.exec}.txt", 'a') as file:
    #    file.write(f"sg.generator.amplitude: {sg.generator.amplitude}\n")
    #message = f" SSIM/MSE : {sg.SSIM_MSE:.3f}"
#    #print(message)
#    #if sg.SSIM_MSE < 0.99 :
#    #    sg.SSIM_MSE += 1e-3
#    #    with open(f"message_{sg.TCfg.exec}.txt", 'a') as file:
#    #        file.write(message + "\n")
    return
sg.beforeReport = my_beforeReport

sg.noAdv = True

try :
    sg.train(savedCheckPoint)
except :
    del sg.dataLoader
    #del sg.testLoader
    sg.freeGPUmem()
    1/10 # to release Jupyuter memory in the next step
    sg.epoch -= 1
    raise



  0%|          | 148/274647 [00:59<30:50:51,  2.47it/s]


AttributeError: module 'sinogap_module_alt' has no attribute 'scheduler_D'

## <font style="color:lightblue">Post</font>

In [None]:
print (sg.generator.amplitude.item(), 2 * torch.sigmoid(sg.generator.amplitude).item() )
sg.initialTest()


In [None]:
sg.testMe(trainSet, 5)

### <font style="color:lightblue">Save results</font>

In [None]:
sg.saveModels()