In [5]:
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from tomotwin.modules.networks.torchmodel import TorchModel

class AutoEncoder(TorchModel):

    NORM_BATCHNORM = "BatchNorm"
    NORM_GROUPNORM = "GroupNorm"

    class Model(nn.Module):
        def make_norm(self, norm: Dict, num_channels: int) -> nn.Module:
            if norm["module"] == nn.BatchNorm3d:
                norm["kwargs"]["num_features"] = num_channels
                return norm["module"](**norm["kwargs"])
            elif norm["module"] == nn.GroupNorm:
                norm["kwargs"]["num_channels"] = num_channels
                return norm["module"](**norm["kwargs"])
            else:
                raise ValueError("Not supported norm", norm["module"])


        def __init__(
            self,
            output_channels: int,
            norm: Dict,
            dropout: float = 0.5,
            repeat_layers=0,
            gem_pooling = None,
        ):
            super().__init__()
            norm_func = self.make_norm(norm, 64)
            self.en_layer0 = self._make_conv_layer(1, 64, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.en_layer1 = self._make_conv_layer(64, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 256)
            self.en_layer2 = self._make_conv_layer(128, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 512)
            self.en_layer3 = self._make_conv_layer(256, 512, norm=norm_func)


            self.max_pooling = nn.MaxPool3d((2, 2, 2))
            if gem_pooling:
                self.adap_max_pool = gem_pooling
            else:
                self.adap_max_pool = nn.AdaptiveAvgPool3d((2, 2, 2))
            
            self.headnet = self._make_headnet(
                 512, 256, 64, 1, dropout=dropout
            )

            norm_func = self.make_norm(norm, 256)
            self.de_layer0 = self._make_deconv_layer(512, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.de_layer1 = self._make_deconv_layer(256, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 64)
            self.de_layer2 = self._make_deconv_layer(128, 64, norm=norm_func)

            #norm_func = self.make_norm(norm, 64)
            #self.de_layer3 = self._make_conv_layer(128, 64, norm=norm_func)

            #norm_func = self.make_norm(norm, 1)
            #self.de_layer4 = self._make_conv_layer(64, 1, norm=norm_func)
            self.de_layer4 = nn.Sequential(
                nn.ConvTranspose3d(64, 1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(1, 1, kernel_size=3, padding=1),
                nn.Identity() 
            )

            self.up_sampling = nn.Upsample(scale_factor =2)

        @staticmethod
        def _make_conv_layer(in_c: int, out_c: int, norm: nn.Module, padding: int = 1, kernel_size: int =3):
            conv_layer = nn.Sequential(
                nn.Conv3d(in_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
                nn.Conv3d(out_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
            )
            return conv_layer
        
        @staticmethod
        def _make_deconv_layer(in_c: int, out_c: int, norm: nn.Module, padding: int = 1, kernel_size: int =3):
            conv_layer = nn.Sequential(
                nn.ConvTranspose3d(in_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
                nn.ConvTranspose3d(out_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
            )
            return conv_layer

        @staticmethod
        def _make_headnet(
            in_c1: int, in_c2: int,out_c1: int, out_head: int, dropout: float
        ) -> nn.Sequential:
            headnet = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Conv3d(in_c1, in_c2, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.Conv3d(in_c2, out_c1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.Conv3d(out_c1, out_head, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(out_head, out_c1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(out_c1, in_c2, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(in_c2,in_c1, kernel_size=3, padding=1),
                nn.LeakyReLU(),


            )
            return headnet

        def forward(self, inputtensor):
            """
            Forward pass through the network
            :param inputtensor: Input tensor
            """
            inputtensor = F.pad(inputtensor, (1, 2, 1, 2, 1, 2))

            out = self.en_layer0(inputtensor)
            out = self.max_pooling(out)
            out = self.en_layer1(out)
            out = self.max_pooling(out)
            out = self.en_layer2(out)
            out = self.max_pooling(out)
            out = self.en_layer3(out)
            #out = self.max_pooling(out)

            #out = self.en_layer4(out)
            #out = self.adap_max_pool(out)
            #out = out.reshape(out.size(0), -1)  # flatten
            out = self.headnet(out)
            #out = out.reshape(-1,512,5,5,5)
            out = self.de_layer0(out)
            out = self.up_sampling(out)
            out = self.de_layer1(out)
            out = self.up_sampling(out)
            out = self.de_layer2(out)
            out = self.up_sampling(out)
            #out = self.de_layer3(out)
            #out = self.up_sampling(out)
            out = self.de_layer4(out)
            #out = F.normalize(out, p=2, dim=1)

            return out

    """
    Custom 3D convnet, nothing fancy
    """

    def setup_norm(self, norm_name : str, norm_kwargs: dict) -> Dict:
        norm = {}
        if norm_name == AutoEncoder.NORM_BATCHNORM:
            norm["module"] = nn.BatchNorm3d
        if norm_name == AutoEncoder.NORM_GROUPNORM:
            norm["module"] = nn.GroupNorm
        norm["kwargs"] = norm_kwargs

        return norm


    def setup_gem_pooling(self,gem_pooling_p : float) -> Union[None, nn.Module]:
        gem_pooling = None
        if gem_pooling_p > 0:
            from tomotwin.modules.networks.GeneralizedMeanPooling import GeneralizedMeanPooling
            gem_pooling = GeneralizedMeanPooling(norm=gem_pooling_p, output_size=(2, 2, 2))
        return gem_pooling

    def __init__(
        self,
        norm_name: str,
        norm_kwargs: Dict = {},
        output_channels: int = 128,
        dropout: float = 0.5,
        gem_pooling_p: float = 0,
        repeat_layers=0,
    ):
        super().__init__()
        norm = self.setup_norm(norm_name, norm_kwargs)
        gem_pooling = self.setup_gem_pooling(gem_pooling_p)


        self.model = self.Model(
            output_channels=output_channels,
            dropout=dropout,
            repeat_layers=repeat_layers,
            norm=norm,
            gem_pooling=gem_pooling
        )

    def init_weights(self):
        def _init_weights(model):
            if isinstance(model, nn.Conv3d):
                torch.nn.init.kaiming_normal_(model.weight)

        self.model.apply(_init_weights)

    def get_model(self) -> nn.Module:
        return self.model

In [6]:
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from tomotwin.modules.networks.torchmodel import TorchModel

class AutoEncoder(TorchModel):

    NORM_BATCHNORM = "BatchNorm"
    NORM_GROUPNORM = "GroupNorm"

    class Model(nn.Module):
        def make_norm(self, norm: Dict, num_channels: int) -> nn.Module:
            if norm["module"] == nn.BatchNorm3d:
                norm["kwargs"]["num_features"] = num_channels
                return norm["module"](**norm["kwargs"])
            elif norm["module"] == nn.GroupNorm:
                norm["kwargs"]["num_channels"] = num_channels
                return norm["module"](**norm["kwargs"])
            else:
                raise ValueError("Not supported norm", norm["module"])


        def __init__(
            self,
            output_channels: int,
            norm: Dict,
            dropout: float = 0.5,
            repeat_layers=0,
            gem_pooling = None,
        ):
            super().__init__()
            norm_func = self.make_norm(norm, 64)
            self.en_layer0 = self._make_conv_layer(1, 64, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.en_layer1 = self._make_conv_layer(64, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 256)
            self.en_layer2 = self._make_conv_layer(128, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 512)
            self.en_layer3 = self._make_conv_layer(256, 512, norm=norm_func)

            norm_func = self.make_norm(norm, 1024)
            self.en_layer4 = self._make_conv_layer(512, 1024, norm=norm_func)

            self.max_pooling = nn.MaxPool3d((2, 2, 2))
            if gem_pooling:
                self.adap_max_pool = gem_pooling
            else:
                self.adap_max_pool = nn.AdaptiveAvgPool3d((2, 2, 2))
            
            self.headnet = self._make_headnet(
                2 * 2 * 2 * 1024, 2048, output_channels, dropout=dropout
            )

            norm_func = self.make_norm(norm, 512)
            self.de_layer0 = self._make_conv_layer(1024, 512, norm=norm_func)

            norm_func = self.make_norm(norm, 256)
            self.de_layer1 = self._make_conv_layer(512, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.de_layer2 = self._make_conv_layer(256, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 64)
            self.de_layer3 = self._make_conv_layer(128, 64, norm=norm_func)

            #norm_func = self.make_norm(norm, 1)
            #self.de_layer4 = self._make_conv_layer(64, 1, norm=norm_func)
            self.de_layer4 = nn.Sequential(
                nn.Conv3d(64, 1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.Conv3d(1, 1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
            )

            self.up_sampling = nn.Upsample(scale_factor =2)

        @staticmethod
        def _make_conv_layer(in_c: int, out_c: int, norm: nn.Module, padding: int = 1, kernel_size: int =3):
            conv_layer = nn.Sequential(
                nn.Conv3d(in_c, out_c, kernel_size=kernel_size, padding=padding),
                norm,
                nn.LeakyReLU(),
                nn.Conv3d(out_c, out_c, kernel_size=kernel_size, padding=padding),
                norm,
                nn.LeakyReLU(),
            )
            return conv_layer

        @staticmethod
        def _make_headnet(
            in_c1: int, out_c1: int, out_head: int, dropout: float
        ) -> nn.Sequential:
            headnet = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(in_c1, out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1, out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1, out_head),
                nn.LeakyReLU(),
                nn.Linear(out_head,out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1,in_c1)
            )
            return headnet

        def forward(self, inputtensor):
            """
            Forward pass through the network
            :param inputtensor: Input tensor
            """
            #inputtensor = F.pad(inputtensor, (1, 2, 1, 2, 1, 2))

            out = self.en_layer0(inputtensor)
            out = self.max_pooling(out)
            out = self.en_layer1(out)
            out = self.max_pooling(out)
            out = self.en_layer2(out)
            out = self.max_pooling(out)
            out = self.en_layer3(out)
            out = self.max_pooling(out)
            out = self.en_layer4(out)
            #out = self.adap_max_pool(out)
            out = out.reshape(out.size(0), -1) 
            out = self.headnet(out)
            out = out.reshape(-1,1024,2,2,2)
            out = self.de_layer0(out)
            out = self.up_sampling(out)
            out = self.de_layer1(out)
            out = self.up_sampling(out)
            out = self.de_layer2(out)
            out = self.up_sampling(out)
            out = self.de_layer3(out)
            out = self.up_sampling(out)
            out = self.de_layer4(out)
            #out = F.normalize(out, p=2, dim=1)

            return out

    """
    Custom 3D convnet, nothing fancy
    """

    def setup_norm(self, norm_name : str, norm_kwargs: dict) -> Dict:
        norm = {}
        if norm_name == AutoEncoder.NORM_BATCHNORM:
            norm["module"] = nn.BatchNorm3d
        if norm_name == AutoEncoder.NORM_GROUPNORM:
            norm["module"] = nn.GroupNorm
        norm["kwargs"] = norm_kwargs

        return norm


    def setup_gem_pooling(self,gem_pooling_p : float) -> Union[None, nn.Module]:
        gem_pooling = None
        if gem_pooling_p > 0:
            from tomotwin.modules.networks.GeneralizedMeanPooling import GeneralizedMeanPooling
            gem_pooling = GeneralizedMeanPooling(norm=gem_pooling_p, output_size=(2, 2, 2))
        return gem_pooling

    def __init__(
        self,
        norm_name: str,
        norm_kwargs: Dict = {},
        output_channels: int = 128,
        dropout: float = 0.5,
        gem_pooling_p: float = 0,
        repeat_layers=0,
    ):
        super().__init__()
        norm = self.setup_norm(norm_name, norm_kwargs)
        gem_pooling = self.setup_gem_pooling(gem_pooling_p)


        self.model = self.Model(
            output_channels=output_channels,
            dropout=dropout,
            repeat_layers=repeat_layers,
            norm=norm,
            gem_pooling=gem_pooling
        )

    def init_weights(self):
        def _init_weights(model):
            if isinstance(model, nn.Conv3d):
                torch.nn.init.kaiming_normal_(model.weight)

        self.model.apply(_init_weights)

    def get_model(self) -> nn.Module:
        return self.model

In [2]:
norm_name = "GroupNorm"
norm_kwargs = {"num_groups": 64,
        "num_channels": 1024}

In [62]:
M = AutoEncoder(norm_name,norm_kwargs,32)

In [23]:
Model = M.get_model()

In [24]:
out = Model(torch.rand(12,1,37,37,37))

torch.Size([12, 64, 40, 40, 40])
torch.Size([12, 128, 10, 10, 10])
torch.Size([12, 512, 5, 5, 5])
torch.Size([12, 512, 5, 5, 5])
torch.Size([12, 256, 10, 10, 10])
torch.Size([12, 128, 20, 20, 20])
torch.Size([12, 64, 20, 20, 20])
torch.Size([12, 64, 40, 40, 40])
torch.Size([12, 1, 40, 40, 40])


In [19]:
out.shape

torch.Size([12, 1, 40, 40, 40])

In [7]:
from tomotwin.modules.training.mrctriplethandler import MRCTripletHandler
import os
import numpy as np
import torch
from torch.utils.data import Dataset


class MRCVolumeDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = self._get_file_paths()
        self.reader = MRCTripletHandler()

    def _get_file_paths(self):
        file_paths = []
        for round_dir in os.listdir(self.root_dir):
            round_path = os.path.join(self.root_dir, round_dir)
            if os.path.isdir(round_path):
                for tomo_dir in os.listdir(round_path):
                    tomo_path = os.path.join(round_path, tomo_dir)
                    if os.path.isdir(tomo_path):
                        mrc_files = [f for f in os.listdir(tomo_path) if f.endswith('.mrc')]
                        for mrc_file in mrc_files:
                            file_paths.append(os.path.join(tomo_path, mrc_file))
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mrc_path = self.file_paths[idx]
        volume = self.reader.read_mrc_and_norm(mrc_path)

        return {'input': volume, 'target': volume}

In [8]:
from torch.utils.data import DataLoader
root_dir = '/home/yousef.metwally/projects/data/tomotwin_training_data/validation'

dataset = MRCVolumeDataset(root_dir)
batch_size = 32
shuffle = True  
num_workers = 4  
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


In [9]:
import torch.nn.functional as F

def loss_function(recon_x, x):
    mse_loss = F.mse_loss(recon_x, x)
    return mse_loss

In [13]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

def train_autoencoder(model, data_loader, optimizer, num_epochs=10, device='cuda'):
    writer = SummaryWriter('/home/yousef.metwally/projects/AutoEncoder')
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        with tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as progress_bar:
            for batch_idx, data in enumerate(progress_bar):
                input_data = data['input'].to(device)
                #if batch_idx > 500:
                 #   print(input_data.shape)
                input_data = input_data.reshape(-1,1,37,37,37)
                target_data = data['target'].to(device)
                target_data = F.pad(target_data, (1, 2, 1, 2, 1, 2))
                target_data = target_data.reshape(-1,1,40,40,40)
                optimizer.zero_grad()
                recon_data = model(input_data)
                loss = loss_function(recon_data, target_data)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                avg_loss = total_loss / (batch_idx + 1)
                writer.add_scalar('Loss/train', avg_loss, epoch * len(data_loader) + batch_idx)

                
                progress_bar.set_postfix(loss=avg_loss)

        print(f"Epoch {epoch+1}/{num_epochs}, Avg. Loss: {avg_loss:.4f}")
        torch.save(model.state_dict(), f"/home/yousef.metwally/projects/AutoEncoder/weights/run1/model_weights_epoch_{epoch+1}.pt")
    writer.close()


In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
num_epochs = 200
M = AutoEncoder(norm_name,norm_kwargs,32)
model = M.get_model()
model = nn.DataParallel(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_autoencoder(model, data_loader, optimizer, num_epochs, device)




In [10]:
class testDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = self._get_file_paths()
        self.reader = MRCTripletHandler()

    def _get_file_paths(self):
        file_paths = [os.path.join(self.root_dir, f) for f in os.listdir(self.root_dir) if f.endswith('.mrc')]
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mrc_path = self.file_paths[idx]
        volume = self.reader.read_mrc_and_norm(mrc_path)
        return {'input': volume, 'target': volume}

In [None]:
test_dir = '/home/yousef.metwally/projects/data/tomotwin_training_data/test'
test_dataset = testDataset(test_dir)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
model = UNet3D(1,1)
model = nn.DataParallel(model)
checkpoint_path = '/home/yousef.metwally/projects/UNet/weights/model_weights_epoch_16.pt'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
device = 'cuda'
model.to(device)
model.eval()

In [16]:
def evaluate (model, data_loader):
    total_loss = 0.0
    model.eval()
    with torch.no_grad():
        with tqdm(data_loader, desc="Evaluating", unit="batch") as progress_bar:
            for batch_idx, data in enumerate(progress_bar):
                input_data = data['input'].to(device)
                input_data = input_data.reshape(-1, 1, 37, 37, 37)
                target_data = data['target'].to(device)
                target_data = target_data.reshape(-1, 1, 37, 37, 37)
                recon_data = model(input_data)
                loss = loss_function(recon_data, target_data)
                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(data_loader)
    print(f"Validation Loss: {avg_loss}")

Evaluating: 100%|██████████| 2/2 [00:01<00:00,  1.20batch/s, loss=1.49e-5]

Validation Loss: 5.2851032178912064e-08





In [33]:
def predict(model, data_loader):
    predictions = []
    targets = []
    with torch.no_grad():
        for batch in data_loader:
            input_data = batch['input'].unsqueeze(1)
            target_data = batch['target'].unsqueeze(1)
            output = model(input_data)
            predictions.append(output.squeeze(1).cpu().numpy())
            targets.append(target_data.squeeze(1).cpu().numpy())
    return predictions, targets

In [34]:
import mrcfile
def save_as_mrc(data, filename):
    with mrcfile.new(filename, overwrite=True) as mrc:
        mrc.set_data(data.astype(np.float32))

In [None]:
predictions, targets = predict(model, test_loader)
save_dir = '/home/yousef.metwally/projects/UNet/output'
def save_ouput (out_dir: str, predictions, targets):
    for i, (pred, target) in enumerate(zip(predictions, targets)):
        pred_filename = os.path.join(out_dir, f'prediction_{i:03d}.mrc')
        target_filename = os.path.join(out_dir, f'target_{i:03d}.mrc')
        save_as_mrc(pred, pred_filename)
        save_as_mrc(target, target_filename)
        print(f'Saved prediction to {pred_filename} and target to {target_filename}')

In [31]:
!tensorboard --logdir='/home/yousef.metwally/projects/UNet/' --bind_all


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.16.2 at http://gtxr3.srv-local.mpi-dortmund.mpg.de:6006/ (Press CTRL+C to quit)
^C


In [36]:
reader = MRCTripletHandler()
v_path = '/home/yousef.metwally/projects/data/tomotwin_training_data/test/round01_t08_0XXX_000.mrc'
volume = reader.read_mrc_and_norm(v_path)
M = AutoEncoder(norm_name,norm_kwargs,32)
model = M.get_model()
checkpoint_path = '/home/yousef.metwally/projects/AutoEncoder/weights/run1/model_weights_epoch_200.pt'
checkpoint = torch.load(checkpoint_path)
state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)



RuntimeError: Error(s) in loading state_dict for Model:
	size mismatch for de_layer0.0.weight: copying a param with shape torch.Size([512, 256, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3, 3]).
	size mismatch for de_layer1.0.weight: copying a param with shape torch.Size([256, 128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3, 3]).
	size mismatch for de_layer2.0.weight: copying a param with shape torch.Size([128, 64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 128, 3, 3, 3]).
	size mismatch for de_layer4.0.weight: copying a param with shape torch.Size([64, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 3, 3, 3]).
	size mismatch for de_layer4.2.weight: copying a param with shape torch.Size([1, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1, 1]).

In [18]:
x = os.listdir("/home/yousef.metwally/projects/AutoEncoder/weights/run1")

In [19]:
x

['model_weights_epoch_197.pt',
 'model_weights_epoch_198.pt',
 'model_weights_epoch_199.pt',
 'model_weights_epoch_200.pt']

In [38]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Ensure the shapes match for concatenation
        diffZ = x2.size(2) - x1.size(2)
        diffY = x2.size(3) - x1.size(3)
        diffX = x2.size(4) - x1.size(4)

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet3D, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

        self._initialize_weights()

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

: 

In [3]:
import argparse
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from tomotwin.modules.networks.torchmodel import TorchModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from tomotwin.modules.training.mrctriplethandler import MRCTripletHandler
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffZ = x2.size(2) - x1.size(2)
        diffY = x2.size(3) - x1.size(3)
        diffX = x2.size(4) - x1.size(4)

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])
        
        #x = torch.cat([x2, x1], dim=1)
        return self.conv(x1)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, n_channels, out_channels, bilinear=True):
        super(UNet3D, self).__init__()
        self.n_channels = n_channels
        self.out_channels = out_channels
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.down5 = Down(1024,1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.up5 = Up(64,64)
        self.outc = OutConv(64, out_channels)

        self._initialize_weights()

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2)
        x = self.up5(x, x1)
        logits = self.outc(x)
        return logits
    

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


In [4]:
model = UNet3D(1,1)
out = model(torch.rand(12,1,37,37,37))

In [32]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv3d(inp, oup, kernel_size=3, stride=stride, padding=(1,1,1), bias=False),
        nn.BatchNorm3d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv3d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm3d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == (1,1,1) and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv3d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm3d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv3d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv3d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm3d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm3d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, sample_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1,  32, 1, (1,1,1)],
            [6,  64, 2, (2,2,2)],
            [6,  128, 3, (2,2,2)],
            [6,  256, 4, (2,2,2)],
            [6,  512, 3, (1,1,1)],
            [6, 1024, 3, (2,2,2)],
            [6, 2048, 1, (1,1,1)],
        ]

        # building first layer
        assert sample_size % 16 == 0.
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(1, input_channel, (1,2,2))]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                stride = s if i == 0 else (1,1,1)
                self.features.append(block(input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        self._initialize_weights()

    def forward(self, x):
        for i, layer in enumerate(self.features):
            x = layer(x)
           # print(f"Output after layer {i}: {x.shape}")
        x = F.avg_pool3d(x, x.data.size()[-3:])
       # print(f"Output after avg_pool3d: {x.shape}")
        x = x.view(x.size(0), -1)
       # print(f"Output after view: {x.shape}")
        #x = self.classifier(x)
        #print(f"Output after classifier: {x.shape}")
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

class MobileNetV2Autoencoder(nn.Module):
    def __init__(self, num_classes=1000, sample_size=32, width_mult=1.):
        super(MobileNetV2Autoencoder, self).__init__()
        self.encoder = MobileNetV2(num_classes=num_classes, sample_size=sample_size, width_mult=width_mult)
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(2048, 320, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(320),
            nn.ReLU6(inplace=True),
            nn.ConvTranspose3d(320, 160, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(160),
            nn.ReLU6(inplace=True),
            nn.ConvTranspose3d(160, 96, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(96),
            nn.ReLU6(inplace=True),
            nn.ConvTranspose3d(96, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU6(inplace=True),
            nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU6(inplace=True),
            nn.ConvTranspose3d(32, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # Assuming the output needs to be normalized
        )

    def forward(self, x):
       # print("Input shape:", x.shape)
        for i, layer in enumerate(self.encoder.features):
            x = layer(x)
         #   print(f"Encoder output after layer {i}: {x.shape}")
        
        x = F.avg_pool3d(x, x.data.size()[-3:])
       # print(f"Output after avg_pool3d: {x.shape}")
        x = x.view(x.size(0), -1, 1, 1, 1)
      #  print(f"Output after view: {x.shape}")

        for i, layer in enumerate(self.decoder):
            x = layer(x)
           # print(f"Decoder output after layer {i}: {x.shape}")

        return x

In [33]:

import argparse
import random
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from tomotwin.modules.networks.torchmodel import TorchModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from tomotwin.modules.training.mrctriplethandler import MRCTripletHandler
import os
import numpy as np
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import mrcfile
class MRCVolumeDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = self._get_file_paths()

    def read_and_normalize_mrc(self, file_path):
        with mrcfile.open(file_path, permissive=True) as mrc:
            data = mrc.data.astype(np.float32)
            min_val = np.min(data)
            max_val = np.max(data)
            normalized_data = (data - min_val) / (max_val - min_val)
        return normalized_data

    def _get_file_paths(self):
        file_paths = []
        for round_dir in os.listdir(self.root_dir):
            round_path = os.path.join(self.root_dir, round_dir)
            if os.path.isdir(round_path):
                for tomo_dir in os.listdir(round_path):
                    tomo_path = os.path.join(round_path, tomo_dir)
                    if os.path.isdir(tomo_path):
                        mrc_files = [f for f in os.listdir(tomo_path) if f.endswith('.mrc')]
                        for mrc_file in mrc_files:
                            file_paths.append(os.path.join(tomo_path, mrc_file))
                                                                                     
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mrc_path = self.file_paths[idx]
        volume = self.read_and_normalize_mrc(mrc_path)
        volume = volume[2:-3,2:-3,2:-3]

        return {'input': volume, 'target': volume}

    

def loss_function(recon_x, x):
    mse_loss = F.mse_loss(recon_x, x)
    return mse_loss


def train_autoencoder(model, data_loader, val_loader, optimizer, scheduler, logging, num_epochs=10, device='cuda', patience=10):
    writer = SummaryWriter(logging)
    model.to(device)
    model.train()
    best_loss = float('inf')
    early_stopping_counter = 0

    if not os.path.exists(f"{logging}/weights"):
        os.makedirs(f"{logging}/weights")

    for epoch in range(num_epochs):
        total_loss = 0.0
        model.train()
        with tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as progress_bar:
            for batch_idx, data in enumerate(progress_bar):
                input_data = data['input'].to(device)
                input_data = input_data.unsqueeze(1)
                target_data = data['target'].to(device)
                target_data = target_data.unsqueeze(1)
                #target_data = F.pad(target_data, (1, 2, 1, 2, 1, 2))  
                optimizer.zero_grad()
                recon_data = model(input_data)
                loss = loss_function(recon_data, target_data)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                avg_loss = total_loss / (batch_idx + 1)
                writer.add_scalar('Loss/train', avg_loss, epoch * len(data_loader) + batch_idx)
                progress_bar.set_postfix(loss=avg_loss)
                
        val_loss = validate_autoencoder(model,val_loader,writer, epoch+1,num_epochs)
        scheduler.step(val_loss)
        if val_loss < best_loss:
            print (f'{val_loss} < {best_loss}, epoch: {epoch+1} ')
            torch.save(model.state_dict(), f"{logging}/weights/model_weights_epoch_{epoch+1}.pt")
            best_loss = val_loss
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print(f"Validation loss did not improve for {patience} epochs. Early stopping...")
                break

        print(f"Epoch {epoch+1}/{num_epochs}, Avg. Loss: {avg_loss}, Val. Loss: {val_loss}")

    writer.close()


def validate_autoencoder(model, data_loader, writer, epoch,num_epochs, device='cuda'):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        with tqdm(data_loader, desc=f"Validation Epoch {epoch}/{num_epochs}", unit="batch") as progress_bar:
            for batch_idx, data in enumerate(progress_bar):
                input_data = data['input'].to(device)
                input_data = input_data.unsqueeze(1)
                target_data = data['target'].to(device)
                target_data = target_data.unsqueeze(1)
                #target_data = F.pad(target_data, (1, 2, 1, 2, 1, 2))  
                recon_data = model(input_data)
                loss = loss_function(recon_data, target_data)
                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(data_loader)
    writer.add_scalar('Loss/val', avg_loss, epoch)
    return avg_loss


In [34]:
dataset_root = '/home/yousef.metwally/projects/data/tomotwin_training_data/training'
val_root = '/home/yousef.metwally/projects/data/tomotwin_training_data/validation'
batch_size = 64
num_workers = 4
device = "cuda" if torch.cuda.is_available() else "cpu"
learning_rate = 0.001
num_epochs = 1000
logging = '/home/yousef.metwally/projects/mobilenet'
dataset = MRCVolumeDataset(dataset_root)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_set = MRCVolumeDataset(val_root)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
model = MobileNetV2Autoencoder()
model = nn.DataParallel(model)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    
train_autoencoder(model, data_loader, val_loader, optimizer,  scheduler,logging, num_epochs, device)

Epoch 1/1000: 100%|██████████| 1682/1682 [04:40<00:00,  6.00batch/s, loss=0.0134]
Validation Epoch 1/1000: 100%|██████████| 281/281 [00:27<00:00, 10.40batch/s, loss=0.0123]


0.012403838492240558 < inf, epoch: 1 
Epoch 1/1000, Avg. Loss: 0.013419953164368796, Val. Loss: 0.012403838492240558


Epoch 2/1000: 100%|██████████| 1682/1682 [04:40<00:00,  5.99batch/s, loss=0.0127]
Validation Epoch 2/1000: 100%|██████████| 281/281 [00:28<00:00,  9.99batch/s, loss=0.0114]


0.012360871772802174 < 0.012403838492240558, epoch: 2 
Epoch 2/1000, Avg. Loss: 0.012724148594042315, Val. Loss: 0.012360871772802174


Epoch 3/1000: 100%|██████████| 1682/1682 [04:43<00:00,  5.94batch/s, loss=0.0127]
Validation Epoch 3/1000: 100%|██████████| 281/281 [00:27<00:00, 10.17batch/s, loss=0.0116]


Epoch 3/1000, Avg. Loss: 0.01267022962928899, Val. Loss: 0.0123866999926198


Epoch 4/1000: 100%|██████████| 1682/1682 [04:33<00:00,  6.14batch/s, loss=0.0126]
Validation Epoch 4/1000: 100%|██████████| 281/281 [00:25<00:00, 10.92batch/s, loss=0.0128]


0.01235434081856148 < 0.012360871772802174, epoch: 4 
Epoch 4/1000, Avg. Loss: 0.012608975928950076, Val. Loss: 0.01235434081856148


Epoch 5/1000: 100%|██████████| 1682/1682 [04:34<00:00,  6.13batch/s, loss=0.0126]
Validation Epoch 5/1000: 100%|██████████| 281/281 [00:27<00:00, 10.35batch/s, loss=0.0124]


Epoch 5/1000, Avg. Loss: 0.012606804372629817, Val. Loss: 0.012359671490233776


Epoch 6/1000: 100%|██████████| 1682/1682 [04:42<00:00,  5.95batch/s, loss=0.0126]
Validation Epoch 6/1000: 100%|██████████| 281/281 [00:27<00:00, 10.13batch/s, loss=0.0127]


Epoch 6/1000, Avg. Loss: 0.012598589330127597, Val. Loss: 0.012390864704772034


Epoch 7/1000: 100%|██████████| 1682/1682 [04:40<00:00,  6.00batch/s, loss=0.0126]
Validation Epoch 7/1000: 100%|██████████| 281/281 [00:27<00:00, 10.17batch/s, loss=0.0128]


Epoch 7/1000, Avg. Loss: 0.01257380468307969, Val. Loss: 0.012382255002087334


Epoch 8/1000: 100%|██████████| 1682/1682 [04:38<00:00,  6.03batch/s, loss=0.0126]
Validation Epoch 8/1000: 100%|██████████| 281/281 [00:27<00:00, 10.36batch/s, loss=0.0122]


0.01235201530277835 < 0.01235434081856148, epoch: 8 
Epoch 8/1000, Avg. Loss: 0.012568638466218848, Val. Loss: 0.01235201530277835


Epoch 9/1000: 100%|██████████| 1682/1682 [04:40<00:00,  6.00batch/s, loss=0.0126]
Validation Epoch 9/1000: 100%|██████████| 281/281 [00:30<00:00,  9.29batch/s, loss=0.0129]


Epoch 9/1000, Avg. Loss: 0.01256760964491244, Val. Loss: 0.013267714894804241


Epoch 10/1000: 100%|██████████| 1682/1682 [04:33<00:00,  6.14batch/s, loss=0.0126]
Validation Epoch 10/1000: 100%|██████████| 281/281 [00:28<00:00,  9.95batch/s, loss=0.0127]


0.012341636388804565 < 0.01235201530277835, epoch: 10 
Epoch 10/1000, Avg. Loss: 0.012629973882638743, Val. Loss: 0.012341636388804565


Epoch 11/1000: 100%|██████████| 1682/1682 [04:38<00:00,  6.03batch/s, loss=0.0126]
Validation Epoch 11/1000: 100%|██████████| 281/281 [00:27<00:00, 10.33batch/s, loss=0.0123]


0.012336004336507082 < 0.012341636388804565, epoch: 11 
Epoch 11/1000, Avg. Loss: 0.01257794929223133, Val. Loss: 0.012336004336507082


Epoch 12/1000: 100%|██████████| 1682/1682 [04:38<00:00,  6.05batch/s, loss=0.0126]
Validation Epoch 12/1000: 100%|██████████| 281/281 [00:28<00:00, 10.02batch/s, loss=0.0125]


0.012317320756163461 < 0.012336004336507082, epoch: 12 
Epoch 12/1000, Avg. Loss: 0.012552832642528057, Val. Loss: 0.012317320756163461


Epoch 13/1000: 100%|██████████| 1682/1682 [04:36<00:00,  6.08batch/s, loss=0.0125]
Validation Epoch 13/1000: 100%|██████████| 281/281 [00:26<00:00, 10.48batch/s, loss=0.0117]


Epoch 13/1000, Avg. Loss: 0.012549906224344595, Val. Loss: 0.012327491978574478


Epoch 14/1000: 100%|██████████| 1682/1682 [04:43<00:00,  5.93batch/s, loss=0.0125]
Validation Epoch 14/1000: 100%|██████████| 281/281 [00:29<00:00,  9.38batch/s, loss=0.0124]


Epoch 14/1000, Avg. Loss: 0.012546363328755186, Val. Loss: 0.012781323758997952


Epoch 15/1000: 100%|██████████| 1682/1682 [04:42<00:00,  5.95batch/s, loss=0.0126]
Validation Epoch 15/1000: 100%|██████████| 281/281 [00:28<00:00,  9.69batch/s, loss=0.0129]


0.012315323844790671 < 0.012317320756163461, epoch: 15 
Epoch 15/1000, Avg. Loss: 0.01256249902114343, Val. Loss: 0.012315323844790671


Epoch 16/1000: 100%|██████████| 1682/1682 [04:40<00:00,  6.00batch/s, loss=0.0125]
Validation Epoch 16/1000: 100%|██████████| 281/281 [00:25<00:00, 10.94batch/s, loss=0.0126]


0.012292618193792916 < 0.012315323844790671, epoch: 16 
Epoch 16/1000, Avg. Loss: 0.012542468836745879, Val. Loss: 0.012292618193792916


Epoch 17/1000: 100%|██████████| 1682/1682 [04:38<00:00,  6.03batch/s, loss=0.0125]
Validation Epoch 17/1000: 100%|██████████| 281/281 [00:26<00:00, 10.73batch/s, loss=0.0124]


Epoch 17/1000, Avg. Loss: 0.012528378019213464, Val. Loss: 0.012298139319274561


Epoch 18/1000: 100%|██████████| 1682/1682 [04:37<00:00,  6.07batch/s, loss=0.0126]
Validation Epoch 18/1000: 100%|██████████| 281/281 [00:27<00:00, 10.33batch/s, loss=0.0128]


Epoch 18/1000, Avg. Loss: 0.012564352456608722, Val. Loss: 0.0123039803457748


Epoch 19/1000: 100%|██████████| 1682/1682 [04:41<00:00,  5.98batch/s, loss=0.0125]
Validation Epoch 19/1000: 100%|██████████| 281/281 [00:28<00:00,  9.90batch/s, loss=0.0128]


Epoch 19/1000, Avg. Loss: 0.012523948276408823, Val. Loss: 0.012293010227567784


Epoch 20/1000: 100%|██████████| 1682/1682 [04:37<00:00,  6.05batch/s, loss=0.0126]
Validation Epoch 20/1000: 100%|██████████| 281/281 [00:28<00:00,  9.91batch/s, loss=0.0122]


Epoch 20/1000, Avg. Loss: 0.012552360497419495, Val. Loss: 0.01231345408380668


Epoch 21/1000: 100%|██████████| 1682/1682 [04:38<00:00,  6.04batch/s, loss=0.0125]
Validation Epoch 21/1000: 100%|██████████| 281/281 [00:26<00:00, 10.48batch/s, loss=0.0116]


0.012290422025789569 < 0.012292618193792916, epoch: 21 
Epoch 21/1000, Avg. Loss: 0.012530535577456429, Val. Loss: 0.012290422025789569


Epoch 22/1000: 100%|██████████| 1682/1682 [04:36<00:00,  6.07batch/s, loss=0.0125]
Validation Epoch 22/1000: 100%|██████████| 281/281 [00:26<00:00, 10.41batch/s, loss=0.0118]


Epoch 22/1000, Avg. Loss: 0.012515806693774199, Val. Loss: 0.012293305804827036


Epoch 23/1000: 100%|██████████| 1682/1682 [04:37<00:00,  6.07batch/s, loss=0.0125]
Validation Epoch 23/1000: 100%|██████████| 281/281 [00:26<00:00, 10.78batch/s, loss=0.0122]


0.01228850754475042 < 0.012290422025789569, epoch: 23 
Epoch 23/1000, Avg. Loss: 0.01251049694969506, Val. Loss: 0.01228850754475042


Epoch 24/1000: 100%|██████████| 1682/1682 [04:32<00:00,  6.17batch/s, loss=0.0125]
Validation Epoch 24/1000: 100%|██████████| 281/281 [00:26<00:00, 10.53batch/s, loss=0.0126]


0.012282330178228138 < 0.01228850754475042, epoch: 24 
Epoch 24/1000, Avg. Loss: 0.01250447417907284, Val. Loss: 0.012282330178228138


Epoch 25/1000: 100%|██████████| 1682/1682 [04:32<00:00,  6.16batch/s, loss=0.0125]
Validation Epoch 25/1000: 100%|██████████| 281/281 [00:26<00:00, 10.75batch/s, loss=0.0123]


0.012267800926287192 < 0.012282330178228138, epoch: 25 
Epoch 25/1000, Avg. Loss: 0.012504460281215156, Val. Loss: 0.012267800926287192


Epoch 26/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0125]
Validation Epoch 26/1000: 100%|██████████| 281/281 [00:25<00:00, 10.87batch/s, loss=0.0117]


0.012265985659273919 < 0.012267800926287192, epoch: 26 
Epoch 26/1000, Avg. Loss: 0.0124951353231301, Val. Loss: 0.012265985659273919


Epoch 27/1000: 100%|██████████| 1682/1682 [04:37<00:00,  6.05batch/s, loss=0.0125]
Validation Epoch 27/1000: 100%|██████████| 281/281 [00:27<00:00, 10.18batch/s, loss=0.0122]


Epoch 27/1000, Avg. Loss: 0.012505322741289201, Val. Loss: 0.012275534937178112


Epoch 28/1000: 100%|██████████| 1682/1682 [04:36<00:00,  6.08batch/s, loss=0.0125]
Validation Epoch 28/1000: 100%|██████████| 281/281 [00:27<00:00, 10.33batch/s, loss=0.0125]


Epoch 28/1000, Avg. Loss: 0.012494600801003319, Val. Loss: 0.01228830031056315


Epoch 29/1000: 100%|██████████| 1682/1682 [04:37<00:00,  6.07batch/s, loss=0.0125]
Validation Epoch 29/1000: 100%|██████████| 281/281 [00:27<00:00, 10.22batch/s, loss=0.0124]


0.01225878418923697 < 0.012265985659273919, epoch: 29 
Epoch 29/1000, Avg. Loss: 0.012486987839166106, Val. Loss: 0.01225878418923697


Epoch 30/1000: 100%|██████████| 1682/1682 [04:39<00:00,  6.01batch/s, loss=0.0125]
Validation Epoch 30/1000: 100%|██████████| 281/281 [00:25<00:00, 10.87batch/s, loss=0.0122]


Epoch 30/1000, Avg. Loss: 0.012482472581901853, Val. Loss: 0.012282517622639277


Epoch 31/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0125]
Validation Epoch 31/1000: 100%|██████████| 281/281 [00:25<00:00, 10.89batch/s, loss=0.0127]


Epoch 31/1000, Avg. Loss: 0.012480444382260444, Val. Loss: 0.012267522563581153


Epoch 32/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0125]
Validation Epoch 32/1000: 100%|██████████| 281/281 [00:26<00:00, 10.64batch/s, loss=0.0126]


0.012250383944645046 < 0.01225878418923697, epoch: 32 
Epoch 32/1000, Avg. Loss: 0.012478634753285527, Val. Loss: 0.012250383944645046


Epoch 33/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0125]
Validation Epoch 33/1000: 100%|██████████| 281/281 [00:26<00:00, 10.78batch/s, loss=0.0128]


Epoch 33/1000, Avg. Loss: 0.012474935871131462, Val. Loss: 0.012263539674761457


Epoch 34/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0125]
Validation Epoch 34/1000: 100%|██████████| 281/281 [00:25<00:00, 10.88batch/s, loss=0.0123]


0.012245708161772782 < 0.012250383944645046, epoch: 34 
Epoch 34/1000, Avg. Loss: 0.012474915193223288, Val. Loss: 0.012245708161772782


Epoch 35/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 35/1000: 100%|██████████| 281/281 [00:25<00:00, 10.85batch/s, loss=0.0124]


Epoch 35/1000, Avg. Loss: 0.01250464963814184, Val. Loss: 0.01225349381632334


Epoch 36/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0125]
Validation Epoch 36/1000: 100%|██████████| 281/281 [00:25<00:00, 10.90batch/s, loss=0.0126]


0.012238296481293504 < 0.012245708161772782, epoch: 36 
Epoch 36/1000, Avg. Loss: 0.012471850735537362, Val. Loss: 0.012238296481293504


Epoch 37/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0125]
Validation Epoch 37/1000: 100%|██████████| 281/281 [00:25<00:00, 10.88batch/s, loss=0.0118]


Epoch 37/1000, Avg. Loss: 0.012479541158274826, Val. Loss: 0.012248791638244726


Epoch 38/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 38/1000: 100%|██████████| 281/281 [00:26<00:00, 10.61batch/s, loss=0.0116]


Epoch 38/1000, Avg. Loss: 0.012470336796496591, Val. Loss: 0.012248143344931968


Epoch 39/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.21batch/s, loss=0.0125]
Validation Epoch 39/1000: 100%|██████████| 281/281 [00:26<00:00, 10.70batch/s, loss=0.0115]


0.012233959759496073 < 0.012238296481293504, epoch: 39 
Epoch 39/1000, Avg. Loss: 0.012466550481314112, Val. Loss: 0.012233959759496073


Epoch 40/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 40/1000: 100%|██████████| 281/281 [00:25<00:00, 10.83batch/s, loss=0.012] 


Epoch 40/1000, Avg. Loss: 0.012465715532353885, Val. Loss: 0.012242226069295958


Epoch 41/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 41/1000: 100%|██████████| 281/281 [00:25<00:00, 10.88batch/s, loss=0.0128]


Epoch 41/1000, Avg. Loss: 0.012462104500163962, Val. Loss: 0.012251047526893879


Epoch 42/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 42/1000: 100%|██████████| 281/281 [00:27<00:00, 10.32batch/s, loss=0.0118]


0.012223453751476846 < 0.012233959759496073, epoch: 42 
Epoch 42/1000, Avg. Loss: 0.012460000962506223, Val. Loss: 0.012223453751476846


Epoch 43/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 43/1000: 100%|██████████| 281/281 [00:26<00:00, 10.74batch/s, loss=0.0117]


Epoch 43/1000, Avg. Loss: 0.012455436331856726, Val. Loss: 0.01224585263277288


Epoch 44/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.21batch/s, loss=0.0125]
Validation Epoch 44/1000: 100%|██████████| 281/281 [00:25<00:00, 10.84batch/s, loss=0.012] 


Epoch 44/1000, Avg. Loss: 0.01245517075874629, Val. Loss: 0.012228671196198549


Epoch 45/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 45/1000: 100%|██████████| 281/281 [00:25<00:00, 10.86batch/s, loss=0.0116]


Epoch 45/1000, Avg. Loss: 0.012450681130017523, Val. Loss: 0.012225683089146835


Epoch 46/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 46/1000: 100%|██████████| 281/281 [00:26<00:00, 10.47batch/s, loss=0.0123]


Epoch 46/1000, Avg. Loss: 0.012447136530694724, Val. Loss: 0.012271886714651278


Epoch 47/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0125]
Validation Epoch 47/1000: 100%|██████████| 281/281 [00:25<00:00, 10.83batch/s, loss=0.0126]


Epoch 47/1000, Avg. Loss: 0.01245223550296496, Val. Loss: 0.012227872270573713


Epoch 48/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 48/1000: 100%|██████████| 281/281 [00:25<00:00, 10.90batch/s, loss=0.0122]


0.012210643508892467 < 0.012223453751476846, epoch: 48 
Epoch 48/1000, Avg. Loss: 0.01244552106420453, Val. Loss: 0.012210643508892467


Epoch 49/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 49/1000: 100%|██████████| 281/281 [00:25<00:00, 10.90batch/s, loss=0.013] 


Epoch 49/1000, Avg. Loss: 0.012442945716426432, Val. Loss: 0.012236524853868629


Epoch 50/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 50/1000: 100%|██████████| 281/281 [00:27<00:00, 10.40batch/s, loss=0.0125]


Epoch 50/1000, Avg. Loss: 0.012439376129198479, Val. Loss: 0.0122237100905583


Epoch 51/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 51/1000: 100%|██████████| 281/281 [00:26<00:00, 10.79batch/s, loss=0.012] 


Epoch 51/1000, Avg. Loss: 0.012436550689748223, Val. Loss: 0.012215736332047877


Epoch 52/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 52/1000: 100%|██████████| 281/281 [00:26<00:00, 10.81batch/s, loss=0.0115]


0.012205947531897614 < 0.012210643508892467, epoch: 52 
Epoch 52/1000, Avg. Loss: 0.012433177008604749, Val. Loss: 0.012205947531897614


Epoch 53/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 53/1000: 100%|██████████| 281/281 [00:25<00:00, 10.86batch/s, loss=0.0113]


Epoch 53/1000, Avg. Loss: 0.012438012320242654, Val. Loss: 0.012279720108598152


Epoch 54/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 54/1000: 100%|██████████| 281/281 [00:26<00:00, 10.69batch/s, loss=0.0121]


Epoch 54/1000, Avg. Loss: 0.01243888231672072, Val. Loss: 0.012218968674301889


Epoch 55/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 55/1000: 100%|██████████| 281/281 [00:26<00:00, 10.68batch/s, loss=0.0118]


Epoch 55/1000, Avg. Loss: 0.012432998527443827, Val. Loss: 0.012208329533554906


Epoch 56/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 56/1000: 100%|██████████| 281/281 [00:25<00:00, 10.86batch/s, loss=0.0114]


0.01220454051891786 < 0.012205947531897614, epoch: 56 
Epoch 56/1000, Avg. Loss: 0.012435250590011637, Val. Loss: 0.01220454051891786


Epoch 57/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 57/1000: 100%|██████████| 281/281 [00:25<00:00, 10.82batch/s, loss=0.0119]


Epoch 57/1000, Avg. Loss: 0.012430001386883143, Val. Loss: 0.012223588451875699


Epoch 58/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 58/1000: 100%|██████████| 281/281 [00:26<00:00, 10.62batch/s, loss=0.0115]


Epoch 58/1000, Avg. Loss: 0.012427704257051173, Val. Loss: 0.01222451924415988


Epoch 59/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 59/1000: 100%|██████████| 281/281 [00:26<00:00, 10.73batch/s, loss=0.0125]


Epoch 59/1000, Avg. Loss: 0.01242100279811464, Val. Loss: 0.01221617956204983


Epoch 60/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 60/1000: 100%|██████████| 281/281 [00:26<00:00, 10.56batch/s, loss=0.0119]


0.012194960520625964 < 0.01220454051891786, epoch: 60 
Epoch 60/1000, Avg. Loss: 0.012420265131389698, Val. Loss: 0.012194960520625964


Epoch 61/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 61/1000: 100%|██████████| 281/281 [00:26<00:00, 10.61batch/s, loss=0.0118]


Epoch 61/1000, Avg. Loss: 0.01243225454577598, Val. Loss: 0.012204971144579059


Epoch 62/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 62/1000: 100%|██████████| 281/281 [00:26<00:00, 10.69batch/s, loss=0.0119]


Epoch 62/1000, Avg. Loss: 0.012419795951895709, Val. Loss: 0.012205513936609983


Epoch 63/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 63/1000: 100%|██████████| 281/281 [00:25<00:00, 10.84batch/s, loss=0.0105]


Epoch 63/1000, Avg. Loss: 0.012416784068531889, Val. Loss: 0.01221347789137601


Epoch 64/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 64/1000: 100%|██████████| 281/281 [00:26<00:00, 10.46batch/s, loss=0.0131]


0.012190266542163184 < 0.012194960520625964, epoch: 64 
Epoch 64/1000, Avg. Loss: 0.012416197854126322, Val. Loss: 0.012190266542163184


Epoch 65/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 65/1000: 100%|██████████| 281/281 [00:25<00:00, 10.81batch/s, loss=0.0127]


Epoch 65/1000, Avg. Loss: 0.01241142260172049, Val. Loss: 0.01219364742871069


Epoch 66/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 66/1000: 100%|██████████| 281/281 [00:25<00:00, 10.89batch/s, loss=0.0118]


0.012188254676600155 < 0.012190266542163184, epoch: 66 
Epoch 66/1000, Avg. Loss: 0.012412048003101427, Val. Loss: 0.012188254676600155


Epoch 67/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.18batch/s, loss=0.0124]
Validation Epoch 67/1000: 100%|██████████| 281/281 [00:25<00:00, 10.94batch/s, loss=0.0122]


Epoch 67/1000, Avg. Loss: 0.012410368500709286, Val. Loss: 0.012206497257475963


Epoch 68/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 68/1000: 100%|██████████| 281/281 [00:26<00:00, 10.72batch/s, loss=0.0117]


Epoch 68/1000, Avg. Loss: 0.012410075943053265, Val. Loss: 0.01220995009449241


Epoch 69/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.19batch/s, loss=0.0124]
Validation Epoch 69/1000: 100%|██████████| 281/281 [00:26<00:00, 10.76batch/s, loss=0.0128]


Epoch 69/1000, Avg. Loss: 0.012409784138982275, Val. Loss: 0.012222449742655312


Epoch 70/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 70/1000: 100%|██████████| 281/281 [00:25<00:00, 10.87batch/s, loss=0.0125]


Epoch 70/1000, Avg. Loss: 0.012405869023881046, Val. Loss: 0.012202036089770947


Epoch 71/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 71/1000: 100%|██████████| 281/281 [00:25<00:00, 10.87batch/s, loss=0.0124]


Epoch 71/1000, Avg. Loss: 0.01240978153991678, Val. Loss: 0.01221894753891378


Epoch 72/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 72/1000: 100%|██████████| 281/281 [00:27<00:00, 10.38batch/s, loss=0.0113]


Epoch 72/1000, Avg. Loss: 0.012413969626107588, Val. Loss: 0.012212173900570309


Epoch 73/1000: 100%|██████████| 1682/1682 [04:31<00:00,  6.20batch/s, loss=0.0124]
Validation Epoch 73/1000: 100%|██████████| 281/281 [00:26<00:00, 10.76batch/s, loss=0.0126]


Epoch 73/1000, Avg. Loss: 0.012400113463105012, Val. Loss: 0.012206023144684865


Epoch 74/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 74/1000: 100%|██████████| 281/281 [00:25<00:00, 10.86batch/s, loss=0.0114]


Epoch 74/1000, Avg. Loss: 0.012396652824007303, Val. Loss: 0.012189609505686896


Epoch 75/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 75/1000: 100%|██████████| 281/281 [00:27<00:00, 10.27batch/s, loss=0.012] 


Epoch 75/1000, Avg. Loss: 0.012394623320957249, Val. Loss: 0.012216815270907726


Epoch 76/1000: 100%|██████████| 1682/1682 [04:30<00:00,  6.21batch/s, loss=0.0124]
Validation Epoch 76/1000: 100%|██████████| 281/281 [00:25<00:00, 10.83batch/s, loss=0.0128]

Validation loss did not improve for 10 epochs. Early stopping...





In [21]:
model = MobileNetV2Autoencoder()
out = model(torch.rand(12,1,32,32,32))

Input shape: torch.Size([12, 1, 32, 32, 32])
Encoder output after layer 0: torch.Size([12, 32, 32, 16, 16])
Encoder output after layer 1: torch.Size([12, 16, 32, 16, 16])
Encoder output after layer 2: torch.Size([12, 24, 16, 8, 8])
Encoder output after layer 3: torch.Size([12, 24, 16, 8, 8])
Encoder output after layer 4: torch.Size([12, 32, 8, 4, 4])
Encoder output after layer 5: torch.Size([12, 32, 8, 4, 4])
Encoder output after layer 6: torch.Size([12, 32, 8, 4, 4])
Encoder output after layer 7: torch.Size([12, 64, 4, 2, 2])
Encoder output after layer 8: torch.Size([12, 64, 4, 2, 2])
Encoder output after layer 9: torch.Size([12, 64, 4, 2, 2])
Encoder output after layer 10: torch.Size([12, 64, 4, 2, 2])
Encoder output after layer 11: torch.Size([12, 96, 4, 2, 2])
Encoder output after layer 12: torch.Size([12, 96, 4, 2, 2])
Encoder output after layer 13: torch.Size([12, 96, 4, 2, 2])
Encoder output after layer 14: torch.Size([12, 160, 2, 1, 1])
Encoder output after layer 15: torch.Size