In [1]:
import wandb
wandb.init(project='33_ddpm_bio',name='note_half_bio')
import wandb


# Standard libraries
import os
import tempfile
import time
import io
import random
import math
import warnings
from multiprocessing import Manager
from typing import Optional

# Data manipulation libraries
import numpy as np
import pandas as pd
import scipy

# PyTorch and related libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset, random_split

# MONAI libraries
# from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.transforms import (
    AddChanneld, 
    CenterSpatialCropd, 
    Compose, 
    Lambdad, 
    LoadImaged, 
    Resized, 
    ScaleIntensityd
)
from monai.utils import set_determinism

# Other medical image processing libraries
import SimpleITK as sitk
import torchio as tio

# Plotting and visualization
import matplotlib.pyplot as plt

# Progress bar
from tqdm import tqdm

# Custom modules
from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler, DDIMScheduler

# Weights and Biases for experiment tracking
from dataloader import Train,Eval






config = {
    'batch_size': 64,
    'imgDimResize':(160,192,160),
    'imgDimPad': (208, 256, 208),
    'spatialDims': '3D',
    'unisotropic_sampling': True, 
    'perc_low': 0, 
    'perc_high': 100,
    'rescaleFactor':2,
    'base_path': '/scratch1/akrami/Latest_Data/Data',
}

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhale-akrami[0m ([33musc_akrami[0m). Use [1m`wandb login --relogin`[0m to force relogin


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html




In [2]:
wandb.config.update(config )


imgpath = {}
# '/acmenas/hakrami/patched-Diffusion-Models-UAD/Data/splits/BioBank_train.csv'
#'/acmenas/hakrami/patched-Diffusion-Models-UAD/Data/splits/IXI_train_fold0.csv',
#csvpath_trains = ['/project/ajoshi_27/akrami/patched-Diffusion-Models-UAD/Data/splits/BioBank_train.csv', '/project/ajoshi_27/akrami/patched-Diffusion-Models-UAD/Data/splits/BioBank_train.csv']
csvpath_trains=['./combined.csv']
pathBase = '/acmenas/hakrami/patched-Diffusion-Models-UAD/Data_train'
csvpath_val = '/acmenas/hakrami/3D_lesion_DF/splits/IXI_train_fold0.csv'
csvpath_test = '/acmenas/hakrami/3D_lesion_DF/splits/Brats21_sub_test.csv'
var_csv = {}
states = ['train','val','test']

df_list = []

# Loop through each CSV file path and read it into a DataFrame
for csvpath in csvpath_trains:
    df = pd.read_csv(csvpath)
    df_list.append(df)

# dfffff=  pd.concat(df_list, ignore_index=True)
# dfffff.to_csv("./combined.csv", index=False)

In [3]:
var_csv['train'] =pd.concat(df_list, ignore_index=True)
var_csv['val'] = pd.read_csv(csvpath_val)
var_csv['test'] = pd.read_csv(csvpath_test)
# if cfg.mode == 't2':
#     keep_t2 = pd.read_csv(cfg.path.IXI.keep_t2) # only keep t2 images that have a t1 counterpart

for state in states:
    var_csv[state]['settype'] = state
    var_csv[state]['norm_path'] = pathBase  + var_csv[state]['norm_path']
    var_csv[state]['img_path'] = pathBase  + var_csv[state]['img_path']
    var_csv[state]['mask_path'] = pathBase  + var_csv[state]['mask_path']
    if state != 'test':
        var_csv[state]['seg_path'] = None
    else:
        var_csv[state]['seg_path'] = pathBase  + var_csv[state]['seg_path']

    # if cfg.mode == 't2': 
    #     var_csv[state] =var_csv[state][var_csv[state].img_name.isin(keep_t2['0'].str.replace('t2','t1'))]
    #     var_csv[state]['img_path'] = var_csv[state]['img_path'].str.replace('t1','t2')
    
    
data_train = Train(var_csv['train'],config) 
data_val = Train(var_csv['val'],config)                
data_test = Eval(var_csv['test'],config)



#data_train = Train(pd.read_csv('/project/ajoshi_27/akrami/monai3D/GenerativeModels/data/split/IXI_train_fold0.csv', converters={'img_path': pd.eval}), config)
train_loader = DataLoader(data_train, batch_size=config.get('batch_size', 1),shuffle=True,num_workers=8)

#data_val = Train(pd.read_csv('/project/ajoshi_27/akrami/monai3D/GenerativeModels/data/split/IXI_val_fold0.csv', converters={'img_path': pd.eval}), config)
val_loader = DataLoader(data_val, batch_size=config.get('batch_size', 1),shuffle=True,num_workers=8)

#data_test = Train(pd.read_csv('/project/ajoshi_27/akrami/monai3D/GenerativeModels/data/split/Brats21_test.csv', converters={'img_path': pd.eval}), config)
test_loader = DataLoader(data_test, batch_size=config.get('batch_size', 1),shuffle=False,num_workers=8)


device = torch.device("cuda")

model = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    num_channels=[32, 64, 128, 128],
    attention_levels=[False, False, False,True],
    num_head_channels=[0, 0, 0,32],
    num_res_blocks=2,
)
#model_filename = '/acmenas/hakrami/3D_lesion_DF/models/halfres/model_epoch984.pt'
model_filename = '/acmenas/hakrami/3D_lesion_DF/models/norm/model_epoch199.pt'


model.to(device)
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

model.load_state_dict(torch.load(model_filename))
scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195)

inferer = DiffusionInferer(scheduler)

optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-5)


n_epochs = 1000
val_interval = 25
epoch_loss_list = []
val_epoch_loss_list = []

scaler = GradScaler()
total_start = time.time()


wandb.watch(model, log_freq=100)

RuntimeError: Error(s) in loading state_dict for DiffusionModelUNet:
	Missing key(s) in state_dict: "conv_in.conv.weight", "conv_in.conv.bias", "time_embed.0.weight", "time_embed.0.bias", "time_embed.2.weight", "time_embed.2.bias", "down_blocks.0.resnets.0.norm1.weight", "down_blocks.0.resnets.0.norm1.bias", "down_blocks.0.resnets.0.conv1.conv.weight", "down_blocks.0.resnets.0.conv1.conv.bias", "down_blocks.0.resnets.0.time_emb_proj.weight", "down_blocks.0.resnets.0.time_emb_proj.bias", "down_blocks.0.resnets.0.norm2.weight", "down_blocks.0.resnets.0.norm2.bias", "down_blocks.0.resnets.0.conv2.conv.weight", "down_blocks.0.resnets.0.conv2.conv.bias", "down_blocks.0.resnets.1.norm1.weight", "down_blocks.0.resnets.1.norm1.bias", "down_blocks.0.resnets.1.conv1.conv.weight", "down_blocks.0.resnets.1.conv1.conv.bias", "down_blocks.0.resnets.1.time_emb_proj.weight", "down_blocks.0.resnets.1.time_emb_proj.bias", "down_blocks.0.resnets.1.norm2.weight", "down_blocks.0.resnets.1.norm2.bias", "down_blocks.0.resnets.1.conv2.conv.weight", "down_blocks.0.resnets.1.conv2.conv.bias", "down_blocks.0.downsampler.op.conv.weight", "down_blocks.0.downsampler.op.conv.bias", "down_blocks.1.resnets.0.norm1.weight", "down_blocks.1.resnets.0.norm1.bias", "down_blocks.1.resnets.0.conv1.conv.weight", "down_blocks.1.resnets.0.conv1.conv.bias", "down_blocks.1.resnets.0.time_emb_proj.weight", "down_blocks.1.resnets.0.time_emb_proj.bias", "down_blocks.1.resnets.0.norm2.weight", "down_blocks.1.resnets.0.norm2.bias", "down_blocks.1.resnets.0.conv2.conv.weight", "down_blocks.1.resnets.0.conv2.conv.bias", "down_blocks.1.resnets.0.skip_connection.conv.weight", "down_blocks.1.resnets.0.skip_connection.conv.bias", "down_blocks.1.resnets.1.norm1.weight", "down_blocks.1.resnets.1.norm1.bias", "down_blocks.1.resnets.1.conv1.conv.weight", "down_blocks.1.resnets.1.conv1.conv.bias", "down_blocks.1.resnets.1.time_emb_proj.weight", "down_blocks.1.resnets.1.time_emb_proj.bias", "down_blocks.1.resnets.1.norm2.weight", "down_blocks.1.resnets.1.norm2.bias", "down_blocks.1.resnets.1.conv2.conv.weight", "down_blocks.1.resnets.1.conv2.conv.bias", "down_blocks.1.downsampler.op.conv.weight", "down_blocks.1.downsampler.op.conv.bias", "down_blocks.2.resnets.0.norm1.weight", "down_blocks.2.resnets.0.norm1.bias", "down_blocks.2.resnets.0.conv1.conv.weight", "down_blocks.2.resnets.0.conv1.conv.bias", "down_blocks.2.resnets.0.time_emb_proj.weight", "down_blocks.2.resnets.0.time_emb_proj.bias", "down_blocks.2.resnets.0.norm2.weight", "down_blocks.2.resnets.0.norm2.bias", "down_blocks.2.resnets.0.conv2.conv.weight", "down_blocks.2.resnets.0.conv2.conv.bias", "down_blocks.2.resnets.0.skip_connection.conv.weight", "down_blocks.2.resnets.0.skip_connection.conv.bias", "down_blocks.2.resnets.1.norm1.weight", "down_blocks.2.resnets.1.norm1.bias", "down_blocks.2.resnets.1.conv1.conv.weight", "down_blocks.2.resnets.1.conv1.conv.bias", "down_blocks.2.resnets.1.time_emb_proj.weight", "down_blocks.2.resnets.1.time_emb_proj.bias", "down_blocks.2.resnets.1.norm2.weight", "down_blocks.2.resnets.1.norm2.bias", "down_blocks.2.resnets.1.conv2.conv.weight", "down_blocks.2.resnets.1.conv2.conv.bias", "down_blocks.2.downsampler.op.conv.weight", "down_blocks.2.downsampler.op.conv.bias", "down_blocks.3.attentions.0.norm.weight", "down_blocks.3.attentions.0.norm.bias", "down_blocks.3.attentions.0.to_q.weight", "down_blocks.3.attentions.0.to_q.bias", "down_blocks.3.attentions.0.to_k.weight", "down_blocks.3.attentions.0.to_k.bias", "down_blocks.3.attentions.0.to_v.weight", "down_blocks.3.attentions.0.to_v.bias", "down_blocks.3.attentions.0.proj_attn.weight", "down_blocks.3.attentions.0.proj_attn.bias", "down_blocks.3.attentions.1.norm.weight", "down_blocks.3.attentions.1.norm.bias", "down_blocks.3.attentions.1.to_q.weight", "down_blocks.3.attentions.1.to_q.bias", "down_blocks.3.attentions.1.to_k.weight", "down_blocks.3.attentions.1.to_k.bias", "down_blocks.3.attentions.1.to_v.weight", "down_blocks.3.attentions.1.to_v.bias", "down_blocks.3.attentions.1.proj_attn.weight", "down_blocks.3.attentions.1.proj_attn.bias", "down_blocks.3.resnets.0.norm1.weight", "down_blocks.3.resnets.0.norm1.bias", "down_blocks.3.resnets.0.conv1.conv.weight", "down_blocks.3.resnets.0.conv1.conv.bias", "down_blocks.3.resnets.0.time_emb_proj.weight", "down_blocks.3.resnets.0.time_emb_proj.bias", "down_blocks.3.resnets.0.norm2.weight", "down_blocks.3.resnets.0.norm2.bias", "down_blocks.3.resnets.0.conv2.conv.weight", "down_blocks.3.resnets.0.conv2.conv.bias", "down_blocks.3.resnets.1.norm1.weight", "down_blocks.3.resnets.1.norm1.bias", "down_blocks.3.resnets.1.conv1.conv.weight", "down_blocks.3.resnets.1.conv1.conv.bias", "down_blocks.3.resnets.1.time_emb_proj.weight", "down_blocks.3.resnets.1.time_emb_proj.bias", "down_blocks.3.resnets.1.norm2.weight", "down_blocks.3.resnets.1.norm2.bias", "down_blocks.3.resnets.1.conv2.conv.weight", "down_blocks.3.resnets.1.conv2.conv.bias", "middle_block.resnet_1.norm1.weight", "middle_block.resnet_1.norm1.bias", "middle_block.resnet_1.conv1.conv.weight", "middle_block.resnet_1.conv1.conv.bias", "middle_block.resnet_1.time_emb_proj.weight", "middle_block.resnet_1.time_emb_proj.bias", "middle_block.resnet_1.norm2.weight", "middle_block.resnet_1.norm2.bias", "middle_block.resnet_1.conv2.conv.weight", "middle_block.resnet_1.conv2.conv.bias", "middle_block.attention.norm.weight", "middle_block.attention.norm.bias", "middle_block.attention.to_q.weight", "middle_block.attention.to_q.bias", "middle_block.attention.to_k.weight", "middle_block.attention.to_k.bias", "middle_block.attention.to_v.weight", "middle_block.attention.to_v.bias", "middle_block.attention.proj_attn.weight", "middle_block.attention.proj_attn.bias", "middle_block.resnet_2.norm1.weight", "middle_block.resnet_2.norm1.bias", "middle_block.resnet_2.conv1.conv.weight", "middle_block.resnet_2.conv1.conv.bias", "middle_block.resnet_2.time_emb_proj.weight", "middle_block.resnet_2.time_emb_proj.bias", "middle_block.resnet_2.norm2.weight", "middle_block.resnet_2.norm2.bias", "middle_block.resnet_2.conv2.conv.weight", "middle_block.resnet_2.conv2.conv.bias", "up_blocks.0.resnets.0.norm1.weight", "up_blocks.0.resnets.0.norm1.bias", "up_blocks.0.resnets.0.conv1.conv.weight", "up_blocks.0.resnets.0.conv1.conv.bias", "up_blocks.0.resnets.0.time_emb_proj.weight", "up_blocks.0.resnets.0.time_emb_proj.bias", "up_blocks.0.resnets.0.norm2.weight", "up_blocks.0.resnets.0.norm2.bias", "up_blocks.0.resnets.0.conv2.conv.weight", "up_blocks.0.resnets.0.conv2.conv.bias", "up_blocks.0.resnets.0.skip_connection.conv.weight", "up_blocks.0.resnets.0.skip_connection.conv.bias", "up_blocks.0.resnets.1.norm1.weight", "up_blocks.0.resnets.1.norm1.bias", "up_blocks.0.resnets.1.conv1.conv.weight", "up_blocks.0.resnets.1.conv1.conv.bias", "up_blocks.0.resnets.1.time_emb_proj.weight", "up_blocks.0.resnets.1.time_emb_proj.bias", "up_blocks.0.resnets.1.norm2.weight", "up_blocks.0.resnets.1.norm2.bias", "up_blocks.0.resnets.1.conv2.conv.weight", "up_blocks.0.resnets.1.conv2.conv.bias", "up_blocks.0.resnets.1.skip_connection.conv.weight", "up_blocks.0.resnets.1.skip_connection.conv.bias", "up_blocks.0.resnets.2.norm1.weight", "up_blocks.0.resnets.2.norm1.bias", "up_blocks.0.resnets.2.conv1.conv.weight", "up_blocks.0.resnets.2.conv1.conv.bias", "up_blocks.0.resnets.2.time_emb_proj.weight", "up_blocks.0.resnets.2.time_emb_proj.bias", "up_blocks.0.resnets.2.norm2.weight", "up_blocks.0.resnets.2.norm2.bias", "up_blocks.0.resnets.2.conv2.conv.weight", "up_blocks.0.resnets.2.conv2.conv.bias", "up_blocks.0.resnets.2.skip_connection.conv.weight", "up_blocks.0.resnets.2.skip_connection.conv.bias", "up_blocks.0.attentions.0.norm.weight", "up_blocks.0.attentions.0.norm.bias", "up_blocks.0.attentions.0.to_q.weight", "up_blocks.0.attentions.0.to_q.bias", "up_blocks.0.attentions.0.to_k.weight", "up_blocks.0.attentions.0.to_k.bias", "up_blocks.0.attentions.0.to_v.weight", "up_blocks.0.attentions.0.to_v.bias", "up_blocks.0.attentions.0.proj_attn.weight", "up_blocks.0.attentions.0.proj_attn.bias", "up_blocks.0.attentions.1.norm.weight", "up_blocks.0.attentions.1.norm.bias", "up_blocks.0.attentions.1.to_q.weight", "up_blocks.0.attentions.1.to_q.bias", "up_blocks.0.attentions.1.to_k.weight", "up_blocks.0.attentions.1.to_k.bias", "up_blocks.0.attentions.1.to_v.weight", "up_blocks.0.attentions.1.to_v.bias", "up_blocks.0.attentions.1.proj_attn.weight", "up_blocks.0.attentions.1.proj_attn.bias", "up_blocks.0.attentions.2.norm.weight", "up_blocks.0.attentions.2.norm.bias", "up_blocks.0.attentions.2.to_q.weight", "up_blocks.0.attentions.2.to_q.bias", "up_blocks.0.attentions.2.to_k.weight", "up_blocks.0.attentions.2.to_k.bias", "up_blocks.0.attentions.2.to_v.weight", "up_blocks.0.attentions.2.to_v.bias", "up_blocks.0.attentions.2.proj_attn.weight", "up_blocks.0.attentions.2.proj_attn.bias", "up_blocks.0.upsampler.conv.conv.weight", "up_blocks.0.upsampler.conv.conv.bias", "up_blocks.1.resnets.0.norm1.weight", "up_blocks.1.resnets.0.norm1.bias", "up_blocks.1.resnets.0.conv1.conv.weight", "up_blocks.1.resnets.0.conv1.conv.bias", "up_blocks.1.resnets.0.time_emb_proj.weight", "up_blocks.1.resnets.0.time_emb_proj.bias", "up_blocks.1.resnets.0.norm2.weight", "up_blocks.1.resnets.0.norm2.bias", "up_blocks.1.resnets.0.conv2.conv.weight", "up_blocks.1.resnets.0.conv2.conv.bias", "up_blocks.1.resnets.0.skip_connection.conv.weight", "up_blocks.1.resnets.0.skip_connection.conv.bias", "up_blocks.1.resnets.1.norm1.weight", "up_blocks.1.resnets.1.norm1.bias", "up_blocks.1.resnets.1.conv1.conv.weight", "up_blocks.1.resnets.1.conv1.conv.bias", "up_blocks.1.resnets.1.time_emb_proj.weight", "up_blocks.1.resnets.1.time_emb_proj.bias", "up_blocks.1.resnets.1.norm2.weight", "up_blocks.1.resnets.1.norm2.bias", "up_blocks.1.resnets.1.conv2.conv.weight", "up_blocks.1.resnets.1.conv2.conv.bias", "up_blocks.1.resnets.1.skip_connection.conv.weight", "up_blocks.1.resnets.1.skip_connection.conv.bias", "up_blocks.1.resnets.2.norm1.weight", "up_blocks.1.resnets.2.norm1.bias", "up_blocks.1.resnets.2.conv1.conv.weight", "up_blocks.1.resnets.2.conv1.conv.bias", "up_blocks.1.resnets.2.time_emb_proj.weight", "up_blocks.1.resnets.2.time_emb_proj.bias", "up_blocks.1.resnets.2.norm2.weight", "up_blocks.1.resnets.2.norm2.bias", "up_blocks.1.resnets.2.conv2.conv.weight", "up_blocks.1.resnets.2.conv2.conv.bias", "up_blocks.1.resnets.2.skip_connection.conv.weight", "up_blocks.1.resnets.2.skip_connection.conv.bias", "up_blocks.1.upsampler.conv.conv.weight", "up_blocks.1.upsampler.conv.conv.bias", "up_blocks.2.resnets.0.norm1.weight", "up_blocks.2.resnets.0.norm1.bias", "up_blocks.2.resnets.0.conv1.conv.weight", "up_blocks.2.resnets.0.conv1.conv.bias", "up_blocks.2.resnets.0.time_emb_proj.weight", "up_blocks.2.resnets.0.time_emb_proj.bias", "up_blocks.2.resnets.0.norm2.weight", "up_blocks.2.resnets.0.norm2.bias", "up_blocks.2.resnets.0.conv2.conv.weight", "up_blocks.2.resnets.0.conv2.conv.bias", "up_blocks.2.resnets.0.skip_connection.conv.weight", "up_blocks.2.resnets.0.skip_connection.conv.bias", "up_blocks.2.resnets.1.norm1.weight", "up_blocks.2.resnets.1.norm1.bias", "up_blocks.2.resnets.1.conv1.conv.weight", "up_blocks.2.resnets.1.conv1.conv.bias", "up_blocks.2.resnets.1.time_emb_proj.weight", "up_blocks.2.resnets.1.time_emb_proj.bias", "up_blocks.2.resnets.1.norm2.weight", "up_blocks.2.resnets.1.norm2.bias", "up_blocks.2.resnets.1.conv2.conv.weight", "up_blocks.2.resnets.1.conv2.conv.bias", "up_blocks.2.resnets.1.skip_connection.conv.weight", "up_blocks.2.resnets.1.skip_connection.conv.bias", "up_blocks.2.resnets.2.norm1.weight", "up_blocks.2.resnets.2.norm1.bias", "up_blocks.2.resnets.2.conv1.conv.weight", "up_blocks.2.resnets.2.conv1.conv.bias", "up_blocks.2.resnets.2.time_emb_proj.weight", "up_blocks.2.resnets.2.time_emb_proj.bias", "up_blocks.2.resnets.2.norm2.weight", "up_blocks.2.resnets.2.norm2.bias", "up_blocks.2.resnets.2.conv2.conv.weight", "up_blocks.2.resnets.2.conv2.conv.bias", "up_blocks.2.resnets.2.skip_connection.conv.weight", "up_blocks.2.resnets.2.skip_connection.conv.bias", "up_blocks.2.upsampler.conv.conv.weight", "up_blocks.2.upsampler.conv.conv.bias", "up_blocks.3.resnets.0.norm1.weight", "up_blocks.3.resnets.0.norm1.bias", "up_blocks.3.resnets.0.conv1.conv.weight", "up_blocks.3.resnets.0.conv1.conv.bias", "up_blocks.3.resnets.0.time_emb_proj.weight", "up_blocks.3.resnets.0.time_emb_proj.bias", "up_blocks.3.resnets.0.norm2.weight", "up_blocks.3.resnets.0.norm2.bias", "up_blocks.3.resnets.0.conv2.conv.weight", "up_blocks.3.resnets.0.conv2.conv.bias", "up_blocks.3.resnets.0.skip_connection.conv.weight", "up_blocks.3.resnets.0.skip_connection.conv.bias", "up_blocks.3.resnets.1.norm1.weight", "up_blocks.3.resnets.1.norm1.bias", "up_blocks.3.resnets.1.conv1.conv.weight", "up_blocks.3.resnets.1.conv1.conv.bias", "up_blocks.3.resnets.1.time_emb_proj.weight", "up_blocks.3.resnets.1.time_emb_proj.bias", "up_blocks.3.resnets.1.norm2.weight", "up_blocks.3.resnets.1.norm2.bias", "up_blocks.3.resnets.1.conv2.conv.weight", "up_blocks.3.resnets.1.conv2.conv.bias", "up_blocks.3.resnets.1.skip_connection.conv.weight", "up_blocks.3.resnets.1.skip_connection.conv.bias", "up_blocks.3.resnets.2.norm1.weight", "up_blocks.3.resnets.2.norm1.bias", "up_blocks.3.resnets.2.conv1.conv.weight", "up_blocks.3.resnets.2.conv1.conv.bias", "up_blocks.3.resnets.2.time_emb_proj.weight", "up_blocks.3.resnets.2.time_emb_proj.bias", "up_blocks.3.resnets.2.norm2.weight", "up_blocks.3.resnets.2.norm2.bias", "up_blocks.3.resnets.2.conv2.conv.weight", "up_blocks.3.resnets.2.conv2.conv.bias", "up_blocks.3.resnets.2.skip_connection.conv.weight", "up_blocks.3.resnets.2.skip_connection.conv.bias", "out.0.weight", "out.0.bias", "out.2.conv.weight", "out.2.conv.bias". 
	Unexpected key(s) in state_dict: "module.conv_in.conv.weight", "module.conv_in.conv.bias", "module.time_embed.0.weight", "module.time_embed.0.bias", "module.time_embed.2.weight", "module.time_embed.2.bias", "module.down_blocks.0.resnets.0.norm1.weight", "module.down_blocks.0.resnets.0.norm1.bias", "module.down_blocks.0.resnets.0.conv1.conv.weight", "module.down_blocks.0.resnets.0.conv1.conv.bias", "module.down_blocks.0.resnets.0.time_emb_proj.weight", "module.down_blocks.0.resnets.0.time_emb_proj.bias", "module.down_blocks.0.resnets.0.norm2.weight", "module.down_blocks.0.resnets.0.norm2.bias", "module.down_blocks.0.resnets.0.conv2.conv.weight", "module.down_blocks.0.resnets.0.conv2.conv.bias", "module.down_blocks.0.resnets.1.norm1.weight", "module.down_blocks.0.resnets.1.norm1.bias", "module.down_blocks.0.resnets.1.conv1.conv.weight", "module.down_blocks.0.resnets.1.conv1.conv.bias", "module.down_blocks.0.resnets.1.time_emb_proj.weight", "module.down_blocks.0.resnets.1.time_emb_proj.bias", "module.down_blocks.0.resnets.1.norm2.weight", "module.down_blocks.0.resnets.1.norm2.bias", "module.down_blocks.0.resnets.1.conv2.conv.weight", "module.down_blocks.0.resnets.1.conv2.conv.bias", "module.down_blocks.0.downsampler.op.conv.weight", "module.down_blocks.0.downsampler.op.conv.bias", "module.down_blocks.1.resnets.0.norm1.weight", "module.down_blocks.1.resnets.0.norm1.bias", "module.down_blocks.1.resnets.0.conv1.conv.weight", "module.down_blocks.1.resnets.0.conv1.conv.bias", "module.down_blocks.1.resnets.0.time_emb_proj.weight", "module.down_blocks.1.resnets.0.time_emb_proj.bias", "module.down_blocks.1.resnets.0.norm2.weight", "module.down_blocks.1.resnets.0.norm2.bias", "module.down_blocks.1.resnets.0.conv2.conv.weight", "module.down_blocks.1.resnets.0.conv2.conv.bias", "module.down_blocks.1.resnets.0.skip_connection.conv.weight", "module.down_blocks.1.resnets.0.skip_connection.conv.bias", "module.down_blocks.1.resnets.1.norm1.weight", "module.down_blocks.1.resnets.1.norm1.bias", "module.down_blocks.1.resnets.1.conv1.conv.weight", "module.down_blocks.1.resnets.1.conv1.conv.bias", "module.down_blocks.1.resnets.1.time_emb_proj.weight", "module.down_blocks.1.resnets.1.time_emb_proj.bias", "module.down_blocks.1.resnets.1.norm2.weight", "module.down_blocks.1.resnets.1.norm2.bias", "module.down_blocks.1.resnets.1.conv2.conv.weight", "module.down_blocks.1.resnets.1.conv2.conv.bias", "module.down_blocks.1.downsampler.op.conv.weight", "module.down_blocks.1.downsampler.op.conv.bias", "module.down_blocks.2.resnets.0.norm1.weight", "module.down_blocks.2.resnets.0.norm1.bias", "module.down_blocks.2.resnets.0.conv1.conv.weight", "module.down_blocks.2.resnets.0.conv1.conv.bias", "module.down_blocks.2.resnets.0.time_emb_proj.weight", "module.down_blocks.2.resnets.0.time_emb_proj.bias", "module.down_blocks.2.resnets.0.norm2.weight", "module.down_blocks.2.resnets.0.norm2.bias", "module.down_blocks.2.resnets.0.conv2.conv.weight", "module.down_blocks.2.resnets.0.conv2.conv.bias", "module.down_blocks.2.resnets.0.skip_connection.conv.weight", "module.down_blocks.2.resnets.0.skip_connection.conv.bias", "module.down_blocks.2.resnets.1.norm1.weight", "module.down_blocks.2.resnets.1.norm1.bias", "module.down_blocks.2.resnets.1.conv1.conv.weight", "module.down_blocks.2.resnets.1.conv1.conv.bias", "module.down_blocks.2.resnets.1.time_emb_proj.weight", "module.down_blocks.2.resnets.1.time_emb_proj.bias", "module.down_blocks.2.resnets.1.norm2.weight", "module.down_blocks.2.resnets.1.norm2.bias", "module.down_blocks.2.resnets.1.conv2.conv.weight", "module.down_blocks.2.resnets.1.conv2.conv.bias", "module.down_blocks.2.downsampler.op.conv.weight", "module.down_blocks.2.downsampler.op.conv.bias", "module.down_blocks.3.attentions.0.norm.weight", "module.down_blocks.3.attentions.0.norm.bias", "module.down_blocks.3.attentions.0.to_q.weight", "module.down_blocks.3.attentions.0.to_q.bias", "module.down_blocks.3.attentions.0.to_k.weight", "module.down_blocks.3.attentions.0.to_k.bias", "module.down_blocks.3.attentions.0.to_v.weight", "module.down_blocks.3.attentions.0.to_v.bias", "module.down_blocks.3.attentions.0.proj_attn.weight", "module.down_blocks.3.attentions.0.proj_attn.bias", "module.down_blocks.3.attentions.1.norm.weight", "module.down_blocks.3.attentions.1.norm.bias", "module.down_blocks.3.attentions.1.to_q.weight", "module.down_blocks.3.attentions.1.to_q.bias", "module.down_blocks.3.attentions.1.to_k.weight", "module.down_blocks.3.attentions.1.to_k.bias", "module.down_blocks.3.attentions.1.to_v.weight", "module.down_blocks.3.attentions.1.to_v.bias", "module.down_blocks.3.attentions.1.proj_attn.weight", "module.down_blocks.3.attentions.1.proj_attn.bias", "module.down_blocks.3.resnets.0.norm1.weight", "module.down_blocks.3.resnets.0.norm1.bias", "module.down_blocks.3.resnets.0.conv1.conv.weight", "module.down_blocks.3.resnets.0.conv1.conv.bias", "module.down_blocks.3.resnets.0.time_emb_proj.weight", "module.down_blocks.3.resnets.0.time_emb_proj.bias", "module.down_blocks.3.resnets.0.norm2.weight", "module.down_blocks.3.resnets.0.norm2.bias", "module.down_blocks.3.resnets.0.conv2.conv.weight", "module.down_blocks.3.resnets.0.conv2.conv.bias", "module.down_blocks.3.resnets.1.norm1.weight", "module.down_blocks.3.resnets.1.norm1.bias", "module.down_blocks.3.resnets.1.conv1.conv.weight", "module.down_blocks.3.resnets.1.conv1.conv.bias", "module.down_blocks.3.resnets.1.time_emb_proj.weight", "module.down_blocks.3.resnets.1.time_emb_proj.bias", "module.down_blocks.3.resnets.1.norm2.weight", "module.down_blocks.3.resnets.1.norm2.bias", "module.down_blocks.3.resnets.1.conv2.conv.weight", "module.down_blocks.3.resnets.1.conv2.conv.bias", "module.middle_block.resnet_1.norm1.weight", "module.middle_block.resnet_1.norm1.bias", "module.middle_block.resnet_1.conv1.conv.weight", "module.middle_block.resnet_1.conv1.conv.bias", "module.middle_block.resnet_1.time_emb_proj.weight", "module.middle_block.resnet_1.time_emb_proj.bias", "module.middle_block.resnet_1.norm2.weight", "module.middle_block.resnet_1.norm2.bias", "module.middle_block.resnet_1.conv2.conv.weight", "module.middle_block.resnet_1.conv2.conv.bias", "module.middle_block.attention.norm.weight", "module.middle_block.attention.norm.bias", "module.middle_block.attention.to_q.weight", "module.middle_block.attention.to_q.bias", "module.middle_block.attention.to_k.weight", "module.middle_block.attention.to_k.bias", "module.middle_block.attention.to_v.weight", "module.middle_block.attention.to_v.bias", "module.middle_block.attention.proj_attn.weight", "module.middle_block.attention.proj_attn.bias", "module.middle_block.resnet_2.norm1.weight", "module.middle_block.resnet_2.norm1.bias", "module.middle_block.resnet_2.conv1.conv.weight", "module.middle_block.resnet_2.conv1.conv.bias", "module.middle_block.resnet_2.time_emb_proj.weight", "module.middle_block.resnet_2.time_emb_proj.bias", "module.middle_block.resnet_2.norm2.weight", "module.middle_block.resnet_2.norm2.bias", "module.middle_block.resnet_2.conv2.conv.weight", "module.middle_block.resnet_2.conv2.conv.bias", "module.up_blocks.0.resnets.0.norm1.weight", "module.up_blocks.0.resnets.0.norm1.bias", "module.up_blocks.0.resnets.0.conv1.conv.weight", "module.up_blocks.0.resnets.0.conv1.conv.bias", "module.up_blocks.0.resnets.0.time_emb_proj.weight", "module.up_blocks.0.resnets.0.time_emb_proj.bias", "module.up_blocks.0.resnets.0.norm2.weight", "module.up_blocks.0.resnets.0.norm2.bias", "module.up_blocks.0.resnets.0.conv2.conv.weight", "module.up_blocks.0.resnets.0.conv2.conv.bias", "module.up_blocks.0.resnets.0.skip_connection.conv.weight", "module.up_blocks.0.resnets.0.skip_connection.conv.bias", "module.up_blocks.0.resnets.1.norm1.weight", "module.up_blocks.0.resnets.1.norm1.bias", "module.up_blocks.0.resnets.1.conv1.conv.weight", "module.up_blocks.0.resnets.1.conv1.conv.bias", "module.up_blocks.0.resnets.1.time_emb_proj.weight", "module.up_blocks.0.resnets.1.time_emb_proj.bias", "module.up_blocks.0.resnets.1.norm2.weight", "module.up_blocks.0.resnets.1.norm2.bias", "module.up_blocks.0.resnets.1.conv2.conv.weight", "module.up_blocks.0.resnets.1.conv2.conv.bias", "module.up_blocks.0.resnets.1.skip_connection.conv.weight", "module.up_blocks.0.resnets.1.skip_connection.conv.bias", "module.up_blocks.0.resnets.2.norm1.weight", "module.up_blocks.0.resnets.2.norm1.bias", "module.up_blocks.0.resnets.2.conv1.conv.weight", "module.up_blocks.0.resnets.2.conv1.conv.bias", "module.up_blocks.0.resnets.2.time_emb_proj.weight", "module.up_blocks.0.resnets.2.time_emb_proj.bias", "module.up_blocks.0.resnets.2.norm2.weight", "module.up_blocks.0.resnets.2.norm2.bias", "module.up_blocks.0.resnets.2.conv2.conv.weight", "module.up_blocks.0.resnets.2.conv2.conv.bias", "module.up_blocks.0.resnets.2.skip_connection.conv.weight", "module.up_blocks.0.resnets.2.skip_connection.conv.bias", "module.up_blocks.0.attentions.0.norm.weight", "module.up_blocks.0.attentions.0.norm.bias", "module.up_blocks.0.attentions.0.to_q.weight", "module.up_blocks.0.attentions.0.to_q.bias", "module.up_blocks.0.attentions.0.to_k.weight", "module.up_blocks.0.attentions.0.to_k.bias", "module.up_blocks.0.attentions.0.to_v.weight", "module.up_blocks.0.attentions.0.to_v.bias", "module.up_blocks.0.attentions.0.proj_attn.weight", "module.up_blocks.0.attentions.0.proj_attn.bias", "module.up_blocks.0.attentions.1.norm.weight", "module.up_blocks.0.attentions.1.norm.bias", "module.up_blocks.0.attentions.1.to_q.weight", "module.up_blocks.0.attentions.1.to_q.bias", "module.up_blocks.0.attentions.1.to_k.weight", "module.up_blocks.0.attentions.1.to_k.bias", "module.up_blocks.0.attentions.1.to_v.weight", "module.up_blocks.0.attentions.1.to_v.bias", "module.up_blocks.0.attentions.1.proj_attn.weight", "module.up_blocks.0.attentions.1.proj_attn.bias", "module.up_blocks.0.attentions.2.norm.weight", "module.up_blocks.0.attentions.2.norm.bias", "module.up_blocks.0.attentions.2.to_q.weight", "module.up_blocks.0.attentions.2.to_q.bias", "module.up_blocks.0.attentions.2.to_k.weight", "module.up_blocks.0.attentions.2.to_k.bias", "module.up_blocks.0.attentions.2.to_v.weight", "module.up_blocks.0.attentions.2.to_v.bias", "module.up_blocks.0.attentions.2.proj_attn.weight", "module.up_blocks.0.attentions.2.proj_attn.bias", "module.up_blocks.0.upsampler.conv.conv.weight", "module.up_blocks.0.upsampler.conv.conv.bias", "module.up_blocks.1.resnets.0.norm1.weight", "module.up_blocks.1.resnets.0.norm1.bias", "module.up_blocks.1.resnets.0.conv1.conv.weight", "module.up_blocks.1.resnets.0.conv1.conv.bias", "module.up_blocks.1.resnets.0.time_emb_proj.weight", "module.up_blocks.1.resnets.0.time_emb_proj.bias", "module.up_blocks.1.resnets.0.norm2.weight", "module.up_blocks.1.resnets.0.norm2.bias", "module.up_blocks.1.resnets.0.conv2.conv.weight", "module.up_blocks.1.resnets.0.conv2.conv.bias", "module.up_blocks.1.resnets.0.skip_connection.conv.weight", "module.up_blocks.1.resnets.0.skip_connection.conv.bias", "module.up_blocks.1.resnets.1.norm1.weight", "module.up_blocks.1.resnets.1.norm1.bias", "module.up_blocks.1.resnets.1.conv1.conv.weight", "module.up_blocks.1.resnets.1.conv1.conv.bias", "module.up_blocks.1.resnets.1.time_emb_proj.weight", "module.up_blocks.1.resnets.1.time_emb_proj.bias", "module.up_blocks.1.resnets.1.norm2.weight", "module.up_blocks.1.resnets.1.norm2.bias", "module.up_blocks.1.resnets.1.conv2.conv.weight", "module.up_blocks.1.resnets.1.conv2.conv.bias", "module.up_blocks.1.resnets.1.skip_connection.conv.weight", "module.up_blocks.1.resnets.1.skip_connection.conv.bias", "module.up_blocks.1.resnets.2.norm1.weight", "module.up_blocks.1.resnets.2.norm1.bias", "module.up_blocks.1.resnets.2.conv1.conv.weight", "module.up_blocks.1.resnets.2.conv1.conv.bias", "module.up_blocks.1.resnets.2.time_emb_proj.weight", "module.up_blocks.1.resnets.2.time_emb_proj.bias", "module.up_blocks.1.resnets.2.norm2.weight", "module.up_blocks.1.resnets.2.norm2.bias", "module.up_blocks.1.resnets.2.conv2.conv.weight", "module.up_blocks.1.resnets.2.conv2.conv.bias", "module.up_blocks.1.resnets.2.skip_connection.conv.weight", "module.up_blocks.1.resnets.2.skip_connection.conv.bias", "module.up_blocks.1.upsampler.conv.conv.weight", "module.up_blocks.1.upsampler.conv.conv.bias", "module.up_blocks.2.resnets.0.norm1.weight", "module.up_blocks.2.resnets.0.norm1.bias", "module.up_blocks.2.resnets.0.conv1.conv.weight", "module.up_blocks.2.resnets.0.conv1.conv.bias", "module.up_blocks.2.resnets.0.time_emb_proj.weight", "module.up_blocks.2.resnets.0.time_emb_proj.bias", "module.up_blocks.2.resnets.0.norm2.weight", "module.up_blocks.2.resnets.0.norm2.bias", "module.up_blocks.2.resnets.0.conv2.conv.weight", "module.up_blocks.2.resnets.0.conv2.conv.bias", "module.up_blocks.2.resnets.0.skip_connection.conv.weight", "module.up_blocks.2.resnets.0.skip_connection.conv.bias", "module.up_blocks.2.resnets.1.norm1.weight", "module.up_blocks.2.resnets.1.norm1.bias", "module.up_blocks.2.resnets.1.conv1.conv.weight", "module.up_blocks.2.resnets.1.conv1.conv.bias", "module.up_blocks.2.resnets.1.time_emb_proj.weight", "module.up_blocks.2.resnets.1.time_emb_proj.bias", "module.up_blocks.2.resnets.1.norm2.weight", "module.up_blocks.2.resnets.1.norm2.bias", "module.up_blocks.2.resnets.1.conv2.conv.weight", "module.up_blocks.2.resnets.1.conv2.conv.bias", "module.up_blocks.2.resnets.1.skip_connection.conv.weight", "module.up_blocks.2.resnets.1.skip_connection.conv.bias", "module.up_blocks.2.resnets.2.norm1.weight", "module.up_blocks.2.resnets.2.norm1.bias", "module.up_blocks.2.resnets.2.conv1.conv.weight", "module.up_blocks.2.resnets.2.conv1.conv.bias", "module.up_blocks.2.resnets.2.time_emb_proj.weight", "module.up_blocks.2.resnets.2.time_emb_proj.bias", "module.up_blocks.2.resnets.2.norm2.weight", "module.up_blocks.2.resnets.2.norm2.bias", "module.up_blocks.2.resnets.2.conv2.conv.weight", "module.up_blocks.2.resnets.2.conv2.conv.bias", "module.up_blocks.2.resnets.2.skip_connection.conv.weight", "module.up_blocks.2.resnets.2.skip_connection.conv.bias", "module.up_blocks.2.upsampler.conv.conv.weight", "module.up_blocks.2.upsampler.conv.conv.bias", "module.up_blocks.3.resnets.0.norm1.weight", "module.up_blocks.3.resnets.0.norm1.bias", "module.up_blocks.3.resnets.0.conv1.conv.weight", "module.up_blocks.3.resnets.0.conv1.conv.bias", "module.up_blocks.3.resnets.0.time_emb_proj.weight", "module.up_blocks.3.resnets.0.time_emb_proj.bias", "module.up_blocks.3.resnets.0.norm2.weight", "module.up_blocks.3.resnets.0.norm2.bias", "module.up_blocks.3.resnets.0.conv2.conv.weight", "module.up_blocks.3.resnets.0.conv2.conv.bias", "module.up_blocks.3.resnets.0.skip_connection.conv.weight", "module.up_blocks.3.resnets.0.skip_connection.conv.bias", "module.up_blocks.3.resnets.1.norm1.weight", "module.up_blocks.3.resnets.1.norm1.bias", "module.up_blocks.3.resnets.1.conv1.conv.weight", "module.up_blocks.3.resnets.1.conv1.conv.bias", "module.up_blocks.3.resnets.1.time_emb_proj.weight", "module.up_blocks.3.resnets.1.time_emb_proj.bias", "module.up_blocks.3.resnets.1.norm2.weight", "module.up_blocks.3.resnets.1.norm2.bias", "module.up_blocks.3.resnets.1.conv2.conv.weight", "module.up_blocks.3.resnets.1.conv2.conv.bias", "module.up_blocks.3.resnets.1.skip_connection.conv.weight", "module.up_blocks.3.resnets.1.skip_connection.conv.bias", "module.up_blocks.3.resnets.2.norm1.weight", "module.up_blocks.3.resnets.2.norm1.bias", "module.up_blocks.3.resnets.2.conv1.conv.weight", "module.up_blocks.3.resnets.2.conv1.conv.bias", "module.up_blocks.3.resnets.2.time_emb_proj.weight", "module.up_blocks.3.resnets.2.time_emb_proj.bias", "module.up_blocks.3.resnets.2.norm2.weight", "module.up_blocks.3.resnets.2.norm2.bias", "module.up_blocks.3.resnets.2.conv2.conv.weight", "module.up_blocks.3.resnets.2.conv2.conv.bias", "module.up_blocks.3.resnets.2.skip_connection.conv.weight", "module.up_blocks.3.resnets.2.skip_connection.conv.bias", "module.out.0.weight", "module.out.0.bias", "module.out.2.conv.weight", "module.out.2.conv.bias". 

In [None]:
# scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195)

# inferer = DiffusionInferer(scheduler)

# optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-5)



# epoch_loss_list = []
# val_epoch_loss_list = []

# scaler = GradScaler()
# total_start = time.time()
# n_epochs = config.get('n_epochs',100)
# val_interval =config.get('val_interval',5)

In [5]:
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
       # images = batch["image"].to(device)
        images = batch['vol']['data'].to(device)
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=True):
            # Generate random noise
            noise = torch.randn_like(images).to(device)

            # Create timesteps
            timesteps = torch.randint(
                0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
            ).long()

            # Get model prediction
            noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)

            loss = F.mse_loss(noise_pred.float(), noise.float())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})
    epoch_loss_list.append(epoch_loss / (step + 1))
    wandb.log({"loss_train": epoch_loss / (step + 1)})

    if (epoch + 1) % val_interval == 0:
        model.eval()
        val_epoch_loss = 0
        for step, batch in enumerate(val_loader):
            images = batch['vol']['data'].to(device)
            noise = torch.randn_like(images).to(device)
            with torch.no_grad():
                with autocast(enabled=True):
                    timesteps = torch.randint(
                        0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
                    ).long()

                    # Get model prediction
                    noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
                    val_loss = F.mse_loss(noise_pred.float(), noise.float())

            val_epoch_loss += val_loss.item()
            progress_bar.set_postfix({"val_loss": val_epoch_loss / (step + 1)})
        val_epoch_loss_list.append(val_epoch_loss / (step + 1))
        wandb.log({"loss_val": val_epoch_loss / (step + 1)})

        # Sampling image during training
        #80, 96, 80
        image = torch.randn_like(images)[0:1,:,:,:]
        image = image.to(device)
        scheduler.set_timesteps(num_inference_steps=1000)
        with autocast(enabled=True):
            image = inferer.sample(input_noise=image, diffusion_model=model, scheduler=scheduler)


        middle_slice_idx = image.size(-1) // 2
        plt.figure(figsize=(2, 2))
        plt.imshow(image[0, 0, :, :, middle_slice_idx].cpu(), vmin=0, vmax=1, cmap="gray")
        plt.tight_layout()
        plt.axis("off")
        plt.show()
        wandb.log({"sample_image": [wandb.Image(plt)]})
        # Modify the filename to include the epoch number
        filename = f"./results/norm_2/sample_epoch{epoch}.png"

        plt.savefig(filename, dpi=300)  
        # Save the model
        model_filename = f"./models/norm_2/model_epoch{epoch}.pt"
        torch.save(model.state_dict(), model_filename)

