In [1]:
# %env ALL_PROXY=http://127.0.0.1:33001
# %env HTTP_PROXY=http://127.0.0.1:33001
# %env HTTPS_PROXY=http://127.0.0.1:33001
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
# !curl google.com

# Import

If change the model, in training step, you need to 
1. change the log chapter
2. change the relative model, loss function, encoder etc.

In [3]:
import cv2
import sys
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
import segmentation_models_pytorch as smp

# sys.path.insert(0, "/root/Soil-Column-Procedures")
sys.path.insert(0, "c:/Users/laish/1_Codes/Image_processing_toolchain/")

from tqdm import tqdm
from torch.utils.data import DataLoader
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold, train_test_split
from src.API_functions.DL import load_data, log, seed, evaluate
from src.API_functions.Images import file_batch as fb
from src.workflow_tools.model_online import fr_unet

import wandb

  check_for_updates()


# Hyperparameter and log

In [4]:
my_parameters = {
    'seed': 3407,

    'Kfold': None,
    'ratio': 0.25,
    'n_epochs': 1000,
    'patience': 50,

    'model': 'U-Net',      # model = 'U-Net', 'DeepLabv3+', 'PSPNet'
    'encoder': 'efficientnet-b2',
    'optimizer': 'adam',
    'loss_function': 'cross_entropy',
    'learning_rate': 0.001,     # Initial learning rate
    'batch_size': 8,

    'scheduler': 'reduce_on_plateau',   # Type of scheduler
    'scheduler_patience': 10,   # Number of epochs to wait before reducing LR
    'scheduler_factor': 0.5,    # Factor by which to reduce LR
    'scheduler_min_lr': 1e-6,   # Minimum learning rate

    'wandb': '34.diceloss_weight_up',  # Name of the wandb run
}

device = 'cuda'
mylogger = log.DataLogger('wandb')  # 'wandb' or 'all'

seed.stablize_seed(my_parameters['seed'])

# Transform

In [5]:
# For training data
transform_train = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=90, p=0.5),
    A.GaussNoise(p=0.5),
    # A.OpticalDistortion(p=0.5),
    ToTensorV2(),
], seed=my_parameters['seed'])

# For validation and test data
transform_val = A.Compose([
    # A.HorizontalFlip(p=0.5),
    # A.VerticalFlip(p=0.5),
    # A.Rotate(limit=90, p=0.5),
    # A.GaussNoise(p=0.5),
    ToTensorV2(),
], seed=my_parameters['seed'])

# Model

In [6]:
# model = smp.PSPNet(
#     encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,                      # model output channels (number of classes in your dataset)
# )
# model = model.to(device)

In [7]:
model = smp.Unet(
    encoder_name="efficientnet-b2",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)

In [8]:
# model = smp.DeepLabV3Plus(
#     encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
#     encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
#     in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
#     classes=1,                      # model output channels (number of classes in your dataset)
# )

# model.load_state_dict(torch.load(f"c:/Users/laish/1_Codes/Image_processing_toolchain/src/workflow_tools/model_DeepLabv3+_23.drive_again.pth", weights_only=True))
# model = model.to(device)

In [9]:
# model = fr_unet.FR_UNet(num_channels=1, num_classes=1, feature_scale=2, dropout=0.2, fuse=True, out_ave=True)
# model = model.to(device)

In [10]:

# Freeze encoder parameters
# for param in model.encoder.parameters():
    # param.requires_grad = False

# model = model.to(device)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=my_parameters['learning_rate'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=my_parameters['scheduler_factor'],
    patience=my_parameters['scheduler_patience'],
    min_lr=my_parameters['scheduler_min_lr']
)
criterion = evaluate.DiceBCELoss()

# Train

The codes below are only for training.

In test step, you need to proceed the codes above and the test chapter code.

## Wandb

