In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from IPython.display import clear_output

In [None]:
import sys
model_path="/home/groups/ZuckermanLab/jalim/test_m2cO2vae/"
sys.path.append(model_path)
import importlib
import train_loops
import run
from utils import utils
import wandb
import logging
from pathlib import Path
from configs.config_LI204601 import config # Load Pretrained Model Configuration
from torchvision.utils import make_grid
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
print(f'Using Device: {device}')

In [None]:
importlib.reload(utils)

dset, loader, dset_test, loader_test = run.get_datasets_from_config(config) # get datasets specified by config
fig, axs = utils.plot_sample_data(loader) # Loading Training Set Images as per the batch size 
fig

In [None]:
config.model.encoder.n_channels = dset[0][0].shape[0]  # image channels
model = run.build_model_from_config(config)

### Get pre-trained model weights

In [None]:
pretrained_model_path = os.path.join(model_path, "wandb/offline-run-20240911_171039-egf1_1/files/model.pt")

In [None]:
model_checkpoint = torch.load(pretrained_model_path)
#model.to(device).cpu().train() # if keys don't match, try many combinations of putting it on and off cuda
model.to(device).train()
missing_keys, unexpected_keys = model.load_state_dict(model_checkpoint['model_state_dict'], strict=False)
assert all(['_basisexpansion' in k for k in missing_keys]) # checking that the only missing keys from the state_dict are this one type

In [None]:
from utils import eval_utils
importlib.reload(eval_utils)

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#model.eval().cpu() 
model.eval().to(device)
#model.to(device).cpu()
model.to(device)
x, y = next(iter(loader)) # One Batch: Training Set
#x, y = next(iter(loader_test)) # One Batch: Test Set

print(f"x shape: {x.shape}, device: {x.device}")
print(f"y shape: {y.shape}, device: {y.device}")


### Run on CPU & Visualize the reconstruction grids

In [None]:
try:
    reconstruct_grid_aligned = eval_utils.reconstruction_grid(model, x.to(device), align=True, device=device)
except Exception as e:
    print(f"Error on CPU: {e}")


In [None]:
fig, axs = plt.subplots(figsize=(20, 20))
axs.imshow(reconstruct_grid_aligned.numpy(), cmap='gray')
axs.set_title('Reconstruction Grid Aligned')
plt.show()