In [1]:
# Load a few packages first 

In [2]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import wandb
from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss

In [3]:
#dir_img = Path('./data/imgs/')
#dir_mask = Path('./data/masks/')
#dir_checkpoint = Path('./checkpoints/')


dir_img = 'Z:/Dongyu Fan/2. Data/ImageProcessing/Simulation/2024-09-19/16-57/img/'
dir_mask = 'Z:/Dongyu Fan/2. Data/ImageProcessing/Simulation/2024-09-19/16-57/mask/'


In [4]:
from datetime import datetime
 # Define the base path
base_path = 'Z:/Dongyu Fan/2. Data/ImageProcessing/training'
# Get current date and time
now = datetime.now()
date_str = now.strftime('%Y-%m-%d')  # Format for date (e.g., '2024-08-26')
time_str = now.strftime('%m-%d_%H-%M')  # Format for time (e.g., '08-26_15-30')

# Create directory paths
date_folder = os.path.join(base_path, date_str)
time_folder = os.path.join(date_folder, time_str)

# Create directories
os.makedirs(time_folder, exist_ok=True)  # exist_ok=True to avoid error if the directory already exists

dir_checkpoint = time_folder
wandb_dir = time_folder

In [5]:
def train_model(
        model,
        device,
        epochs: int = 5,
        batch_size: int = 1,
        learning_rate: float = 1e-5,
        val_percent: float = 0.1,
        save_checkpoint: bool = True,
        img_scale: float = 0.5,
        amp: bool = False,
        weight_decay: float = 1e-8,
        momentum: float = 0.999,
        gradient_clipping: float = 1.0,
):
    
    # 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
    except (AssertionError, RuntimeError, IndexError):
        dataset = BasicDataset(dir_img, dir_mask, img_scale)

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must', dir = wandb_dir)
    experiment.config.update(
        dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
             val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
    )

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(model.parameters(),
                              lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    masks_pred = model(images)
                    if model.n_classes == 1:
                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                    else:
                        loss = criterion(masks_pred, true_masks)
                        loss += dice_loss(
                            F.softmax(masks_pred, dim=1).float(),
                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                            multiclass=True
                        )

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (5 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            tag = tag.replace('/', '.')
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(model, val_loader, device, amp)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        try:
                            experiment.log({
                                'learning rate': optimizer.param_groups[0]['lr'],
                                'validation Dice': val_score,
                                'images': wandb.Image(images[0].cpu()),
                                'masks': {
                                    'true': wandb.Image(true_masks[0].float().cpu()),
                                    'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                                },
                                'step': global_step,
                                'epoch': epoch,
                                **histograms
                            })
                        except:
                            pass

        if save_checkpoint and epoch % 10 == 0:

            # Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            state_dict = model.state_dict()
            state_dict['mask_values'] = dataset.mask_values
            #torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
            checkpoint_path = os.path.join(dir_checkpoint, f'checkpoint_epoch{epoch}.pth')
            torch.save(state_dict, os.path.join(dir_checkpoint, checkpoint_path))
            logging.info(f'Checkpoint {epoch} saved!')
    
    # 6. Save the model 
    model_path = os.path.join(dir_checkpoint, f'model.pth')
    torch.save(model, model_path)
    logging.info(f'Model for epoch {epoch} saved at {model_path}')


In [6]:
def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args()



In [8]:
#args = get_args()

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
model = UNet(n_channels=1, n_classes=2, bilinear=False)
model = model.to(memory_format=torch.channels_last)

logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

model.to(device=device)
try:
    train_model(
        model=model,
        epochs=20,
        batch_size=8,
        learning_rate=1e-5,
        device=device,
        img_scale=1.,
        val_percent= 10 / 100,
        amp=False
    )
except torch.cuda.OutOfMemoryError:
    logging.error('Detected OutOfMemoryError! '
                    'Enabling checkpointing to reduce memory usage, but this slows down training. '
                    'Consider enabling AMP (--amp) for fast and memory efficient training')
    torch.cuda.empty_cache()
    model.use_checkpointing()
    train_model(
        model=model,
        epochs=5,
        batch_size=1,
        learning_rate=1e-5,
        device=device,
        img_scale=1,
        val_percent=0.5 / 100,
        amp=False
    )


INFO: Using device cuda
INFO: Network:
	1 input channels
	2 output channels (classes)
	Transposed conv upscaling
INFO: Creating dataset with 1041 examples
INFO: Scanning mask files to determine unique values
100%|██████████| 1041/1041 [01:10<00:00, 14.72it/s]
INFO: Unique mask values: [0.0, 1.0]
ERROR: Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: dyfan777 (llgroup). Use `wandb login --relogin` to force relogin


INFO: Starting training:
        Epochs:          20
        Batch size:      8
        Learning rate:   1e-05
        Training size:   937
        Validation size: 104
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1.0
        Mixed Precision: False
    
Epoch 1/20:  20%|█▉        | 184/937 [00:25<00:46, 16.10img/s, loss (batch)=0.334]
Epoch 1/20:  20%|█▉        | 184/937 [00:38<00:46, 16.10img/s, loss (batch)=0.334]
Validation round:   8%|▊         | 1/13 [00:31<06:17, 31.49s/batch]
Validation round:  15%|█▌        | 2/13 [00:31<02:23, 13.04s/batch]
Validation round: 100%|██████████| 13/13 [00:31<00:00,  1.28s/batch]
                                                                    INFO: Validation Dice score: 0.8085708022117615
Epoch 1/20:  39%|███▉      | 368/937 [01:01<00:16, 34.37img/s, loss (batch)=0.217]
Epoch 1/20:  39%|███▉      | 368/937 [01:18<00:16, 34.37img/s, loss (batch)=0.217]
Validation round:   8%|▊         | 1/13 [00:25<05:07,

In [6]:
dataset = CarvanaDataset(dir_img, dir_mask, 1)

100%|██████████| 1000/1000 [01:02<00:00, 15.98it/s]


In [7]:
train_set, val_set = random_split(dataset, [900, 100], generator=torch.Generator().manual_seed(0))

In [10]:
train_loader = DataLoader(train_set, shuffle=True, )
val_loader = DataLoader(val_set, shuffle=False, drop_last=True,)

In [11]:
for batch in train_loader:
    batch

KeyboardInterrupt: 

0.0

In [13]:
filename = dir_mask + 'img_1_mask.mat'
import h5py
from PIL import Image
mat_file = h5py.File(filename, 'r')
keys = list(mat_file.keys())
if not keys:
    raise ValueError("No datasets found in the .mat file.")
dataset_name = keys[0]
imgfile = Image.fromarray(mat_file[dataset_name][:])  # ONly specifically for this file. Change it to read any file in the mat file
mat_file.close()

In [10]:
filename_m = dir_img + 'img_1.jpg'
img = Image.open(filename_m)
img2 = torch.from_numpy(BasicDataset.preprocess(None, img, 1, is_mask=False))

NameError: name 'Image' is not defined

'Z:/Dongyu Fan/2. Data/ImageProcessing/Simulation/2024-09-18/17-52/mask/mask_1.jpg'

In [14]:
from PIL import Image
filename = './data/imgs/0cdf5b5d0ce1_01.jpg'
img = Image.open(filename)

img2 = torch.from_numpy(BasicDataset.preprocess(None, img, 1, is_mask=False))

In [19]:
filename_m = './data/masks/fff9b3a5373f_14_mask.gif'
mask = Image.open(filename_m)
mask2 = torch.from_numpy(BasicDataset.preprocess([0,1], img, 1, is_mask=True))

In [20]:
mask2

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])