In [12]:
wandb.init(
    project="U-Net",
    name=my_parameters['wandb'],
    config=my_parameters,
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: laishixuan123 (laishixuan123-china-agricultural-university). Use `wandb login --relogin` to force relogin


## Load_data

In [13]:
data_paths = fb.get_image_names(r'g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_image', None, 'tif')
labels_paths = fb.get_image_names(r'g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_label', None, 'tif')

data = fb.read_images(data_paths, 'gray', read_all=True)
labels = fb.read_images(labels_paths, 'gray', read_all=True)

# Preprocessing is now handled in the Dataset class
train_data, val_data, train_labels, val_labels = train_test_split(data, labels, test_size=my_parameters['ratio'], random_state=my_parameters['seed'])

train_dataset = load_data.my_Dataset(train_data, train_labels, transform=transform_train)
val_dataset = load_data.my_Dataset(val_data, val_labels, transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=my_parameters['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=my_parameters['batch_size'], shuffle=False)

print(f'len of train_data: {len(train_data)}, len of val_data: {len(val_data)}')

57 images have been found in g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_image
The first 3 images are:
g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_image\0003-02130-patch-00009.tif
g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_image\0003-02130-patch-00010.tif
g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_image\0003-02130-patch-00011.tif
Get names completely!
57 images have been found in g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_label
The first 3 images are:
g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_label\0003-02130-patch-00009.tif
g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_label\0003-02130-patch-00010.tif
g:\DL_Data_raw\version4-classes\7.Final_dataset\train_val_label\0003-02130-patch-00011.tif
Get names completely!


100%|██████████| 57/57 [00:00<00:00, 314.04it/s]


57 images have been read
Reading completely!


100%|██████████| 57/57 [00:00<00:00, 530.19it/s]


57 images have been read
Reading completely!
len of train_data: 42, len of val_data: 15


## Train

In [14]:
val_loss_best = 100000
proceed_once = True  # Add a flag

for epoch in range(my_parameters['n_epochs']):
    model.train()
    train_loss = 0.0

    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)

        # Checking the dimension of the outputs and labels
        if outputs.dim() == 4 and outputs.size(1) == 1:
            outputs = outputs.squeeze(1)
        
        # Only proceed once:
        if proceed_once:
            print(f'outputs.size(): {outputs.size()}, labels.size(): {labels.size()}')
            print(f'outputs.min: {outputs.min()}, outputs.max: {outputs.max()}')
            print(f'images.min: {images.min()}, images.max: {images.max()}')
            print(f'labels.min: {labels.min()}, labels.max: {labels.max()}')
            print(f'count of label 0: {(labels == 0).sum()}, count of label 1:{(labels == 1).sum()}')
            print('')
            proceed_once = False  # Set the flag to False after proceeding once
        
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
    
    train_loss_mean = train_loss / len(train_loader.dataset)


    model.eval()
    val_loss = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            if outputs.dim() == 4 and outputs.size(1) == 1:
                outputs = outputs.squeeze(1)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)

    val_loss_mean = val_loss / len(val_loader.dataset)
    current_lr = optimizer.param_groups[0]['lr']
    dict = {
        'train_loss': train_loss_mean,
        'epoch': epoch,
        'val_loss': val_loss_mean,
        'learning_rate': current_lr
    }
    mylogger.log(dict)

    # Step the scheduler
    scheduler.step(val_loss_mean)

    if val_loss_mean < val_loss_best:
        val_loss_best = val_loss_mean
        torch.save(model.state_dict(), f"model_{my_parameters['model']}_{my_parameters['wandb']}.pth")
        print(f'Model saved at epoch {epoch:.3f}, val_loss: {val_loss_mean:.3f}')

  0%|          | 0/6 [00:00<?, ?it/s]

outputs.size(): torch.Size([8, 512, 512]), labels.size(): torch.Size([8, 512, 512])
outputs.min: -8.037663459777832, outputs.max: 5.503569602966309
images.min: -0.2830659747123718, images.max: 1.0
labels.min: 0.0, labels.max: 1.0
count of label 0: 2032433, count of label 1:64719



100%|██████████| 6/6 [00:03<00:00,  1.65it/s]


Model saved at epoch 0.000, val_loss: 0.810


100%|██████████| 6/6 [00:02<00:00,  2.25it/s]


Model saved at epoch 1.000, val_loss: 0.756


100%|██████████| 6/6 [00:02<00:00,  2.27it/s]


Model saved at epoch 2.000, val_loss: 0.694


100%|██████████| 6/6 [00:02<00:00,  2.27it/s]


Model saved at epoch 3.000, val_loss: 0.675


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.33it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.42it/s]
100%|██████████| 6/6 [00:02<00:00,  2.48it/s]
100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.27it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.18it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 6/6 [00:02<00:00,  2.23it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
100%|██████████| 6/6 [00:02<00:00,

Model saved at epoch 36.000, val_loss: 0.674


100%|██████████| 6/6 [00:02<00:00,  2.23it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]


Model saved at epoch 39.000, val_loss: 0.667


100%|██████████| 6/6 [00:02<00:00,  2.23it/s]


Model saved at epoch 40.000, val_loss: 0.655


100%|██████████| 6/6 [00:02<00:00,  2.21it/s]


Model saved at epoch 41.000, val_loss: 0.623


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]


Model saved at epoch 42.000, val_loss: 0.612


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.27it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]


