In [1]:
from models.mymodules import CNN_block, CNN_head
import torch.nn as nn
from datasets.datasets import SHHSdataset
from torch.utils.data import DataLoader

In [2]:
ds = SHHSdataset(
    data_path='../../thesis01/data/',
    first_patient=1,
    num_patients=1
)
datapoint = ds[0][0]
dl = DataLoader(
    dataset=ds,
    batch_size=64,
    shuffle=False
)
batch = next(iter(dl))

In [3]:
import torch
x,y = batch
x = torch.squeeze(x, 1)
print(x.size())

torch.Size([64, 1, 3000])


In [4]:
conv_filters = [32, 64, 64]
representation_dim = 100
encoder = CNN_head(conv_filters, representation_dim)

In [5]:
encoding = encoder(x)
encoding = encoding[:,None,:]
print(encoding.size())

torch.Size([64, 1, 100])


In [6]:
class Inverse_CNN_block(nn.Module):
    """
        The purpose of this block is to invert the operation of a regular convolution block
    """
    def __init__(self, kernel_size, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose1d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel_size,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),

        )

    def forward(self, x):
        return self.net(x)

In [14]:
class CNN_decoder(nn.Module):
    """
        Inverse of CNN_head for generational purpose
    """
    def __init__(self, conv_filters, representation_dim):
        super().__init__()
        self.model = nn.Sequential(
            Inverse_CNN_block(3, 1, conv_filters[0]),
            Inverse_CNN_block(3, conv_filters[0], conv_filters[1]),
            Inverse_CNN_block(3, conv_filters[1], conv_filters[2]),
            nn.Flatten(),
            nn.Linear(51200, 3000)
        )

    def forward(self, x):
        return self.model(x)

In [15]:
decoder = CNN_decoder(conv_filters, representation_dim)
decoding = decoder(encoding)
print(decoding.size())

torch.Size([64, 3000])
