In [5]:
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from tomotwin.modules.networks.torchmodel import TorchModel

class AutoEncoder(TorchModel):

    NORM_BATCHNORM = "BatchNorm"
    NORM_GROUPNORM = "GroupNorm"

    class Model(nn.Module):
        def make_norm(self, norm: Dict, num_channels: int) -> nn.Module:
            if norm["module"] == nn.BatchNorm3d:
                norm["kwargs"]["num_features"] = num_channels
                return norm["module"](**norm["kwargs"])
            elif norm["module"] == nn.GroupNorm:
                norm["kwargs"]["num_channels"] = num_channels
                return norm["module"](**norm["kwargs"])
            else:
                raise ValueError("Not supported norm", norm["module"])


        def __init__(
            self,
            output_channels: int,
            norm: Dict,
            dropout: float = 0.5,
            repeat_layers=0,
            gem_pooling = None,
        ):
            super().__init__()
            norm_func = self.make_norm(norm, 64)
            self.en_layer0 = self._make_conv_layer(1, 64, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.en_layer1 = self._make_conv_layer(64, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 256)
            self.en_layer2 = self._make_conv_layer(128, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 512)
            self.en_layer3 = self._make_conv_layer(256, 512, norm=norm_func)


            self.max_pooling = nn.MaxPool3d((2, 2, 2))
            if gem_pooling:
                self.adap_max_pool = gem_pooling
            else:
                self.adap_max_pool = nn.AdaptiveAvgPool3d((2, 2, 2))
            
            self.headnet = self._make_headnet(
                 512, 256, 64, 1, dropout=dropout
            )

            norm_func = self.make_norm(norm, 256)
            self.de_layer0 = self._make_deconv_layer(512, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.de_layer1 = self._make_deconv_layer(256, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 64)
            self.de_layer2 = self._make_deconv_layer(128, 64, norm=norm_func)

            #norm_func = self.make_norm(norm, 64)
            #self.de_layer3 = self._make_conv_layer(128, 64, norm=norm_func)

            #norm_func = self.make_norm(norm, 1)
            #self.de_layer4 = self._make_conv_layer(64, 1, norm=norm_func)
            self.de_layer4 = nn.Sequential(
                nn.ConvTranspose3d(64, 1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(1, 1, kernel_size=3, padding=1),
                nn.Identity() 
            )

            self.up_sampling = nn.Upsample(scale_factor =2)

        @staticmethod
        def _make_conv_layer(in_c: int, out_c: int, norm: nn.Module, padding: int = 1, kernel_size: int =3):
            conv_layer = nn.Sequential(
                nn.Conv3d(in_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
                nn.Conv3d(out_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
            )
            return conv_layer
        
        @staticmethod
        def _make_deconv_layer(in_c: int, out_c: int, norm: nn.Module, padding: int = 1, kernel_size: int =3):
            conv_layer = nn.Sequential(
                nn.ConvTranspose3d(in_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
                nn.ConvTranspose3d(out_c, out_c, kernel_size=3, padding=padding),
                norm,
                nn.LeakyReLU(),
            )
            return conv_layer

        @staticmethod
        def _make_headnet(
            in_c1: int, in_c2: int,out_c1: int, out_head: int, dropout: float
        ) -> nn.Sequential:
            headnet = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Conv3d(in_c1, in_c2, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.Conv3d(in_c2, out_c1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.Conv3d(out_c1, out_head, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(out_head, out_c1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(out_c1, in_c2, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.ConvTranspose3d(in_c2,in_c1, kernel_size=3, padding=1),
                nn.LeakyReLU(),


            )
            return headnet

        def forward(self, inputtensor):
            """
            Forward pass through the network
            :param inputtensor: Input tensor
            """
            inputtensor = F.pad(inputtensor, (1, 2, 1, 2, 1, 2))

            out = self.en_layer0(inputtensor)
            out = self.max_pooling(out)
            out = self.en_layer1(out)
            out = self.max_pooling(out)
            out = self.en_layer2(out)
            out = self.max_pooling(out)
            out = self.en_layer3(out)
            #out = self.max_pooling(out)

            #out = self.en_layer4(out)
            #out = self.adap_max_pool(out)
            #out = out.reshape(out.size(0), -1)  # flatten
            out = self.headnet(out)
            #out = out.reshape(-1,512,5,5,5)
            out = self.de_layer0(out)
            out = self.up_sampling(out)
            out = self.de_layer1(out)
            out = self.up_sampling(out)
            out = self.de_layer2(out)
            out = self.up_sampling(out)
            #out = self.de_layer3(out)
            #out = self.up_sampling(out)
            out = self.de_layer4(out)
            #out = F.normalize(out, p=2, dim=1)

            return out

    """
    Custom 3D convnet, nothing fancy
    """

    def setup_norm(self, norm_name : str, norm_kwargs: dict) -> Dict:
        norm = {}
        if norm_name == AutoEncoder.NORM_BATCHNORM:
            norm["module"] = nn.BatchNorm3d
        if norm_name == AutoEncoder.NORM_GROUPNORM:
            norm["module"] = nn.GroupNorm
        norm["kwargs"] = norm_kwargs

        return norm


    def setup_gem_pooling(self,gem_pooling_p : float) -> Union[None, nn.Module]:
        gem_pooling = None
        if gem_pooling_p > 0:
            from tomotwin.modules.networks.GeneralizedMeanPooling import GeneralizedMeanPooling
            gem_pooling = GeneralizedMeanPooling(norm=gem_pooling_p, output_size=(2, 2, 2))
        return gem_pooling

    def __init__(
        self,
        norm_name: str,
        norm_kwargs: Dict = {},
        output_channels: int = 128,
        dropout: float = 0.5,
        gem_pooling_p: float = 0,
        repeat_layers=0,
    ):
        super().__init__()
        norm = self.setup_norm(norm_name, norm_kwargs)
        gem_pooling = self.setup_gem_pooling(gem_pooling_p)


        self.model = self.Model(
            output_channels=output_channels,
            dropout=dropout,
            repeat_layers=repeat_layers,
            norm=norm,
            gem_pooling=gem_pooling
        )

    def init_weights(self):
        def _init_weights(model):
            if isinstance(model, nn.Conv3d):
                torch.nn.init.kaiming_normal_(model.weight)

        self.model.apply(_init_weights)

    def get_model(self) -> nn.Module:
        return self.model

In [6]:
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from tomotwin.modules.networks.torchmodel import TorchModel

class AutoEncoder(TorchModel):

    NORM_BATCHNORM = "BatchNorm"
    NORM_GROUPNORM = "GroupNorm"

    class Model(nn.Module):
        def make_norm(self, norm: Dict, num_channels: int) -> nn.Module:
            if norm["module"] == nn.BatchNorm3d:
                norm["kwargs"]["num_features"] = num_channels
                return norm["module"](**norm["kwargs"])
            elif norm["module"] == nn.GroupNorm:
                norm["kwargs"]["num_channels"] = num_channels
                return norm["module"](**norm["kwargs"])
            else:
                raise ValueError("Not supported norm", norm["module"])


        def __init__(
            self,
            output_channels: int,
            norm: Dict,
            dropout: float = 0.5,
            repeat_layers=0,
            gem_pooling = None,
        ):
            super().__init__()
            norm_func = self.make_norm(norm, 64)
            self.en_layer0 = self._make_conv_layer(1, 64, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.en_layer1 = self._make_conv_layer(64, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 256)
            self.en_layer2 = self._make_conv_layer(128, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 512)
            self.en_layer3 = self._make_conv_layer(256, 512, norm=norm_func)

            norm_func = self.make_norm(norm, 1024)
            self.en_layer4 = self._make_conv_layer(512, 1024, norm=norm_func)

            self.max_pooling = nn.MaxPool3d((2, 2, 2))
            if gem_pooling:
                self.adap_max_pool = gem_pooling
            else:
                self.adap_max_pool = nn.AdaptiveAvgPool3d((2, 2, 2))
            
            self.headnet = self._make_headnet(
                2 * 2 * 2 * 1024, 2048, output_channels, dropout=dropout
            )

            norm_func = self.make_norm(norm, 512)
            self.de_layer0 = self._make_conv_layer(1024, 512, norm=norm_func)

            norm_func = self.make_norm(norm, 256)
            self.de_layer1 = self._make_conv_layer(512, 256, norm=norm_func)

            norm_func = self.make_norm(norm, 128)
            self.de_layer2 = self._make_conv_layer(256, 128, norm=norm_func)

            norm_func = self.make_norm(norm, 64)
            self.de_layer3 = self._make_conv_layer(128, 64, norm=norm_func)

            #norm_func = self.make_norm(norm, 1)
            #self.de_layer4 = self._make_conv_layer(64, 1, norm=norm_func)
            self.de_layer4 = nn.Sequential(
                nn.Conv3d(64, 1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
                nn.Conv3d(1, 1, kernel_size=3, padding=1),
                nn.LeakyReLU(),
            )

            self.up_sampling = nn.Upsample(scale_factor =2)

        @staticmethod
        def _make_conv_layer(in_c: int, out_c: int, norm: nn.Module, padding: int = 1, kernel_size: int =3):
            conv_layer = nn.Sequential(
                nn.Conv3d(in_c, out_c, kernel_size=kernel_size, padding=padding),
                norm,
                nn.LeakyReLU(),
                nn.Conv3d(out_c, out_c, kernel_size=kernel_size, padding=padding),
                norm,
                nn.LeakyReLU(),
            )
            return conv_layer

        @staticmethod
        def _make_headnet(
            in_c1: int, out_c1: int, out_head: int, dropout: float
        ) -> nn.Sequential:
            headnet = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(in_c1, out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1, out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1, out_head),
                nn.LeakyReLU(),
                nn.Linear(out_head,out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1,in_c1)
            )
            return headnet

        def forward(self, inputtensor):
            """
            Forward pass through the network
            :param inputtensor: Input tensor
            """
            #inputtensor = F.pad(inputtensor, (1, 2, 1, 2, 1, 2))

            out = self.en_layer0(inputtensor)
            out = self.max_pooling(out)
            out = self.en_layer1(out)
            out = self.max_pooling(out)
            out = self.en_layer2(out)
            out = self.max_pooling(out)
            out = self.en_layer3(out)
            out = self.max_pooling(out)
            out = self.en_layer4(out)
            #out = self.adap_max_pool(out)
            out = out.reshape(out.size(0), -1) 
            out = self.headnet(out)
            out = out.reshape(-1,1024,2,2,2)
            out = self.de_layer0(out)
            out = self.up_sampling(out)
            out = self.de_layer1(out)
            out = self.up_sampling(out)
            out = self.de_layer2(out)
            out = self.up_sampling(out)
            out = self.de_layer3(out)
            out = self.up_sampling(out)
            out = self.de_layer4(out)
            #out = F.normalize(out, p=2, dim=1)

            return out

    """
    Custom 3D convnet, nothing fancy
    """

    def setup_norm(self, norm_name : str, norm_kwargs: dict) -> Dict:
        norm = {}
        if norm_name == AutoEncoder.NORM_BATCHNORM:
            norm["module"] = nn.BatchNorm3d
        if norm_name == AutoEncoder.NORM_GROUPNORM:
            norm["module"] = nn.GroupNorm
        norm["kwargs"] = norm_kwargs

        return norm


    def setup_gem_pooling(self,gem_pooling_p : float) -> Union[None, nn.Module]:
        gem_pooling = None
        if gem_pooling_p > 0:
            from tomotwin.modules.networks.GeneralizedMeanPooling import GeneralizedMeanPooling
            gem_pooling = GeneralizedMeanPooling(norm=gem_pooling_p, output_size=(2, 2, 2))
        return gem_pooling

    def __init__(
        self,
        norm_name: str,
        norm_kwargs: Dict = {},
        output_channels: int = 128,
        dropout: float = 0.5,
        gem_pooling_p: float = 0,
        repeat_layers=0,
    ):
        super().__init__()
        norm = self.setup_norm(norm_name, norm_kwargs)
        gem_pooling = self.setup_gem_pooling(gem_pooling_p)


        self.model = self.Model(
            output_channels=output_channels,
            dropout=dropout,
            repeat_layers=repeat_layers,
            norm=norm,
            gem_pooling=gem_pooling
        )

    def init_weights(self):
        def _init_weights(model):
            if isinstance(model, nn.Conv3d):
                torch.nn.init.kaiming_normal_(model.weight)

        self.model.apply(_init_weights)

    def get_model(self) -> nn.Module:
        return self.model

In [2]:
norm_name = "GroupNorm"
norm_kwargs = {"num_groups": 64,
        "num_channels": 1024}

In [62]:
M = AutoEncoder(norm_name,norm_kwargs,32)

In [23]:
Model = M.get_model()

In [24]:
out = Model(torch.rand(12,1,37,37,37))

torch.Size([12, 64, 40, 40, 40])
torch.Size([12, 128, 10, 10, 10])
torch.Size([12, 512, 5, 5, 5])
torch.Size([12, 512, 5, 5, 5])
torch.Size([12, 256, 10, 10, 10])
torch.Size([12, 128, 20, 20, 20])
torch.Size([12, 64, 20, 20, 20])
torch.Size([12, 64, 40, 40, 40])
torch.Size([12, 1, 40, 40, 40])


In [19]:
out.shape

torch.Size([12, 1, 40, 40, 40])

In [7]:
from tomotwin.modules.training.mrctriplethandler import MRCTripletHandler
import os
import numpy as np
import torch
from torch.utils.data import Dataset


class MRCVolumeDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = self._get_file_paths()
        self.reader = MRCTripletHandler()

    def _get_file_paths(self):
        file_paths = []
        for round_dir in os.listdir(self.root_dir):
            round_path = os.path.join(self.root_dir, round_dir)
            if os.path.isdir(round_path):
                for tomo_dir in os.listdir(round_path):
                    tomo_path = os.path.join(round_path, tomo_dir)
                    if os.path.isdir(tomo_path):
                        mrc_files = [f for f in os.listdir(tomo_path) if f.endswith('.mrc')]
                        for mrc_file in mrc_files:
                            file_paths.append(os.path.join(tomo_path, mrc_file))
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mrc_path = self.file_paths[idx]
        volume = self.reader.read_mrc_and_norm(mrc_path)

        return {'input': volume, 'target': volume}

In [8]:
from torch.utils.data import DataLoader
root_dir = '/home/yousef.metwally/projects/data/tomotwin_training_data/validation'

dataset = MRCVolumeDataset(root_dir)
batch_size = 32
shuffle = True  
num_workers = 4  
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


In [9]:
import torch.nn.functional as F

def loss_function(recon_x, x):
    mse_loss = F.mse_loss(recon_x, x)
    return mse_loss

In [13]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

def train_autoencoder(model, data_loader, optimizer, num_epochs=10, device='cuda'):
    writer = SummaryWriter('/home/yousef.metwally/projects/AutoEncoder')
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        with tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as progress_bar:
            for batch_idx, data in enumerate(progress_bar):
                input_data = data['input'].to(device)
                #if batch_idx > 500:
                 #   print(input_data.shape)
                input_data = input_data.reshape(-1,1,37,37,37)
                target_data = data['target'].to(device)
                target_data = F.pad(target_data, (1, 2, 1, 2, 1, 2))
                target_data = target_data.reshape(-1,1,40,40,40)
                optimizer.zero_grad()
                recon_data = model(input_data)
                loss = loss_function(recon_data, target_data)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                avg_loss = total_loss / (batch_idx + 1)
                writer.add_scalar('Loss/train', avg_loss, epoch * len(data_loader) + batch_idx)

                
                progress_bar.set_postfix(loss=avg_loss)

        print(f"Epoch {epoch+1}/{num_epochs}, Avg. Loss: {avg_loss:.4f}")
        torch.save(model.state_dict(), f"/home/yousef.metwally/projects/AutoEncoder/weights/run1/model_weights_epoch_{epoch+1}.pt")
    writer.close()


In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
num_epochs = 200
M = AutoEncoder(norm_name,norm_kwargs,32)
model = M.get_model()
model = nn.DataParallel(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_autoencoder(model, data_loader, optimizer, num_epochs, device)




In [10]:
class testDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = self._get_file_paths()
        self.reader = MRCTripletHandler()

    def _get_file_paths(self):
        file_paths = [os.path.join(self.root_dir, f) for f in os.listdir(self.root_dir) if f.endswith('.mrc')]
        return file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        mrc_path = self.file_paths[idx]
        volume = self.reader.read_mrc_and_norm(mrc_path)
        return {'input': volume, 'target': volume}

In [None]:
test_dir = '/home/yousef.metwally/projects/data/tomotwin_training_data/test'
test_dataset = testDataset(test_dir)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)
model = UNet3D(1,1)
model = nn.DataParallel(model)
checkpoint_path = '/home/yousef.metwally/projects/UNet/weights/model_weights_epoch_16.pt'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
device = 'cuda'
model.to(device)
model.eval()

In [16]:
def evaluate (model, data_loader):
    total_loss = 0.0
    model.eval()
    with torch.no_grad():
        with tqdm(data_loader, desc="Evaluating", unit="batch") as progress_bar:
            for batch_idx, data in enumerate(progress_bar):
                input_data = data['input'].to(device)
                input_data = input_data.reshape(-1, 1, 37, 37, 37)
                target_data = data['target'].to(device)
                target_data = target_data.reshape(-1, 1, 37, 37, 37)
                recon_data = model(input_data)
                loss = loss_function(recon_data, target_data)
                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item())
    
    avg_loss = total_loss / len(data_loader)
    print(f"Validation Loss: {avg_loss}")

Evaluating: 100%|██████████| 2/2 [00:01<00:00,  1.20batch/s, loss=1.49e-5]

Validation Loss: 5.2851032178912064e-08





In [33]:
def predict(model, data_loader):
    predictions = []
    targets = []
    with torch.no_grad():
        for batch in data_loader:
            input_data = batch['input'].unsqueeze(1)
            target_data = batch['target'].unsqueeze(1)
            output = model(input_data)
            predictions.append(output.squeeze(1).cpu().numpy())
            targets.append(target_data.squeeze(1).cpu().numpy())
    return predictions, targets

In [34]:
import mrcfile
def save_as_mrc(data, filename):
    with mrcfile.new(filename, overwrite=True) as mrc:
        mrc.set_data(data.astype(np.float32))

In [None]:
predictions, targets = predict(model, test_loader)
save_dir = '/home/yousef.metwally/projects/UNet/output'
def save_ouput (out_dir: str, predictions, targets):
    for i, (pred, target) in enumerate(zip(predictions, targets)):
        pred_filename = os.path.join(out_dir, f'prediction_{i:03d}.mrc')
        target_filename = os.path.join(out_dir, f'target_{i:03d}.mrc')
        save_as_mrc(pred, pred_filename)
        save_as_mrc(target, target_filename)
        print(f'Saved prediction to {pred_filename} and target to {target_filename}')

In [31]:
!tensorboard --logdir='/home/yousef.metwally/projects/UNet/' --bind_all


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.16.2 at http://gtxr3.srv-local.mpi-dortmund.mpg.de:6006/ (Press CTRL+C to quit)
^C


In [36]:
reader = MRCTripletHandler()
v_path = '/home/yousef.metwally/projects/data/tomotwin_training_data/test/round01_t08_0XXX_000.mrc'
volume = reader.read_mrc_and_norm(v_path)
M = AutoEncoder(norm_name,norm_kwargs,32)
model = M.get_model()
checkpoint_path = '/home/yousef.metwally/projects/AutoEncoder/weights/run1/model_weights_epoch_200.pt'
checkpoint = torch.load(checkpoint_path)
state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)



RuntimeError: Error(s) in loading state_dict for Model:
	size mismatch for de_layer0.0.weight: copying a param with shape torch.Size([512, 256, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3, 3]).
	size mismatch for de_layer1.0.weight: copying a param with shape torch.Size([256, 128, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3, 3]).
	size mismatch for de_layer2.0.weight: copying a param with shape torch.Size([128, 64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 128, 3, 3, 3]).
	size mismatch for de_layer4.0.weight: copying a param with shape torch.Size([64, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 3, 3, 3]).
	size mismatch for de_layer4.2.weight: copying a param with shape torch.Size([1, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 1, 1, 1, 1]).

In [18]:
x = os.listdir("/home/yousef.metwally/projects/AutoEncoder/weights/run1")

In [19]:
x

['model_weights_epoch_197.pt',
 'model_weights_epoch_198.pt',
 'model_weights_epoch_199.pt',
 'model_weights_epoch_200.pt']

In [38]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Ensure the shapes match for concatenation
        diffZ = x2.size(2) - x1.size(2)
        diffY = x2.size(3) - x1.size(3)
        diffX = x2.size(4) - x1.size(4)

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet3D, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

        self._initialize_weights()

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

: 

In [3]:
import argparse
from typing import Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from tomotwin.modules.networks.torchmodel import TorchModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from tomotwin.modules.training.mrctriplethandler import MRCTripletHandler
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffZ = x2.size(2) - x1.size(2)
        diffY = x2.size(3) - x1.size(3)
        diffX = x2.size(4) - x1.size(4)

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2,
                        diffZ // 2, diffZ - diffZ // 2])
        
        #x = torch.cat([x2, x1], dim=1)
        return self.conv(x1)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, n_channels, out_channels, bilinear=True):
        super(UNet3D, self).__init__()
        self.n_channels = n_channels
        self.out_channels = out_channels
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.down5 = Down(1024,1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.up5 = Up(64,64)
        self.outc = OutConv(64, out_channels)

        self._initialize_weights()

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2)
        x = self.up5(x, x1)
        logits = self.outc(x)
        return logits
    

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


In [4]:
model = UNet3D(1,1)
out = model(torch.rand(12,1,37,37,37))