In [1]:
%load_ext autoreload

In [120]:
%autoreload
import copy
import sys
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from functorch import vmap
from tqdm.auto import tqdm

sys.path.append("/scratch/gpfs/js5013/programs/cfilt/")
from cfilt.utils import *

In [84]:
def remove_data_parallel(old_state_dict):
    new_state_dict = OrderedDict()

    for k, v in old_state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v

    return new_state_dict

In [99]:
class Cascade(nn.Module):
    def __init__(self, base_model, load_paths):
        super().__init__()
        self.models = [copy.deepcopy(base_model) for _ in range(len(load_paths))]
        for i, m in tqdm(enumerate(self.models), total=len(load_paths)):
            state_dict = torch.load(load_paths[i])
            state_dict = remove_data_parallel(state_dict)
            m.load_state_dict(state_dict)
        self.models = nn.ModuleList(self.models)
        self.intermediates = []

    def get_intermediates(self):
        return self.intermediates

    def forward(self, x):
        self.intermediates.clear()
        for m in self.models:
            x = m(x)
            self.intermediates.append(x)
        return x

In [86]:
cds = CDS(
    [1, 4, 8, 16, 32, 80, 160, 320, 1600],
    "jx",
    "../out/",
    normalize=True,
    transform=transforms.ToTensor(),
)

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

  0%|          | 0/148 [00:00<?, ?it/s]

In [100]:
ac = nn.Sequential(
    ConvXCoder((125, 133), 1, 4, 4, 3, "cuda"),
    ConvXCoder((125, 133), 4, 1, 4, 3, "cuda"),
)

In [101]:
csc = Cascade(
    ac,
    [
        f"../models/convxcoder/{low}-{high}-3l-4c-new.pt"
        for low, high in zip(
            [1, 4, 8, 16, 32, 80, 160, 320], [4, 8, 16, 32, 80, 160, 320, 1600]
        )
    ],
)

  0%|          | 0/8 [00:00<?, ?it/s]

In [102]:
x = next(iter(cds))

In [103]:
out = csc(x[0].cuda())

In [104]:
im = csc.get_intermediates()

In [109]:
len(x[1:])

8

In [107]:
len(im)

8

In [110]:
loss_fn = MS_SSIM_L1_Loss(alpha=0.7)

In [122]:
torch.stack(im).shape

torch.Size([8, 1, 1, 125, 133])

In [112]:
loss_fn(x[1:], im)

AttributeError: 'list' object has no attribute 'shape'