In [9]:
import numpy as np 
import matplotlib.pyplot as plt

import torch 
from torch import nn, optim 
import napari

import os 

import tifffile 

In [10]:
data_dir = '/home/confetti/mnt/data/processed/t1779/128roi_skip_gapped'
files = os.listdir(data_dir)
files = [file for file in files if file.endswith('.tif')]
files.sort()
files = [os.path.join(data_dir, file) for file in files]

print(len(files))
print(files[0])

128
/home/confetti/mnt/data/processed/t1779/128roi_skip_gapped/0001.tif


In [11]:
#for visualization
img = tifffile.imread(files[0])
thickness = 8
start =64
viewer = napari.Viewer(ndisplay=3)
for thick in range(2,16,2):
    projection = np.max(img[start:start + thick,:,:],axis=0)
    projection = np.squeeze(projection)
    viewer.add_image(projection, name = f"{thick}_mip")
viewer.add_image(img)
viewer.add_image(img[start], name = f"0_mip")

<Image layer '0_mip' at 0x7fd679526ff0>

In [12]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.files = os.listdir(data_dir)
        self.files = [file for file in self.files if file.endswith('.tif')]
        self.files.sort()
        self.files = [os.path.join(data_dir, file) for file in self.files]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = tifffile.imread(self.files[idx])
        img = np.array(img).astype(np.float32)
        img = torch.from_numpy(img)
        return img.unsqueeze(0)


