In [1]:
# %%
# ===============================
# Step 1: Import Necessary Libraries
# ===============================
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg16
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
from torch.cuda.amp import GradScaler, autocast
import os
import subprocess

In [2]:
# %%
# ===============================
# Step 2: Load and Inspect the Dataset
# ===============================

# Load dataset
with open('./reformatedNDDs/dataset_16k_20k.pkl', 'rb') as f:
    dataset_data = pickle.load(f)

train_x, train_y = dataset_data['X_train'], dataset_data['Y_train']
val_x, val_y = dataset_data['X_val'], dataset_data['Y_val']
test_x, test_y = dataset_data['X_test'], dataset_data['Y_test']

print(f"X_train: {train_x.shape}")  # Expected: (65, 10, 1, 160, 280)
print(f"Y_train: {train_y.shape}")  # Expected: (65, 20, 1, 160, 280)

X_train: (65, 10, 1, 160, 280)
Y_train: (65, 20, 1, 160, 280)


In [3]:
# %%
# ===============================
# Step 3: Check GPU Availability
# ===============================

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("GPU is not available.")

GPU is available: Tesla V100-SXM2-32GB


In [4]:
# %%
# ===============================
# Step 4: Define Data Augmentation and Dataset Classes
# ===============================

class FrameTransform:
    def __init__(self):
        """
        Applies random horizontal flip and rotation to each frame.
        """
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=15)
        ])

    def __call__(self, xy):
        x, y = xy  # x and y are tensors: (frames, channels, H, W)

        frames_x = []
        frames_y = []
        for frame_x, frame_y in zip(x, y):
            # frame_x and frame_y are tensors: (channels, H, W)

            # Apply transforms
            frame_x = self.transform(frame_x)
            frame_y = self.transform(frame_y)

            frames_x.append(frame_x)
            frames_y.append(frame_y)

        x_transformed = torch.stack(frames_x)  # (frames, channels, H, W)
        y_transformed = torch.stack(frames_y)

        return x_transformed, y_transformed

class NeuriteGrowthDataset(Dataset):
    def __init__(self, inputs, targets, transform=None):
        """
        Custom Dataset for Neurite Growth sequences.

        Args:
            inputs (np.ndarray): Healthy sequences with shape (samples, frames, channels, H, W).
            targets (np.ndarray): Disorder-affected sequences with shape (samples, frames, channels, H, W).
            transform (callable, optional): Transform to be applied on a sample.
        """
        self.inputs = inputs
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        x = self.inputs[idx]  # Shape: (10, 1, 160, 280)
        y = self.targets[idx]  # Shape: (20, 1, 160, 280)

        # Convert to torch tensors
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
            y = torch.from_numpy(y).float()
        else:
            x = x.clone().detach().float()
            y = y.clone().detach().float()

        # Normalize to [-1, 1]
        x = x * 2 - 1
        y = y * 2 - 1

        # Ensure the tensor has shape (frames, channels, H, W)
        if x.dim() == 4 and x.shape[1] == 1:
            pass  # Already in (frames, channels, H, W)
        elif x.dim() == 4 and x.shape[1] > 1:
            # Convert multi-channel to single channel by averaging if necessary
            x = x.mean(dim=1, keepdim=True)  # (frames, 1, H, W)
            y = y.mean(dim=1, keepdim=True)
        else:
            raise ValueError(f"Expected x and y to have 4 dimensions, got {x.dim()} and {y.dim()}")

        # Apply transforms
        if self.transform:
            x, y = self.transform((x, y))

        # Debugging: Print shapes
        # Uncomment the following line to enable debugging prints
        # print(f"Dataset __getitem__ - x shape: {x.shape}, y shape: {y.shape}")

        return x, y

In [5]:
# %%
# ===============================
# Step 5: Define the Enhanced U-Net Architecture
# ===============================

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        emb = torch.exp(-torch.arange(half_dim, dtype=torch.float32, device=device) * math.log(10000) / (half_dim - 1))
        emb = time[:, None].float() * emb[None, :]  # (batch_size, half_dim)
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # (batch_size, dim)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)

        self.match_channels = None
        if in_channels != out_channels or stride !=1:
            self.match_channels = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        residual = x
        out = F.relu(self.conv1(x))
        out = self.conv2(out)

        if self.match_channels is not None:
            residual = self.match_channels(residual)

        out += residual
        return F.relu(out)

