# Training UNet for Diffusion Prediction


## Table of Contents

- [1. Imports & Model Setup](#Imports-&-Model-Setup)  
- [2. Data Loading & Check](#Data-Loading-&-Check)  
- [3. Training Loop](#Training-Loop)  

# Imports-&-Model-Setup

In [None]:
# Import packages 
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
import wandb
import matplotlib.pyplot as plt
import time

# Internal import 
from model import unet_locD
from utils.data_loader import data_loader
from utils.loss_calculator import calculate_loss
from utils.train import train

In [None]:
# Set up the folders

dir_img = 'path/imgpadding/'  # Set where the image files are, by default should be in the folder imgpadding
dir_label ='path/labelpading' # Set where the image files are, by default should be in the folder labelpading
dir_pair = 'path/pairpading/' # Set where the image files are, by default should be in the folder pairpading


label_suffix = '_loc'
pair_suffix = '_pair'


# Set up the parameter:
D_range = [0.01, 2]  # The Diffusion efficient range this data is simulated with. For normalization purposes.

# Data-Loading-&-Check

In [None]:
# Import the model: 
model = unet_locD(n_channels=3, n_classes=1, bilinear=False)

In [None]:
# Import the data
data = data_loader(dir_img, dir_label,dir_pair, label_suffix, pair_suffix)
dataset = DataLoader(data)

In [None]:
# Take one slice of data
i = 0 
example_frame = 1 # Set which one to take
for batch in dataset:
    images, labels = batch['image'], batch['label']
    i += 1
    if i == example_frame:
        break

In [None]:
print(f"The shape of the label is: {labels.shape}")
print(f"The shape of the image is: {images.shape}")

# The shape of the label should be 14 * 64 * 64 
# The shape of the image should be 3 (channels)  * 64 * 64
# If the shape is not right, adjust transpose in the data loader. 

In [None]:

# Take a look of the data

plt.imshow(np.sum(images.detach().numpy()[0],0), alpha = 0.5)
plt.show()
label2 = np.sum(labels.detach().numpy()[0],0)
label2 = label2.transpose(1,0)
plt.imshow(label2, alpha = 0.5)


# If the label and image doesn't match each other, check the transpose in data loader.

# Training-Loop

In [None]:
# Set up a folder for storing the training model. 
from datetime import datetime
 # Define the base path
base_path = 'path/training'  # Set up where you'd like the save the trained model.


# Get current date and time
now = datetime.now()
date_str = now.strftime('%Y-%m-%d')  # Format for date (e.g., '2024-08-26')
time_str = now.strftime('%m-%d_%H-%M')  # Format for time (e.g., '08-26_15-30')

# Create directory paths
date_folder = os.path.join(base_path, date_str)
time_folder = os.path.join(date_folder, time_str)

# Create directories
os.makedirs(time_folder, exist_ok=True)  # exist_ok=True to avoid error if the directory already exists

dir_checkpoint = time_folder
wandb_dir = time_folder

In [None]:
# Start training the model 

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

# Change here to adapt to your data
# n_classes is the number of probabilities you want to get per pixel
model = unet_locD(n_channels=3, n_classes=1, bilinear=False)   # Change channel accordingly 
model = model.to(memory_format=torch.channels_last)


logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

model.to(device=device)
try:
   predction, truth, images, loss = train(
    #masks_pred_bi, masks_pred = test_train(
        model=model,
        dir_img = dir_img, 
        dir_label = dir_label,
        dir_pair = dir_pair, 
        label_suffix = label_suffix, 
        pair_suffix = pair_suffix,
        wandb_dir = wandb_dir,
        dir_checkpoint = dir_checkpoint,
        epochs=100,
        batch_size=8,
        learning_rate=1e-5,   # It was 1e-5
        device=device,
        val_percent=0.1,
        amp=False,
        wandb_log = True,
        save_checkpoint= True

    )

except torch.cuda.OutOfMemoryError:
    logging.error('Detected OutOfMemoryError! '
                    'Enabling checkpointing to reduce memory usage, but this slows down training. '
                    'Consider enabling AMP (--amp) for fast and memory efficient training')

In [None]:
wandb.finish()

In [10]:
# Move to weight and bias to check the training results. 