Model saved at epoch 45.000, val_loss: 0.606


100%|██████████| 6/6 [00:02<00:00,  2.44it/s]


Model saved at epoch 46.000, val_loss: 0.593


100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.43it/s]


Model saved at epoch 49.000, val_loss: 0.561


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.18it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.27it/s]


Model saved at epoch 55.000, val_loss: 0.504


100%|██████████| 6/6 [00:02<00:00,  2.48it/s]


Model saved at epoch 56.000, val_loss: 0.456


100%|██████████| 6/6 [00:02<00:00,  2.46it/s]
100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
100%|██████████| 6/6 [00:02<00:00,  2.46it/s]
100%|██████████| 6/6 [00:02<00:00,  2.52it/s]
100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
100%|██████████| 6/6 [00:02<00:00,  2.28it/s]
100%|██████████| 6/6 [00:02<00:00,  2.22it/s]


Model saved at epoch 65.000, val_loss: 0.455


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]


Model saved at epoch 66.000, val_loss: 0.415


100%|██████████| 6/6 [00:02<00:00,  2.17it/s]
100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
100%|██████████| 6/6 [00:02<00:00,  2.24it/s]


Model saved at epoch 70.000, val_loss: 0.353


100%|██████████| 6/6 [00:02<00:00,  2.05it/s]


Model saved at epoch 71.000, val_loss: 0.326


100%|██████████| 6/6 [00:02<00:00,  2.22it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]


Model saved at epoch 73.000, val_loss: 0.281


100%|██████████| 6/6 [00:02<00:00,  2.15it/s]
100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.17it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.40it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 6/6 [00:02<00:00,  2.22it/s]
100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 6/6 [00:02<00:00,  2.32it/s]
100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.20it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,

Model saved at epoch 104.000, val_loss: 0.280


100%|██████████| 6/6 [00:02<00:00,  2.42it/s]


Model saved at epoch 105.000, val_loss: 0.275


100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.51it/s]


Model saved at epoch 107.000, val_loss: 0.251


100%|██████████| 6/6 [00:02<00:00,  2.45it/s]


Model saved at epoch 108.000, val_loss: 0.246


100%|██████████| 6/6 [00:02<00:00,  2.48it/s]
100%|██████████| 6/6 [00:02<00:00,  2.45it/s]


Model saved at epoch 110.000, val_loss: 0.228


100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 6/6 [00:02<00:00,  2.54it/s]
100%|██████████| 6/6 [00:02<00:00,  2.53it/s]
100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]


Model saved at epoch 115.000, val_loss: 0.217


100%|██████████| 6/6 [00:02<00:00,  2.38it/s]


Model saved at epoch 116.000, val_loss: 0.201


100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 6/6 [00:02<00:00,  2.27it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.22it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.30it/s]


Model saved at epoch 125.000, val_loss: 0.198


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]


Model saved at epoch 126.000, val_loss: 0.193


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]


Model saved at epoch 127.000, val_loss: 0.174


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]


Model saved at epoch 132.000, val_loss: 0.169