In [13]:
class AutoEncoder3D(nn.Module):
    def __init__(self, in_channels=1, base_channels=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, base_channels, kernel_size=3, stride=2, padding=1),  # downsample by 2 
            nn.LeakyReLU(),
            nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1),  # downsample by 2
            nn.LeakyReLU(),
            nn.Conv3d(base_channels*2, base_channels*2, kernel_size=3, stride=2, padding=1),    # downsample by 2
            nn.LeakyReLU(),
            nn.Conv3d(base_channels*2, base_channels, kernel_size=3, stride=2, padding=1),    # downsample by 2
            nn.LeakyReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(base_channels*2, base_channels*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(base_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoEncoder3D(1, 32).to(device)
dataset = Dataset(data_dir)

from torch.utils.data import DataLoader
loader  = DataLoader(dataset,batch_size=8,shuffle=True) 
inputs = dataset[0].to(device)

output = model(inputs)
inputs.shape, output.shape

(torch.Size([1, 128, 128, 128]), torch.Size([1, 128, 128, 128]))

In [21]:
from tqdm import tqdm
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
epochs = 500
for epoch in range(epochs):
    for data in tqdm(loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{epochs} Loss: {loss.item()}')

100%|██████████| 16/16 [00:11<00:00,  1.40it/s]


Epoch 1/500 Loss: 26264.9140625


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 2/500 Loss: 19427.265625


100%|██████████| 16/16 [00:08<00:00,  2.00it/s]


Epoch 3/500 Loss: 10576.724609375


100%|██████████| 16/16 [00:07<00:00,  2.06it/s]


Epoch 4/500 Loss: 8945.751953125


100%|██████████| 16/16 [00:07<00:00,  2.08it/s]


Epoch 5/500 Loss: 7663.47216796875


100%|██████████| 16/16 [00:07<00:00,  2.11it/s]


Epoch 6/500 Loss: 6059.0791015625


100%|██████████| 16/16 [00:08<00:00,  1.93it/s]


Epoch 7/500 Loss: 7083.55517578125


100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 8/500 Loss: 5961.35986328125


100%|██████████| 16/16 [00:07<00:00,  2.05it/s]


Epoch 9/500 Loss: 6862.63037109375


100%|██████████| 16/16 [00:07<00:00,  2.07it/s]


Epoch 10/500 Loss: 5065.1826171875


100%|██████████| 16/16 [00:07<00:00,  2.08it/s]


Epoch 11/500 Loss: 10621.841796875


100%|██████████| 16/16 [00:07<00:00,  2.06it/s]


Epoch 12/500 Loss: 5756.251953125


100%|██████████| 16/16 [00:08<00:00,  1.93it/s]


Epoch 13/500 Loss: 5550.45703125


100%|██████████| 16/16 [00:08<00:00,  1.80it/s]


Epoch 14/500 Loss: 6292.509765625


100%|██████████| 16/16 [00:12<00:00,  1.31it/s]


Epoch 15/500 Loss: 6909.57470703125


100%|██████████| 16/16 [00:13<00:00,  1.15it/s]


Epoch 16/500 Loss: 5434.78662109375


100%|██████████| 16/16 [00:10<00:00,  1.53it/s]


Epoch 17/500 Loss: 6571.791015625


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 18/500 Loss: 6122.60986328125


100%|██████████| 16/16 [00:08<00:00,  1.84it/s]


Epoch 19/500 Loss: 6405.1396484375


100%|██████████| 16/16 [00:07<00:00,  2.03it/s]


Epoch 20/500 Loss: 5479.61669921875


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 21/500 Loss: 5992.27880859375


100%|██████████| 16/16 [00:07<00:00,  2.03it/s]


Epoch 22/500 Loss: 5134.8232421875


100%|██████████| 16/16 [00:07<00:00,  2.07it/s]


Epoch 23/500 Loss: 6187.7646484375


100%|██████████| 16/16 [00:08<00:00,  1.89it/s]


Epoch 24/500 Loss: 4303.3642578125


100%|██████████| 16/16 [00:10<00:00,  1.48it/s]


Epoch 25/500 Loss: 5669.3193359375


100%|██████████| 16/16 [00:07<00:00,  2.03it/s]


Epoch 26/500 Loss: 4633.3505859375


100%|██████████| 16/16 [00:07<00:00,  2.04it/s]


Epoch 27/500 Loss: 4627.6923828125


100%|██████████| 16/16 [00:07<00:00,  2.05it/s]


Epoch 28/500 Loss: 8590.5498046875


100%|██████████| 16/16 [00:07<00:00,  2.04it/s]


Epoch 29/500 Loss: 10543.587890625


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 30/500 Loss: 20379.455078125


100%|██████████| 16/16 [00:08<00:00,  1.96it/s]


Epoch 31/500 Loss: 7882.88671875


100%|██████████| 16/16 [00:07<00:00,  2.00it/s]


Epoch 32/500 Loss: 7365.107421875


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 33/500 Loss: 7451.984375


100%|██████████| 16/16 [00:07<00:00,  2.03it/s]


Epoch 34/500 Loss: 5837.73681640625


100%|██████████| 16/16 [00:07<00:00,  2.14it/s]


Epoch 35/500 Loss: 5352.11767578125


100%|██████████| 16/16 [00:07<00:00,  2.04it/s]


Epoch 36/500 Loss: 5321.5595703125


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 37/500 Loss: 6431.6806640625


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 38/500 Loss: 5450.22509765625


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 39/500 Loss: 5413.96484375


100%|██████████| 16/16 [00:05<00:00,  2.77it/s]


Epoch 40/500 Loss: 5497.9619140625


100%|██████████| 16/16 [00:05<00:00,  2.77it/s]


Epoch 41/500 Loss: 6542.99609375


100%|██████████| 16/16 [00:05<00:00,  2.72it/s]


Epoch 42/500 Loss: 5564.19287109375


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 43/500 Loss: 5393.1318359375


100%|██████████| 16/16 [00:05<00:00,  2.78it/s]


Epoch 44/500 Loss: 5552.50048828125


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Epoch 45/500 Loss: 5491.75244140625


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Epoch 46/500 Loss: 7431.37255859375


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 47/500 Loss: 5764.06005859375


100%|██████████| 16/16 [00:05<00:00,  2.72it/s]


Epoch 48/500 Loss: 5200.515625


100%|██████████| 16/16 [00:05<00:00,  2.70it/s]


Epoch 49/500 Loss: 6311.9169921875


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 50/500 Loss: 4628.23388671875


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 51/500 Loss: 10313.625


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 52/500 Loss: 5042.2353515625


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 53/500 Loss: 5759.939453125


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 54/500 Loss: 6375.171875


100%|██████████| 16/16 [00:08<00:00,  1.85it/s]


Epoch 55/500 Loss: 4114.06689453125


100%|██████████| 16/16 [00:07<00:00,  2.18it/s]


Epoch 56/500 Loss: 6563.431640625


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Epoch 57/500 Loss: 5009.51513671875


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 58/500 Loss: 5061.6337890625


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 59/500 Loss: 4956.4833984375


100%|██████████| 16/16 [00:05<00:00,  2.77it/s]


Epoch 60/500 Loss: 3607.0048828125


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Epoch 61/500 Loss: 5895.828125


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 62/500 Loss: 5300.318359375


100%|██████████| 16/16 [00:05<00:00,  2.90it/s]


Epoch 63/500 Loss: 5036.17578125


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 64/500 Loss: 4924.77099609375


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 65/500 Loss: 4253.3896484375


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 66/500 Loss: 5046.3916015625


100%|██████████| 16/16 [00:05<00:00,  2.72it/s]


Epoch 67/500 Loss: 6694.689453125


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 68/500 Loss: 4783.1884765625


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 69/500 Loss: 5547.5302734375


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 70/500 Loss: 6272.96240234375


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 71/500 Loss: 4566.72607421875


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Epoch 72/500 Loss: 10237.6220703125


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 73/500 Loss: 7411.7841796875


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 74/500 Loss: 4651.9013671875


100%|██████████| 16/16 [00:05<00:00,  2.74it/s]


Epoch 75/500 Loss: 5890.734375


100%|██████████| 16/16 [00:05<00:00,  2.72it/s]


Epoch 76/500 Loss: 4939.35107421875


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 77/500 Loss: 3527.76220703125


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 78/500 Loss: 5041.62353515625


100%|██████████| 16/16 [00:05<00:00,  2.72it/s]


Epoch 79/500 Loss: 5099.5107421875


100%|██████████| 16/16 [00:05<00:00,  2.77it/s]


Epoch 80/500 Loss: 6446.2548828125


100%|██████████| 16/16 [00:05<00:00,  2.77it/s]


Epoch 81/500 Loss: 5675.005859375


100%|██████████| 16/16 [00:05<00:00,  2.91it/s]


Epoch 82/500 Loss: 5553.0634765625


100%|██████████| 16/16 [00:05<00:00,  2.78it/s]


Epoch 83/500 Loss: 4471.02099609375


100%|██████████| 16/16 [00:05<00:00,  2.76it/s]


Epoch 84/500 Loss: 4037.86474609375


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 85/500 Loss: 6622.9619140625


100%|██████████| 16/16 [00:05<00:00,  2.70it/s]


Epoch 86/500 Loss: 5340.37109375


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 87/500 Loss: 7230.69287109375


100%|██████████| 16/16 [00:05<00:00,  2.75it/s]


Epoch 88/500 Loss: 6987.6611328125


100%|██████████| 16/16 [00:06<00:00,  2.66it/s]


Epoch 89/500 Loss: 6980.20166015625


100%|██████████| 16/16 [00:09<00:00,  1.66it/s]


Epoch 90/500 Loss: 5743.970703125


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 91/500 Loss: 7683.8369140625


100%|██████████| 16/16 [00:05<00:00,  2.83it/s]


Epoch 92/500 Loss: 4971.64453125


100%|██████████| 16/16 [00:05<00:00,  2.73it/s]


Epoch 93/500 Loss: 5796.1552734375


100%|██████████| 16/16 [00:05<00:00,  2.79it/s]


Epoch 94/500 Loss: 4881.02490234375


100%|██████████| 16/16 [00:07<00:00,  2.00it/s]


Epoch 95/500 Loss: 8206.5048828125


100%|██████████| 16/16 [00:08<00:00,  1.89it/s]


Epoch 96/500 Loss: 5971.1435546875


100%|██████████| 16/16 [00:08<00:00,  1.94it/s]


Epoch 97/500 Loss: 4179.8671875


100%|██████████| 16/16 [00:07<00:00,  2.11it/s]


Epoch 98/500 Loss: 4540.611328125


100%|██████████| 16/16 [00:07<00:00,  2.04it/s]


Epoch 99/500 Loss: 5088.72021484375


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 100/500 Loss: 3975.77001953125


100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 101/500 Loss: 4462.35205078125


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 102/500 Loss: 4640.37890625


100%|██████████| 16/16 [00:08<00:00,  1.99it/s]


Epoch 103/500 Loss: 17516.0859375


100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 104/500 Loss: 6572.7158203125


100%|██████████| 16/16 [00:07<00:00,  2.01it/s]


Epoch 105/500 Loss: 6489.04296875


100%|██████████| 16/16 [00:07<00:00,  2.04it/s]


Epoch 106/500 Loss: 5097.884765625


100%|██████████| 16/16 [00:07<00:00,  2.04it/s]


Epoch 107/500 Loss: 4066.921142578125


100%|██████████| 16/16 [00:07<00:00,  2.05it/s]


Epoch 108/500 Loss: 5167.46533203125


100%|██████████| 16/16 [00:08<00:00,  1.99it/s]


Epoch 109/500 Loss: 4514.103515625


100%|██████████| 16/16 [00:08<00:00,  1.98it/s]


Epoch 110/500 Loss: 5290.2646484375


100%|██████████| 16/16 [00:07<00:00,  2.02it/s]


Epoch 111/500 Loss: 6091.1015625


  6%|▋         | 1/16 [00:00<00:11,  1.35it/s]


KeyboardInterrupt: 

In [22]:
model.eval()

batch = next(iter(loader))
batch = batch.to(device)
preds = model(batch)
preds = preds.detach().cpu()
preds = np.squeeze(preds)
batch = batch.detach().cpu()
batch = np.squeeze(batch)

idx  = 0
pred = preds[idx]
input = batch[idx]
viewer.add_image(pred,name=f"recon_{idx}")
viewer.add_image(input,name=f"input_{idx}")

pred = np.max(pred,axis=0)
input = np.max(input,axis=0)

fig, axs =plt.subplots(3)
img1 = axs[0].imshow(input,cmap = 'viridis')
axs[0].set_title("input")
fig.colorbar(img1,ax = axs[0])
img2 = axs[1].imshow(pred,cmap = 'viridis')
axs[1].set_title("pred")
fig.colorbar(img2,ax = axs[1])
img3 = axs[2].imshow(input - pred,cmap = 'viridis')
axs[2].set_title("residual")
fig.colorbar(img3,ax = axs[2])
    


TypeError: max() received an invalid combination of arguments - got (out=NoneType, axis=int, ), but expected one of:
 * ()
 * (Tensor other)
 * (int dim, bool keepdim = False)
      didn't match because some of the keywords were incorrect: out, axis
 * (name dim, bool keepdim = False)
      didn't match because some of the keywords were incorrect: out, axis
