In [1]:
from tomoSegmentPipeline.utils.common import read_array, write_array
from tomoSegmentPipeline.utils import setup
from cryoS2Sdrop.dataloader import singleCET_dataset
from cryoS2Sdrop.model import Denoising_UNet

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchsummary import summary

import os
from torch.utils.data import Dataset, DataLoader

PARENT_PATH = setup.PARENT_PATH
ISONET_PATH = os.path.join(PARENT_PATH, 'data/isoNet/')

%matplotlib inline
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [28]:
# cet_path = os.path.join(PARENT_PATH, 'data/raw_cryo-ET/tomo02.mrc') 
cet_path = os.path.join(PARENT_PATH, 'data/S2SDenoising/dummy_tomograms/tomo02_dummy.mrc')

p = 0.3
n_samples = 50
vol_scale_factor = 8
subtomo_length = 96

my_dataset = singleCET_dataset(cet_path, subtomo_length=subtomo_length, p=p, n_samples=n_samples, volumetric_scale_factor=vol_scale_factor)

In [31]:
dloader = DataLoader(my_dataset, batch_size=2, shuffle=True, pin_memory=True)

# Model

In [32]:
for batch in dloader:
    subtomo, target, mask = batch
    print(subtomo.shape)
    break

torch.Size([2, 50, 96, 96, 96])


In [33]:
n_features = 48
model = Denoising_UNet(None, 0, n_samples, n_features, p)

In [34]:
model(subtomo).shape

torch.Size([2, 50, 96, 96, 96])

In [35]:
summary(model, (n_samples, subtomo_length, subtomo_length, subtomo_length), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
     PartialConv3d-1       [-1, 48, 96, 96, 96]          64,848
     PartialConv3d-2       [-1, 48, 96, 96, 96]          62,256
         LeakyReLU-3       [-1, 48, 96, 96, 96]               0
         MaxPool3d-4       [-1, 48, 48, 48, 48]               0
     PartialConv3d-5       [-1, 48, 48, 48, 48]          62,256
         LeakyReLU-6       [-1, 48, 48, 48, 48]               0
         MaxPool3d-7       [-1, 48, 24, 24, 24]               0
     PartialConv3d-8       [-1, 48, 24, 24, 24]          62,256
         LeakyReLU-9       [-1, 48, 24, 24, 24]               0
        MaxPool3d-10       [-1, 48, 12, 12, 12]               0
    PartialConv3d-11       [-1, 48, 12, 12, 12]          62,256
        LeakyReLU-12       [-1, 48, 12, 12, 12]               0
        MaxPool3d-13          [-1, 48, 6, 6, 6]               0
    PartialConv3d-14          [-1, 48, 

# Loss function

In [14]:
y_wiggle = (1-mask)*model(subtomo)
y_wiggle.shape

torch.Size([8, 35, 96, 96, 96])

In [15]:
target.shape

torch.Size([8, 35, 96, 96, 96])

In [8]:
loss = torch.nn.MSELoss(reduction='sum')
loss(y_wiggle, target)

tensor(11296888., grad_fn=<MseLossBackward0>)

In [12]:
y_wiggle_flat = y_wiggle.flatten(start_dim=2)
target_flat = target.flatten(start_dim=2)

pred_minus_target = y_wiggle_flat - target_flat
pred_minus_target.shape

torch.Size([5, 30, 262144])

In [15]:
torch.norm(pred_minus_target, p=2, dim=2).sum()

tensor(41155.6641, grad_fn=<SumBackward0>)

In [18]:
torch.linalg.vector_norm(y_wiggle-target, ord=2, dim=(2, 3, 4)).sum(1).mean()

tensor(18095.8203, grad_fn=<MeanBackward0>)

In [19]:
from cryoS2Sdrop.losses import self2self_L2Loss
loss = self2self_L2Loss()

In [20]:
loss(y_wiggle, target)

tensor(18095.8203, grad_fn=<MeanBackward0>)