100%|██████████| 6/6 [00:02<00:00,  2.30it/s]


Model saved at epoch 133.000, val_loss: 0.162


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]


Model saved at epoch 134.000, val_loss: 0.159


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.30it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]


Model saved at epoch 137.000, val_loss: 0.158


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]


Model saved at epoch 138.000, val_loss: 0.154


100%|██████████| 6/6 [00:02<00:00,  2.21it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.28it/s]


Model saved at epoch 144.000, val_loss: 0.142


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]


Model saved at epoch 145.000, val_loss: 0.138


100%|██████████| 6/6 [00:02<00:00,  2.39it/s]


Model saved at epoch 146.000, val_loss: 0.135


100%|██████████| 6/6 [00:02<00:00,  2.24it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.33it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]


Model saved at epoch 152.000, val_loss: 0.132


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]


Model saved at epoch 153.000, val_loss: 0.128


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
100%|██████████| 6/6 [00:02<00:00,  2.37it/s]
100%|██████████| 6/6 [00:02<00:00,  2.39it/s]
100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 6/6 [00:02<00:00,  2.33it/s]
100%|██████████| 6/6 [00:02<00:00,  2.31it/s]
100%|██████████| 6/6 [00:02<00:00,  2.29it/s]
100%|██████████| 6/6 [00:02<00:00,  2.07it/s]
100%|██████████| 6/6 [00:02<00:00,  2.22it/s]
100%|██████████| 6/6 [00:02<00:00,  2.35it/s]
100%|██████████| 6/6 [00:02<00:00,  2.34it/s]
100%|██████████| 6/6 [00:02<00:00,  2.33it/s]
100%|██████████| 6/6 [00:02<00:00,  2.22it/s]


Model saved at epoch 170.000, val_loss: 0.125


100%|██████████| 6/6 [00:02<00:00,  2.29it/s]


Model saved at epoch 171.000, val_loss: 0.123


100%|██████████| 6/6 [00:02<00:00,  2.25it/s]
100%|██████████| 6/6 [00:02<00:00,  2.16it/s]
100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 6/6 [00:02<00:00,  2.51it/s]
100%|██████████| 6/6 [00:02<00:00,  2.51it/s]
100%|██████████| 6/6 [00:02<00:00,  2.48it/s]
100%|██████████| 6/6 [00:02<00:00,  2.46it/s]


Model saved at epoch 178.000, val_loss: 0.121


100%|██████████| 6/6 [00:02<00:00,  2.27it/s]
100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
100%|██████████| 6/6 [00:02<00:00,  2.28it/s]


Model saved at epoch 181.000, val_loss: 0.119


100%|██████████| 6/6 [00:02<00:00,  2.31it/s]


Model saved at epoch 182.000, val_loss: 0.118


100%|██████████| 6/6 [00:02<00:00,  2.35it/s]


Model saved at epoch 183.000, val_loss: 0.117


100%|██████████| 6/6 [00:02<00:00,  2.34it/s]


Model saved at epoch 184.000, val_loss: 0.117


100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.43it/s]


Model saved at epoch 186.000, val_loss: 0.117


100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 6/6 [00:02<00:00,  2.46it/s]
100%|██████████| 6/6 [00:02<00:00,  2.50it/s]
100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
100%|██████████| 6/6 [00:02<00:00,  2.48it/s]
100%|██████████| 6/6 [00:02<00:00,  2.43it/s]
100%|██████████| 6/6 [00:02<00:00,  2.51it/s]
100%|██████████| 6/6 [00:02<00:00,  2.49it/s]
100%|██████████| 6/6 [00:02<00:00,  2.49it/s]


Model saved at epoch 197.000, val_loss: 0.115