total_time = time.time() - total_start
print(f"train completed, total time: {total_time}.")

Epoch 0: 100%|████████| 302/302 [09:38<00:00,  1.91s/it, loss=0.00701]
Epoch 1: 100%|████████| 302/302 [02:23<00:00,  2.10it/s, loss=0.00655]
Epoch 2: 100%|████████| 302/302 [02:17<00:00,  2.19it/s, loss=0.00651]
Epoch 3: 100%|████████| 302/302 [02:13<00:00,  2.26it/s, loss=0.00644]
Epoch 4: 100%|████████| 302/302 [02:13<00:00,  2.26it/s, loss=0.00648]
Epoch 5: 100%|█████████| 302/302 [02:14<00:00,  2.25it/s, loss=0.0064]
Epoch 6: 100%|████████| 302/302 [02:13<00:00,  2.27it/s, loss=0.00629]
Epoch 7: 100%|████████| 302/302 [02:13<00:00,  2.26it/s, loss=0.00645]
Epoch 8: 100%|████████| 302/302 [02:12<00:00,  2.28it/s, loss=0.00603]
Epoch 9: 100%|████████| 302/302 [02:13<00:00,  2.27it/s, loss=0.00627]
Epoch 10: 100%|████████| 302/302 [02:12<00:00,  2.27it/s, loss=0.0062]
Epoch 11: 100%|███████| 302/302 [02:13<00:00,  2.27it/s, loss=0.00627]
Epoch 12: 100%|███████| 302/302 [02:12<00:00,  2.28it/s, loss=0.00589]
Epoch 13: 100%|█████████| 302/302 [02:12<00:00,  2.28it/s, loss=0.006]
Epoch 

In [None]:
f"/scratch1/akrami/models/3Ddiffusion/half_norm/model_epoch{epoch}.pt"
f"./results/half_norm/sample_epoch{epoch}.png"