In [11]:
mode = "LS" # LS or IS (latent/image)

import torch
import json
import sys
import os
utils_path = os.path.abspath(os.path.join('../'))
if utils_path not in sys.path:
    sys.path.append(utils_path)
from utils.notebookutils import SimaseUSLatentDataset,SimaseUSVideoDataset, SiameseNetwork, model_forward_to_corrcoeff,model_forward_to_pred, model_forward_to_bin_pred

normalization =lambda x: (x  - x.min())/(x.max() - x.min()) * 2 - 1  # should be -1 to 1 due to way we trained the model

#datasets
if mode == "LS": 
    ds_test_dynamic = SimaseUSLatentDataset(phase="testing", transform=normalization, latents_csv="/vol/ideadata/at70emic/projects/TMI23/data/diffusion/dynamic/FileList.csv", training_latents_base_path="/vol/ideadata/at70emic/projects/TMI23/data/diffusion/dynamic/Latents", in_memory=False, generator_seed=0)
    ds_test_psax = SimaseUSLatentDataset(phase="testing", transform=normalization, latents_csv= "/vol/ideadata/at70emic/projects/TMI23/data/diffusion/PSAX/FileList.csv", training_latents_base_path= "/vol/ideadata/at70emic/projects/TMI23/data/diffusion/PSAX/Latents", in_memory=False, generator_seed=0)
    ds_test_a4c = SimaseUSLatentDataset(phase="testing", transform=normalization, latents_csv="/vol/ideadata/at70emic/projects/TMI23/data/diffusion/A4C/FileList.csv", training_latents_base_path= "/vol/ideadata/at70emic/projects/TMI23/data/diffusion/A4C/Latents", in_memory=False, generator_seed=0)
else: 

    ds_test_dynamic = SimaseUSVideoDataset(phase="testing", transform=normalization, latents_csv="/vol/ideadata/at70emic/datasets/EchoNet-Dynamic/FileList.csv", training_latents_base_path= "/vol/ideadata/at70emic/datasets/EchoNet-Dynamic/Videos", in_memory=False, generator_seed=0)
    ds_test_psax = SimaseUSVideoDataset(phase="testing", transform=normalization, latents_csv= "/vol/ideadata/at70emic/datasets/Echonet-Peds/PSAX/processed/FileList.csv", training_latents_base_path= "/vol/ideadata/at70emic/datasets/Echonet-Peds/PSAX/processed/Videos", in_memory=False, generator_seed=0)
    ds_test_a4c = SimaseUSVideoDataset(phase="testing", transform=normalization, latents_csv="/vol/ideadata/at70emic/datasets/Echonet-Peds/A4C/processed/FileList.csv", training_latents_base_path= "/vol/ideadata/at70emic/datasets/Echonet-Peds/A4C/processed/Videos", in_memory=False, generator_seed=0)

datasets = {"d": ds_test_dynamic, "p": ds_test_psax, "a": ds_test_a4c}
ds_name_to_name = {"d": "Dynamic", "p": "PSAX", "a": "A4C"}

#load models
models = {"a": None, "d": None, "p": None}
for model_name, model_ending in zip(["a", "d", "p"], ["a4c", "Dynamic", "psax"]): 
    model_basepath = f"/vol/ideadata/ed52egek/pycharm/privatis_us/archive/{model_ending}{mode}Best"
    with open(os.path.join(model_basepath, "config.json")) as config:
        config = config.read()

    # parse config
    config = json.loads(config)
    net = SiameseNetwork(network=config['siamese_architecture'], in_channels=config['n_channels'], n_features=config['n_features'])
    net.eval()
    net = net.cuda()
    best_name = [x for x in os.listdir(model_basepath) if x.endswith("best_network.pth")][0]
    net.load_state_dict(torch.load(os.path.join(model_basepath, best_name)))
    models[model_name] = net

Set testing dataset seed to 0
Set testing dataset seed to 0
Set testing dataset seed to 0




In [41]:
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision
from einops import rearrange
from PIL import Image

mode = "LV"
rdm = torch.randn((3, 1, 5))
for frame_num in range(5):
    frame = ds_test_dynamic[0][frame_num*10]

    if mode == "RDM": 
        frame = torch.randn((3, 112//8, 112//8))
    if mode == "LV":
        frame = rdm[:,:, torch.randperm(5)]# + 0.05 * torch.randn((3, 1, 5))

    frame = ((frame + 1) * 127.5).to(torch.uint8)
    if frame.size()[-1] != 112:
        frame = Resize(112, interpolation=torchvision.transforms.InterpolationMode.NEAREST,)(frame)
        #pass

    image = Image.fromarray(rearrange(frame.numpy(), "c h w -> h w c"))
    image.save(f"USframe{frame_num}{mode}.png")

In [37]:
rdm[:,:, torch.randperm(len(rdm))]

tensor([[[ 0.8744, -1.4230,  0.1571]],

        [[ 0.6223, -1.4236, -1.0518]],

        [[ 0.2928, -0.6817,  1.0972]]])