class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=1, cond_channels=1, base_channels=64):
        super(ConditionalUNet, self).__init__()
        
        # Encoding path with downsampling
        self.enc1 = ResidualBlock(in_channels + cond_channels, base_channels, stride=1)  # (batch_size,64,10,160,280)
        self.enc2 = ResidualBlock(base_channels, base_channels * 2, stride=2)  # (batch_size,128,5,80,140)
        self.enc3 = ResidualBlock(base_channels * 2, base_channels * 4, stride=2)  # (batch_size,256,3,40,70)

        # Decoding path with upsampling
        self.dec3 = nn.ConvTranspose3d(
            base_channels * 4, 
            base_channels * 2, 
            kernel_size=3, 
            stride=2, 
            padding=1, 
            output_padding=0
        )  # Upsamples from 3 to 5 frames
        self.dec2 = nn.ConvTranspose3d(
            base_channels * 2, 
            base_channels, 
            kernel_size=4, 
            stride=2, 
            padding=1, 
            output_padding=0
        )  # Upsamples from 5 to 10 frames
        self.dec1 = nn.Conv3d(base_channels, in_channels, kernel_size=3, padding=1)  # Final output

        # Time embedding
        self.time_embed = SinusoidalPosEmb(base_channels * 4)  # emb_dim=256

    def forward(self, x, cond, t):
        # Concatenate input and conditioning information along channels
        x = torch.cat([x, cond], dim=1)  # (batch_size, 2, 10, 160, 280)

        # Time embedding
        time_emb = self.time_embed(t)  # (batch_size,256)
        time_emb = time_emb.view(-1, time_emb.shape[1], 1, 1, 1)  # (batch_size,256,1,1,1)

        # Encoding path
        e1 = self.enc1(x)  # (batch_size,64,10,160,280)
        e2 = self.enc2(e1)  # (batch_size,128,5,80,140)
        e3 = self.enc3(e2)  # (batch_size,256,3,40,70)

        # Add time embedding at the deepest layer
        e3 = e3 + time_emb  # Broadcasting to (batch_size,256,3,40,70)

        # Decoding path
        d3 = F.relu(self.dec3(e3))  # (batch_size,128,5,80,140)

        # Check and pad d3 to match e2's spatial dimensions
        if d3.size(3) != e2.size(3) or d3.size(4) != e2.size(4):
            pad_h = e2.size(3) - d3.size(3)
            pad_w = e2.size(4) - d3.size(4)
            # Pad (W_left, W_right, H_left, H_right, D_left, D_right)
            # Here, pad on the right side
            d3 = F.pad(d3, (0, pad_w, 0, pad_h, 0, 0))
            # print(f"Padded d3 to shape: {d3.shape}")

        # Skip connection from e2
        if d3.shape != e2.shape:
            print(f"d3 shape: {d3.shape}, e2 shape: {e2.shape}")  # Debugging
            raise ValueError(f"Mismatch in shape during skip connection: d3={d3.shape}, e2={e2.shape}")
        d3 = d3 + e2  # (batch_size,128,5,80,140)

        d2 = F.relu(self.dec2(d3))  # (batch_size,64,10,160,280)

        # Skip connection from e1
        if d2.shape != e1.shape:
            print(f"d2 shape: {d2.shape}, e1 shape: {e1.shape}")  # Debugging
            raise ValueError(f"Mismatch in shape during skip connection: d2={d2.shape}, e1={e1.shape}")
        d2 = d2 + e1  # (batch_size,64,10,160,280)

        out = self.dec1(d2)  # (batch_size,1,10,160,280)

        return out  # (batch_size,1,10,160,280)

In [6]:
# %%
# ===============================
# Step 6: Define the Diffusion Process
# ===============================

