In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model import HierarchicalProbUNet  # With updated classes
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

channels_per_block = (24, 48, 96, 192, 192, 192, 192, 192)
latent_dims = (4, 4, 4, 1)  # 4 latents, but only 3 used in HierarchicalCore
num_classes = 2
convs_per_block = 3
blocks_per_level = 3

loss_kwargs = {
    'type': 'geco',
    'top_k_percentage': 0.02,
    'deterministic_top_k': False,
    'kappa': 0.05,
    'decay': 0.99,
    'rate': 1e-2,
    'beta': None
}

model = HierarchicalProbUNet(
    latent_dims=latent_dims,
    channels_per_block=channels_per_block,
    num_classes=num_classes,
    down_channels_per_block=tuple(c // 2 for c in channels_per_block),
    activation_fn=nn.ReLU(),
    convs_per_block=convs_per_block,
    blocks_per_level=blocks_per_level,
    loss_kwargs=loss_kwargs
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

class DummyLIDCDataset(torch.utils.data.Dataset):
    def __init__(self, size=8882):
        self.size = size
    def __len__(self):
        return self.size
    def __getitem__(self, idx):
        img = torch.randn(1, 128, 128)
        seg = torch.randint(0, 2, (2, 128, 128)).float()
        mask = torch.ones(128, 128)
        return img, seg, mask

train_dataset = DummyLIDCDataset()
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

total_iterations = 240000
lr_steps = [60000, 120000, 180000]
step_idx = 0

for iteration, (img, seg, mask) in enumerate(train_loader):
    if iteration >= total_iterations:
        break

    img, seg, mask = img.to(device), seg.to(device), mask.to(device)

    if step_idx < len(lr_steps) and iteration >= lr_steps[step_idx]:
        new_lr = 1e-4 * (0.5 ** (step_idx + 1))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr
        step_idx += 1

    optimizer.zero_grad()
    loss_dict = model.loss(seg, img, mask)
    loss = loss_dict['supervised_loss']
    loss.backward()
    optimizer.step()

    if iteration % 1000 == 0:
        print(f"Iteration {iteration}/{total_iterations}, Loss: {loss.item():.4f}")

# Sampling and reconstruction
model.eval()
with torch.no_grad():
    sample_img = torch.randn(1, 1, 128, 128).to(device)
    samples = model.sample(sample_img, mean=False)
    print(f"Sample shape: {samples.shape}")

    sample_seg = torch.randint(0, 2, (1, 2, 128, 128)).float().to(device)
    recon = model.reconstruct(sample_seg, sample_img, mean=False)
    print(f"Reconstruction shape: {recon.shape}")
    

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 32 but got size 16 for tensor number 1 in the list.