100%|██████████| 6/6 [00:02<00:00,  2.48it/s]
100%|██████████| 6/6 [00:02<00:00,  2.50it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
100%|██████████| 6/6 [00:02<00:00,  2.41it/s]
100%|██████████| 6/6 [00:02<00:00,  2.53it/s]
100%|██████████| 6/6 [00:02<00:00,  2.51it/s]
100%|██████████| 6/6 [00:02<00:00,  2.47it/s]
100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 6/6 [00:02<00:00,  2.56it/s]
100%|██████████| 6/6 [00:02<00:00,  2.54it/s]
100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 6/6 [00:02<00:00,  2.53it/s]
100%|██████████| 6/6 [00:02<00:00,  2.55it/s]
100%|██████████| 6/6 [00:02<00:00,  2.53it/s]
100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 6/6 [00:02<00:00,  2.44it/s]
100%|██████████| 6/6 [00:02<00:00,  2.57it/s]
100%|██████████| 6/6 [00:02<00:00,  2.46it/s]
100%|██████████| 6/6 [00:02<00:00,  2.45it/s]
100%|██████████| 6/6 [00:02<00:00,  2.36it/s]
100%|██████████| 6/6 [00:02<00:00,

KeyboardInterrupt: 

In [15]:
wandb.finish()

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
learning_rate,██▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,▄█▆▅▅▃▄▃▄▃▆▃▆▄▆▄▅▆▃▅█▆▄▄▅▄▄▇▃▅▅▁▃▁▂▆▂▃▄▂
val_loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
learning_rate,0.0
train_loss,0.35818
val_loss,0.12396


# Test

**Code below do not been used anymore!**

In [None]:
def test_model(model, test_loader, test_names, device='cuda'):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Turn off gradients to speed up this part
        loss = []
        dice = []
        soft_dice = []
        bce_loss = []
        iou = []
        f1_score = []

        proceed_once = True  # Add a flag

        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Checking the dimension of the outputs and labels
            if outputs.dim() == 4 and outputs.size(1) == 1:
                outputs = outputs.squeeze(1)

            # Only proceed once:
            if proceed_once:
                print(f'outputs.size(): {outputs.size()}, labels.size(): {labels.size()}')
                print(f'outputs.min: {outputs.min()}, outputs.max: {outputs.max()}')
                print(f'labels.min: {labels.min()}, labels.max: {labels.max()}')
                print(f'count of label 0: {(labels == 0).sum()}, count of label 1:{(labels == 1).sum()}')
                print('')
                proceed_once = False

            # Calculate loss indexes 
            # 1.bce and dice loss, because the criterion function is used in train process, so it has a sigmoid function inside, should be put before sigmoid
            loss.append(criterion(outputs, labels).item())

            # 2.Calculate dice, soft_dice, bce_loss 
            outputs = torch.sigmoid(outputs)  # Apply sigmoid to get values between 0 and 1

            dice.append(evaluate.dice_coefficient(outputs, labels))
            soft_dice.append(evaluate.soft_dice_coefficient(labels, outputs))
            calculate_bce = nn.BCELoss()
            bce_loss.append(calculate_bce(outputs, labels))
            iou.append(evaluate.iou(pred=outputs, target=labels, n_classes=2))
            f1_score.append(evaluate.f1_score(pred=outputs, gt=labels))

            # Save output images
            outputs = outputs > 0.5  # Threshold the probabilities to create a binary mask
            for j, img in enumerate(outputs):
                save_path = f'/root/Soil-Column-Procedures/data/version1/inference/1/'
                save_path = save_path + test_names[j]
                output_np = img.cpu().numpy().astype(np.uint8) * 255  # Convert to numpy array and scale to 0-255
                cv2.imwrite(save_path, output_np)

            print(f'Processed batch {i+1}/{len(test_loader)}')
        
        loss_avg = sum(loss) / len(test_loader)
        dice_avg = sum(dice) / len(test_loader)
        soft_dice_avg = sum(soft_dice) / len(test_loader)
        bce_loss_avg = sum(bce_loss) / len(test_loader)
        iou_avg = sum(iou) / len(test_loader)
        f1_score_avg = sum(f1_score) / len(test_loader)
        print(f'Loss: {loss_avg:.3f}, Dice: {dice_avg:.3f}, soft_dice: {soft_dice_avg:.3f}, BCE Loss: {bce_loss_avg:.3f}, IOU: {iou_avg:.3f}, f1_score: {f1_score_avg:.3f}')
