In [1]:
from lib.arch.contrastive_net import build_contrastive_net
from config import load_cfg
import  matplotlib.pyplot  as plt
import numpy as np
from torchsummary import summary
from tqdm.auto import tqdm
import torch.nn as nn
import torch
import tifffile as tif
from glob import glob
import re
import os

activation = {}
def get_activation(name):
    def hook(model, input, output):
        #check for whether registered at last layer of classifier
        activation[name] = output.detach()
    return hook


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

#load cfg and load state_dict
cfg = load_cfg("config/contrastive_net.yaml")
device = 'cuda'

model = build_contrastive_net(cfg)
model.to(device)
model.eval()

summary(model,(1,128,128,128))
print(model)






model:  Contrastive_net
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 32, 64, 64, 64]           4,032
          Identity-2       [-1, 32, 64, 64, 64]               0
               ELU-3       [-1, 32, 64, 64, 64]               0
            Conv3d-4       [-1, 64, 32, 32, 32]          55,360
          Identity-5       [-1, 64, 32, 32, 32]               0
               ELU-6       [-1, 64, 32, 32, 32]               0
            Conv3d-7       [-1, 96, 16, 16, 16]         165,984
          Identity-8       [-1, 96, 16, 16, 16]               0
               ELU-9       [-1, 96, 16, 16, 16]               0
           Conv3d-10       [-1, 96, 16, 16, 16]           9,312
          Encoder-11       [-1, 96, 16, 16, 16]               0
        AvgPool3d-12          [-1, 96, 8, 8, 8]               0
           Conv3d-13          [-1, 96, 8, 8, 8]         248,928
        AvgPool

In [4]:
exp_name= 'testbigger_data6_autoencoder_3layers_64batch'
cpkg_pth = f"out/weights/{exp_name}"
ckpts = sorted(glob(f'out/weights/{exp_name}/Epoch_*.pth'))
ckpts = sorted(ckpts,key=lambda x: int(re.search(r'Epoch_(\d+).pth', os.path.basename(x)).group(1)))
print(ckpts[-1])
#load the last ckpt
ckpt = torch.load(ckpts[-1])
weight_dict = ckpt['model']


# Remove "module." from the key name, if it exists
new_weight_dict = {}
for key, value in weight_dict.items():
    new_key = key.replace('module.', '')
    new_weight_dict[new_key] = value

#remove fc1 and fc2 from weight_dict
new_weight_dict = {k: v for k, v in new_weight_dict.items() if not k.startswith(('fc1', 'fc2'))}




out/weights/testbigger_data6_autoencoder_3layers_64batch/Epoch_1000.pth


  ckpt = torch.load(ckpts[-1])


In [None]:
model.load_state_dict(new_weight_dict,strict=False)

In [None]:

from lib.datasets.visor import get_valid_dataset,get_dataset
from torch.utils.data import Dataset,DataLoader
valid_dataset  = get_valid_dataset(cfg)
valid_loader = DataLoader(dataset=valid_dataset, 
                        batch_size= 6, 
                        num_workers=0, 
                        pin_memory = True,
                        drop_last = True
                        )
train_dataset  = get_dataset(cfg)
train_loader = DataLoader(dataset=train_dataset, 
                        batch_size= 12, 
                        num_workers=0, 
                        pin_memory = True,
                        drop_last = True
                        )

In [None]:
test_save_dir = f'./valid/{exp_name}'
os.makedirs(test_save_dir,exist_ok=True)
loss_fn = nn.L1Loss(reduction='mean')

valid_loss = []
input_images = []
pred_images = []

for input_data, _ in tqdm(train_loader):

    input_data = input_data.to('cuda')

    with torch.no_grad():
        preds = model(input_data)

    loss = loss_fn(preds, input_data)
    print(f"loss: {loss.item()}")
    valid_loss.append(loss.item())

    preds = preds.detach().cpu().numpy()
    preds = np.squeeze(preds)
    pred_images.append(preds)
    input_data = input_data.detach().cpu().numpy()
    input_data = np.squeeze(input_data)
    input_images.append(input_data)

valid_loss = sum(valid_loss) / len(valid_loss)

input_images = np.concatenate(input_images,axis=0)
pred_images = np.concatenate(pred_images,axis=0)



    


In [None]:
l = []
for idx in range(16):
    l.append(np.random.randint(0,20))
print(l)

In [None]:
#for each img in a batch
for idx in range(16):
    x = input_images[idx]
    re_x = pred_images[idx]
    residual = re_x - x
    print(x.shape)
    
    #compress to 2d
    x, re_x, residual = map(
                        lambda img: img[63], 
                        [x, re_x,residual]
                        )
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    for i, (data, title) in enumerate(zip(
        [x, re_x,residual],
        ["x", "re_x", "residual"]
    )):
        img = axs[i].imshow(data, cmap='viridis')
        axs[i].set_title(title)
        fig.colorbar(img, ax=axs[i])

    plt.tight_layout()
    plt.show()