In [1]:
import torch
from Models import Params,SepVAE,CrossAttention, SepVAEEncoder
import torch.nn as nn
import torchinfo
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# test cross attention

crs = CrossAttention(17,35)
input1 = torch.randn(size = (3,17,16,16))
input2 = torch.randn(size = (3,35,17,17))

with torch.no_grad():
    a = crs(input1,input2)

In [3]:
print(a[0].shape, a[1].shape)

torch.Size([3, 17, 16, 16]) torch.Size([3, 35, 17, 17])


In [18]:
a = torch.chunk(a[0],2,2)


In [19]:
a[0].shape

torch.Size([3, 17, 8, 16])

In [6]:
dummy = torch.randn(size= (3,3,128,128))

params = Params()
params.useconv = False
model = SepVAE(params)

In [4]:
print(torchinfo.summary(model))

Layer (type:depth-idx)                                       Param #
SepVAE                                                       --
├─SepVAEEncoder: 1-1                                         --
│    └─Conv2d: 2-1                                           448
│    └─ModuleList: 2-2                                       --
│    │    └─ModuleList: 3-1                                  13,984
│    │    └─ModuleList: 3-2                                  15,024
│    │    └─ModuleList: 3-3                                  43,360
│    │    └─ModuleList: 3-4                                  139,968
│    └─ModuleList: 2-3                                       --
│    │    └─ModuleList: 3-5                                  50,848
│    │    └─ModuleList: 3-6                                  35,504
│    │    └─ModuleList: 3-7                                  76,128
│    │    └─ModuleList: 3-8                                  139,968
│    └─CrossAttention: 2-4                                   --


In [7]:
with torch.no_grad():
    a = model(dummy)

In [8]:
with torch.no_grad():
    b = model.elbo(a[0],a[1],a[2],a[3],a[-1])

In [9]:
b

tensor(49673.2734)

513.0

In [4]:
print(torchinfo.summary(model))

Layer (type:depth-idx)                                  Param #
SepVAE                                                  --
├─Conv2d: 1-1                                           448
├─ModuleList: 1-2                                       --
│    └─ModuleList: 2-1                                  --
│    │    └─ResnetBlock: 3-1                            4,704
│    │    └─Residual: 3-2                               8,240
│    │    └─Conv2d: 3-3                                 1,040
│    └─ModuleList: 2-2                                  --
│    │    └─ResnetBlock: 3-4                            4,704
│    │    └─Residual: 3-5                               8,240
│    │    └─Conv2d: 3-6                                 2,080
│    └─ModuleList: 2-3                                  --
│    │    └─ResnetBlock: 3-7                            18,624
│    │    └─Residual: 3-8                               16,480
│    │    └─Conv2d: 3-9                                 18,496
├─ModuleList: 1-3   

In [5]:
def Downsample(dim, dim_out=None):
    return nn.Conv2d(dim, dim_out if dim_out is not None else dim, 2, stride=2)



In [9]:
ts = Downsample(64)
with torch.no_grad():
    print(ts(torch.randn(size=(5,64,64,64))).shape)

torch.Size([5, 64, 32, 32])


In [None]:
class Trainer:
    def __init__(self, model, optimizer, scheduler, train_loader, test_loader, training_step, report_step, logger=None):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.training_step = training_step
        self.report_step = report_step
        self.logger = logger
        self.step = 0

        self.loss_list = []
        self.evaloss_list = []

    def evaluate(self):
        hloss = []
        self.model.eval()
        with torch.no_grad():
            for X in self.test_loader:
                fine_pos_mean, fine_pos_var, coarse_pos_mean, coarse_pos_var, fine_sample, coarse_sample, output = self.model(X)
                hloss1 = self.model.elbo(fine_pos_mean, fine_pos_var, coarse_pos_mean, coarse_pos_var, output)
                hloss.append(hloss1.item())
        value = np.mean(np.array(hloss))

        fig, ax = plt.subplots(ncols=2)
        ax[0].imshow(X[0].cpu().detach().permute(1,2,0).numpy())
        ax[1].imshow(output[0].cpu().detach().permute(1,2,0).numpy())

        del hloss
        return value

    def save(self, path = './'):

        all_state_dict = {
        'model_sd' : self.model.state_dict(),
        'optimizer_sd' : self.optimizer.state_dict(),
        'sche_sd' : self.scheduler.state_dict(),
        'step' : self.step,
        'loss_list' : self.loss_list,
        'evaloss_list' : self.evaloss_list
        }

        torch.save(all_state_dict, path/'sepvae.pth')





    def train(self):
        best_eva_loss = np.inf
        while True:
            for X in self.train_loader:
                self.model.train()
                self.optimizer.zero_grad()
                fine_pos_mean, fine_pos_var, coarse_pos_mean, coarse_pos_var, fine_sample, coarse_sample, output = self.model(X)
                loss = self.model.elbo(fine_pos_mean, fine_pos_var, coarse_pos_mean, coarse_pos_var, output)
                self.loss_list.append(loss.item())
                if self.logger:
                    self.logger.info(f"step:{self.step}, loss:{loss.item():.6f}")
                else:
                    print(f"step:{self.step}, loss:{loss.item():.6f}")

                loss.backward()

                self.optimizer.step()
                self.scheduler.step()
                self.step += 1

        # valid
            if self.step % self.report_step == 0:
                evaloss = self.evaluate()
                self.evaloss_list.appned(evaloss.item())
                if self.logger:
                    self.logger.info(f"step:{self.step}, eval_loss:{evaloss.item():.6f}")
                else:
                    print(f"step:{self.step}, eval_loss:{evaloss.item():.6f}")

                # saving the best
                if evaloss.item() < best_eva_loss:
                    best_eva_loss = evaloss.item()
                    self.save()