def cosine_beta_schedule(timesteps, s=0.008):
    """
    Generates a cosine beta schedule.

    Args:
        timesteps (int): Number of diffusion steps.
        s (float, optional): Small constant to prevent division by zero. Defaults to 0.008.

    Returns:
        torch.Tensor: Beta schedule.
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    betas = torch.clip(betas, 0, 0.999)
    return betas

class Diffusion:
    def __init__(self, timesteps=1000, device='cuda'):
        """
        Diffusion process for adding and removing noise.

        Args:
            timesteps (int, optional): Number of diffusion steps. Defaults to 1000.
            device (str, optional): Device to run on. Defaults to 'cuda'.
        """
        self.device = device
        self.timesteps = timesteps
        self.betas = cosine_beta_schedule(timesteps).to(device)

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod).to(device)

    def add_noise(self, x0, t):
        """
        Adds noise to the data at timestep t.

        Args:
            x0 (torch.Tensor): Original data.
            t (torch.Tensor): Timestep indices.

        Returns:
            tuple: Noisy data and the noise added.
        """
        noise = torch.randn_like(x0).to(self.device)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1, 1)
        xt = sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise
        return xt, noise  # Both xt and noise have the same shape as x0

In [7]:
# %%
# ===============================
# Step 7: Define Loss Functions
# ===============================

# Use a pre-trained VGG model for perceptual loss
vgg = vgg16(pretrained=True).features[:16].to(device).eval()

def perceptual_loss(pred, target):
    """
    Computes perceptual loss using VGG features.

    Args:
        pred (torch.Tensor): Predicted noise.
        target (torch.Tensor): Actual noise.

    Returns:
        torch.Tensor: Perceptual loss.
    """
    # Convert to 3 channels if necessary
    if pred.shape[1] == 1:
        pred_rgb = pred.repeat(1, 3, 1, 1, 1)
        target_rgb = target.repeat(1, 3, 1, 1, 1)
    else:
        pred_rgb = pred
        target_rgb = target

    # Normalize for VGG
    mean = torch.tensor([0.485, 0.456, 0.406], device=pred.device).view(1, 3, 1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=pred.device).view(1, 3, 1, 1, 1)

    pred_norm = (pred_rgb + 1) / 2  # [0,1]
    pred_norm = (pred_norm - mean) / std

    target_norm = (target_rgb + 1) / 2
    target_norm = (target_norm - mean) / std

    # Reshape to (batch_size*frames, 3, H, W) to pass through VGG
    pred_norm = pred_norm.view(-1, 3, pred_norm.shape[3], pred_norm.shape[4])
    target_norm = target_norm.view(-1, 3, target_norm.shape[3], target_norm.shape[4])

    with torch.no_grad():
        target_features = vgg(target_norm)

    pred_features = vgg(pred_norm)
    return F.mse_loss(pred_features, target_features)

l1_loss = nn.L1Loss()

def combined_loss(pred, target, alpha=0.7):
    """
    Combines L1 loss and perceptual loss.

    Args:
        pred (torch.Tensor): Predicted noise.
        target (torch.Tensor): Actual noise.

    Returns:
        torch.Tensor: Combined loss.
    """
    l1 = l1_loss(pred, target)
    perceptual = perceptual_loss(pred, target)
    return alpha * l1 + (1 - alpha) * perceptual




In [8]:
# %%
# ===============================
# Step 8: Initialize Datasets and DataLoaders with Augmentation
# ===============================

batch_size = 16  # Adjust based on your GPU memory

# Create Dataset instances with data augmentation for training
train_dataset = NeuriteGrowthDataset(train_x, train_y, transform=FrameTransform())
val_dataset = NeuriteGrowthDataset(val_x, val_y)  # No augmentation for validation
test_dataset = NeuriteGrowthDataset(test_x, test_y)  # No augmentation for testing

# Create DataLoaders with optimizations
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=4, 
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=4, 
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=4, 
    pin_memory=True
)

In [9]:
# %%
# ===============================
# Step 9: Initialize the Model, Optimizer, Diffusion Process, and Scaler
# ===============================

model = ConditionalUNet(in_channels=1, cond_channels=1, base_channels=64).to(device)

# Wrap the model with DataParallel to use multiple GPUs
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)
    
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, verbose=True)
diffusion = Diffusion(timesteps=1000, device=device)
scaler = GradScaler()  # For mixed precision training

In [None]:
# %%
# ===============================
# Step 10: Define the Training Loop with Mixed Precision and Perceptual Loss
# ===============================
        
num_epochs = 200  # Adjust as needed

# Optional: Create a directory to save checkpoints
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize the best loss variable
best_loss = float('inf')  # Set to infinity so the first loss is always lower

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for healthy_seq, disorder_seq in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        healthy_seq = healthy_seq.to(device)  # Shape: (batch_size, 10, 1, 160, 280)
        disorder_seq = disorder_seq.to(device)  # Shape: (batch_size, 20, 1, 160, 280)

        # Permute to (batch_size, channels, frames, height, width)
        healthy_seq = healthy_seq.permute(0, 2, 1, 3, 4)  # Shape: (batch_size, 1, 10, 160, 280)
        disorder_seq = disorder_seq.permute(0, 2, 1, 3, 4)  # Shape: (batch_size, 1, 20, 160, 280)

        # Debugging: Print shapes after permutation
        # Uncomment the following line to enable debugging prints
        # print(f"After permute - healthy_seq: {healthy_seq.shape}, disorder_seq: {disorder_seq.shape}")

        # Upsample healthy_seq to match frames dimension of disorder_seq if necessary
        if healthy_seq.size(2) != disorder_seq.size(2):
            healthy_seq = F.interpolate(
                healthy_seq,
                size=(disorder_seq.size(2), disorder_seq.size(3), disorder_seq.size(4)),
                mode='trilinear',
                align_corners=False
            )  # Shape: (batch_size, 1, 20, 160, 280)
            # Debugging: Print shapes after interpolation
            # Uncomment the following line to enable debugging prints
            # print(f"After interpolation - healthy_seq: {healthy_seq.shape}")

        batch_size_current = disorder_seq.size(0)
        t = torch.randint(0, diffusion.timesteps, (batch_size_current,), device=device).long()

        # Add noise to disorder-affected sequence
        x_t, noise = diffusion.add_noise(disorder_seq, t)  # Both x_t and noise: (batch_size, 1, 20, 160, 280)

        optimizer.zero_grad()

        with autocast():
            # Predict the noise given the noisy input and healthy sequence
            predicted_noise = model(x_t, healthy_seq, t)  # Shape: (batch_size, 1, 10, 160, 280)
            
            # Debugging: Check if predicted_noise is None
            if predicted_noise is None:
                raise ValueError("Model returned None for predicted_noise")
            if noise is None:
                raise ValueError("Noise tensor is None")
            # print(f"predicted_noise shape: {predicted_noise.shape}")
            # print(f"noise shape: {noise.shape}")

            # Compute combined loss (L1 + perceptual loss)
            loss = combined_loss(predicted_noise, noise)  # Scalar

        # Scale the loss and backpropagate
        scaler.scale(loss).backward()

        # Optional: Gradient Clipping to prevent exploding gradients
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")
        
    # Use the scheduler to adjust the learning rate based on the current loss
    scheduler.step(avg_loss)  # Pass the average loss (or validation loss) to the scheduler

    # Optional: Save model checkpoint if the current loss is lower than the best loss
    if avg_loss < best_loss:
        best_loss = avg_loss  # Update the best loss
        checkpoint_path = os.path.join(checkpoint_dir, f'conditional_unet_best_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path} with a loss of {avg_loss:.6f}")

    # Optional: Save model checkpoints periodically
    if (epoch + 1) % 10 == 0:
        subprocess.run(['nvidia-smi'])
        # checkpoint_path = os.path.join(checkpoint_dir, f'conditional_unet_epoch_{epoch+1}.pth')
        # torch.save(model.state_dict(), checkpoint_path)
        # print(f"Checkpoint saved at {checkpoint_path}")

Epoch 1/200: 100%|██████████| 5/5 [00:05<00:00,  1.12s/it]


Epoch [1/200], Loss: 7.677693
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_1.pth with a loss of 7.677693


Epoch 2/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [2/200], Loss: 7.081687
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_2.pth with a loss of 7.081687


Epoch 3/200: 100%|██████████| 5/5 [00:03<00:00,  1.51it/s]


Epoch [3/200], Loss: 6.519331
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_3.pth with a loss of 6.519331


Epoch 4/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [4/200], Loss: 5.908103
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_4.pth with a loss of 5.908103


Epoch 5/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [5/200], Loss: 5.910053


Epoch 6/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [6/200], Loss: 5.173400
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_6.pth with a loss of 5.173400


Epoch 7/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [7/200], Loss: 4.995109
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_7.pth with a loss of 4.995109


Epoch 8/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [8/200], Loss: 4.847290
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_8.pth with a loss of 4.847290


Epoch 9/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [9/200], Loss: 4.878041


Epoch 10/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [10/200], Loss: 4.574693
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_10.pth with a loss of 4.574693
Sun Sep 29 15:42:26 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:B3:00.0 Off |                    0 |
| N/A   52C    P0              74W / 300W |  21091MiB / 32768MiB |      1%      Default |
|                                         |                      |                  N/A |
+---------------------------------

Epoch 11/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [11/200], Loss: 4.663915


Epoch 12/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [12/200], Loss: 4.505817
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_12.pth with a loss of 4.505817


Epoch 13/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [13/200], Loss: 3.937246
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_13.pth with a loss of 3.937246


Epoch 14/200: 100%|██████████| 5/5 [00:03<00:00,  1.47it/s]


Epoch [14/200], Loss: 4.066898


Epoch 15/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [15/200], Loss: 4.025561


Epoch 16/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [16/200], Loss: 4.426147


Epoch 17/200: 100%|██████████| 5/5 [00:03<00:00,  1.47it/s]


Epoch [17/200], Loss: 3.789548
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_17.pth with a loss of 3.789548


Epoch 18/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [18/200], Loss: 3.706849
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_18.pth with a loss of 3.706849


Epoch 19/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [19/200], Loss: 4.316695


Epoch 20/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [20/200], Loss: 3.555432
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_20.pth with a loss of 3.555432
Sun Sep 29 15:43:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:B3:00.0 Off |                    0 |
| N/A   54C    P0              75W / 300W |  21091MiB / 32768MiB |     28%      Default |
|                                         |                      |                  N/A |
+---------------------------------

Epoch 21/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [21/200], Loss: 4.161513


Epoch 22/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [22/200], Loss: 3.597090


Epoch 23/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [23/200], Loss: 3.944292


Epoch 24/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [24/200], Loss: 4.159632


Epoch 25/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [25/200], Loss: 3.430937
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_25.pth with a loss of 3.430937


Epoch 26/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [26/200], Loss: 3.603115


Epoch 27/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [27/200], Loss: 3.509348


Epoch 28/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [28/200], Loss: 3.427119
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_28.pth with a loss of 3.427119


Epoch 29/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [29/200], Loss: 3.314096
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_29.pth with a loss of 3.314096


Epoch 30/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [30/200], Loss: 3.485348
Sun Sep 29 15:43:34 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:B3:00.0 Off |                    0 |
| N/A   55C    P0              82W / 300W |  21091MiB / 32768MiB |     67%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                     

Epoch 31/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [31/200], Loss: 3.318998


Epoch 32/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [32/200], Loss: 3.975500


Epoch 33/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [33/200], Loss: 3.898452


Epoch 34/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [34/200], Loss: 2.984409
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_34.pth with a loss of 2.984409


Epoch 35/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [35/200], Loss: 2.834085
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_35.pth with a loss of 2.834085


Epoch 36/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [36/200], Loss: 2.562435
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_36.pth with a loss of 2.562435


Epoch 37/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [37/200], Loss: 3.127327


Epoch 38/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [38/200], Loss: 2.762575


Epoch 39/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [39/200], Loss: 2.869269


Epoch 40/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [40/200], Loss: 2.208860
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_40.pth with a loss of 2.208860
Sun Sep 29 15:44:08 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:B3:00.0 Off |                    0 |
| N/A   54C    P0              74W / 300W |  21091MiB / 32768MiB |     39%      Default |
|                                         |                      |                  N/A |
+---------------------------------

Epoch 41/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [41/200], Loss: 2.289571


Epoch 42/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [42/200], Loss: 2.257437


Epoch 43/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [43/200], Loss: 1.959504
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_43.pth with a loss of 1.959504


Epoch 44/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [44/200], Loss: 2.151804


Epoch 45/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [45/200], Loss: 2.411413


Epoch 46/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [46/200], Loss: 2.095071


Epoch 47/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [47/200], Loss: 2.134788


Epoch 48/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [48/200], Loss: 2.018716


Epoch 49/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [49/200], Loss: 1.851973
Checkpoint saved at ./checkpoints/conditional_unet_best_epoch_49.pth with a loss of 1.851973


Epoch 50/200: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s]


Epoch [50/200], Loss: 3.050195
Sun Sep 29 15:44:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000000:B3:00.0 Off |                    0 |
| N/A   55C    P0              82W / 300W |  21091MiB / 32768MiB |     67%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                     

Epoch 51/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [51/200], Loss: 2.058796


Epoch 52/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [52/200], Loss: 2.380569


Epoch 53/200: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s]


Epoch [53/200], Loss: 1.867178


Epoch 54/200: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s]


Epoch [54/200], Loss: 2.436938


Epoch 55/200:  40%|████      | 2/5 [00:01<00:02,  1.08it/s]

In [None]:
# %%
# ===============================
# Step 11: Define the Sampling and Visualization Functions
# ===============================

@torch.no_grad()
def sample(model, diffusion, conditioning_seq, num_frames=20):
    """
    Generates samples using the trained diffusion model.

    Args:
        model (nn.Module): Trained diffusion model.
        diffusion (Diffusion): Diffusion process instance.
        conditioning_seq (torch.Tensor): Conditioning sequence (batch_size, channels, frames, H, W).
        num_frames (int, optional): Number of frames in the generated sequence. Defaults to 20.

    Returns:
        torch.Tensor: Generated sequence.
    """
    model.eval()
    device = diffusion.device
    batch_size = conditioning_seq.size(0)

    # Upsample conditioning_seq to match num_frames if necessary
    if conditioning_seq.size(2) != num_frames:
        conditioning_seq = F.interpolate(
            conditioning_seq,
            size=(num_frames, conditioning_seq.size(3), conditioning_seq.size(4)),
            mode='trilinear',
            align_corners=False
        )

    # Initialize the noisy image
    shape = (batch_size, 1, num_frames, conditioning_seq.size(-2), conditioning_seq.size(-1))
    img = torch.randn(shape).to(device)

    for i in tqdm(reversed(range(diffusion.timesteps)), desc="Sampling"):
        t = torch.full((batch_size,), i, device=device, dtype=torch.long)

        # Predict the noise
        predicted_noise = model(img, conditioning_seq, t)  # Shape: (batch_size,1,20,160,280)

        # Update image
        alpha = diffusion.alphas[t].view(-1, 1, 1, 1, 1)
        beta = diffusion.betas[t].view(-1, 1, 1, 1, 1)
        alpha_hat = diffusion.alphas_cumprod[t].view(-1, 1, 1, 1, 1)
        if i > 0:
            noise = torch.randn_like(img)
        else:
            noise = torch.zeros_like(img)
        img = (1 / torch.sqrt(alpha)) * (img - ((1 - alpha) / torch.sqrt(1 - alpha_hat)) * predicted_noise) + torch.sqrt(beta) * noise

    return img  # Shape: (batch_size,1,20,160,280)

def show_sequence(sequence, nrows=1, ncols=20):
    """
    Displays a sequence of frames.

    Args:
        sequence (torch.Tensor): Sequence to display (batch_size, channels, frames, H, W).
        nrows (int, optional): Number of rows in the plot grid. Defaults to 1.
        ncols (int, optional): Number of columns (frames) to display. Defaults to 20.
    """
    sequence = sequence.cpu().numpy()
    batch_size, channels, frames, height, width = sequence.shape

    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    for i in range(nrows):
        for j in range(ncols):
            if nrows == 1:
                ax = axes[j]
            else:
                ax = axes[i, j]
            if j < frames:
                img = sequence[i, 0, j]
                ax.imshow(img, cmap='gray')
            ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# %%
# ===============================
# Step 12: Generate and Visualize Samples
# ===============================

# Get a batch of healthy sequences from the test set
test_iter = iter(test_loader)
healthy_seq, _ = next(test_iter)
healthy_seq = healthy_seq.to(device)  # Shape: (batch_size, 10, 1, 160, 280)

# Permute to (batch_size, channels, frames, height, width)
healthy_seq = healthy_seq.permute(0, 2, 1, 3, 4)  # Shape: (batch_size, 1, 10, 160, 280)

# Generate samples
generated_seq = sample(model, diffusion, healthy_seq, num_frames=10)  # Shape: (batch_size,1,20,160,280)

# Convert generated_seq from [-1, 1] to [0, 1] for visualization
generated_seq = (generated_seq + 1) / 2

# Optionally, display the corresponding healthy sequence
healthy_seq_display = (healthy_seq + 1) / 2  # Convert back to [0, 1]
show_sequence(healthy_seq_display, nrows=1, ncols=10)

# Visualize generated sequence
show_sequence(generated_seq, nrows=1, ncols=10)

In [None]:
# %%
# ===============================
# Step 13: Save the Trained Model
# ===============================

torch.save(model.state_dict(), 'conditional_unet.pth')
# To load the model later:
# model.load_state_dict(torch.load('conditional_unet.pth'))

# %%
