In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.networks.nets import UNet
from monai.transforms import AsDiscrete
from monai.metrics import DiceMetric
from loss import VolumeAwareTverskyLoss, compute_metrics

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model definition
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=5,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

# Dummy loader (replace with actual DataLoader)
def dummy_loader(num_batches=2):
    for _ in range(num_batches):
        x = torch.randn(2, 1, 64, 64, 64).to(device)
        y = torch.randint(0, 5, (2, 64, 64, 64)).to(device)
        y_onehot = F.one_hot(y, num_classes=5).permute(0, 4, 1, 2, 3).float()
        yield x, y_onehot, y

# Loss function
loss_fn = VolumeAwareTverskyLoss(
    include_background=False,
    softmax=True,
    tversky_alpha=0.3,
    tversky_beta=0.7,
    loss_weights=(1.0, 0.0),
).to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
epochs = 1
for epoch in range(epochs):
    model.train()
    for i, (inputs, targets_onehot, targets) in enumerate(dummy_loader()):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss_out = loss_fn(outputs, targets_onehot)

        loss = loss_out["loss"]
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {loss.item():.4f}")
        print("Tversky Loss:", loss_out["tversky_loss"].item())
        print("Per-class TP:", loss_out["per_class_tp"].cpu().numpy())
        print("Per-class FP:", loss_out["per_class_fp"].cpu().numpy())
        print("Per-class FN:", loss_out["per_class_fn"].cpu().numpy())

        # Metrics
        metrics = compute_metrics(
            loss_out["per_class_tp"],
            loss_out["per_class_fp"],
            loss_out["per_class_fn"]
        )

        print("--- Per-Class Metrics ---")
        for c in range(5):
            print(f"Class {c}: Dice={metrics['dice'][c]:.4f}, Sens={metrics['sensitivity'][c]:.4f}, Spec={metrics['specificity'][c]:.4f}")

# Visualization (optional)
model.eval()
with torch.no_grad():
    inputs, targets_onehot, targets = next(dummy_loader(1))
    preds = model(inputs)
    pred_mask = torch.argmax(torch.softmax(preds, dim=1), dim=1)
    slice_idx = 32

    fig, axes = plt.subplots(2, 4, figsize=(15, 5))
    for c in range(1, 5):
        axes[0, c - 1].imshow(pred_mask[0, slice_idx].cpu() == c, cmap='Reds')
        axes[0, c - 1].set_title(f"Pred Class {c}")
        axes[1, c - 1].imshow(targets[0, slice_idx].cpu() == c, cmap='Greens')
        axes[1, c - 1].set_title(f"GT Class {c}")
        for ax in [axes[0, c - 1], axes[1, c - 1]]:
            ax.axis('off')

    plt.suptitle("Center Slice Prediction vs Ground Truth")
    plt.tight_layout()
    plt.show()


  from .autonotebook import tqdm as notebook_tqdm


ImportError: cannot import name 'compute_meandice' from 'monai.metrics' (c:\Users\DARSHAN PARMAR\.conda\envs\brats\lib\site-packages\monai\metrics\__init__.py)