# Import

In [1]:
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import torch.nn as nn
from sklearn.model_selection import KFold, train_test_split

import segmentation_models_pytorch as smp
import wandb

In [2]:
import sys
sys.path.insert(0, "c:/Users/laish/1_Codes/Image_processing_toolchain")

from API_functions.DL import load_data, log, seed
from API_functions import file_batch as fb

# Hyperparameter and log

In [3]:
my_parameters = {
    'seed': 402,

    'Kfold': None,
    'ratio': 0.2,

    'model': '???',
    'optimizer': 'adam',
    'learning_rate':  0.001,
    'batch_size': 32,
    'loss_function': 'cross_entropy',

    'n_epochs': 1000,
    'patience': 50,
}

device = 'cuda'
mylogger = log.Logger('all')

seed.stablize_seed(my_parameters['seed'])

In [4]:
wandb.init(
    project="U-Net",
    name='5.what???',
    config=my_parameters,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlaishixuan123[0m ([33mlaishixuan123-china-agricultural-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

# Transform

In [5]:
# Define transformations and dataset
transform_common = v2.Compose([
    v2.ToTensor(),
    v2.ToDtype(torch.float32),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomRotation(degrees=90),
])

transform_train_special = v2.Compose([
    v2.ColorJitter(brightness=0.1, contrast=0.1)
])

transform_train = v2.Compose([
    transform_common,
    transform_train_special
])

transform_train_label = transform_common

transform_val = v2.Compose([
    v2.ToTensor(),
    v2.ToDtype(torch.float32)
])

transform_test = transform_val



# Load_data

In [6]:
def load_images(paths):
    return [cv2.imread(p, cv2.IMREAD_GRAYSCALE) for p in paths]

data_paths = fb.get_image_names('g:/DL_Data_raw/version0/train_images/', None, 'png')
labels_paths = fb.get_image_names('g:/DL_Data_raw/version0/train_labels/', None, 'png')
test_paths = fb.get_image_names('g:/DL_Data_raw/version0/test_images/', None, 'png')
test_labels_paths = fb.get_image_names('g:/DL_Data_raw/version0/test_labels/', None, 'png')

data = load_images(data_paths)
labels = load_images(labels_paths)
tests = load_images(test_paths)
test_labels = load_images(test_labels_paths)

train_data, val_data, train_labels, val_labels = train_test_split(data, labels, test_size=my_parameters['ratio'], random_state=my_parameters['seed'])

train_dataset = load_data.my_Dataset(train_data, train_labels, transform=transform_train, label_transform=transform_train_label)
val_dataset = load_data.my_Dataset(val_data, val_labels, transform=transform_val, label_transform=transform_val)
test_dataset = load_data.my_Dataset(tests, test_labels, transform=transform_test, label_transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=my_parameters['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=my_parameters['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=my_parameters['batch_size'], shuffle=False)

print(f'len of train_data: {len(train_data)}, len of val_data: {len(val_data)}, len of test_data: {len(tests)}')

10 images have been found in g:/DL_Data_raw/version0/train_images/
The first 3 images are:
g:/DL_Data_raw/version0/train_images\002_ou_DongYing_12633.png
g:/DL_Data_raw/version0/train_images\002_ou_DongYing_12634.png
g:/DL_Data_raw/version0/train_images\002_ou_DongYing_12635.png
[1;3mGet names completely![0m
10 images have been found in g:/DL_Data_raw/version0/train_labels/
The first 3 images are:
g:/DL_Data_raw/version0/train_labels\002_ou_DongYing_12633.png
g:/DL_Data_raw/version0/train_labels\002_ou_DongYing_12634.png
g:/DL_Data_raw/version0/train_labels\002_ou_DongYing_12635.png
[1;3mGet names completely![0m
5 images have been found in g:/DL_Data_raw/version0/test_images/
The first 3 images are:
g:/DL_Data_raw/version0/test_images\002_ou_DongYing_13635_roi_selected.png
g:/DL_Data_raw/version0/test_images\002_ou_DongYing_13636_roi_selected.png
g:/DL_Data_raw/version0/test_images\002_ou_DongYing_13637_roi_selected.png
[1;3mGet names completely![0m
5 images have been found in g:

# Model

In [7]:
# class DoubleConv(nn.Module):
#     """(convolution => [BN] => ReLU) * 2"""

#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         self.double_conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, 3, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, 3, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         return self.double_conv(x)

# class UNet(nn.Module):
#     def __init__(self):
#         super(UNet, self).__init__()
#         self.dconv_down1 = DoubleConv(1, 64)
#         self.dconv_down2 = DoubleConv(64, 128)
#         self.dconv_down3 = DoubleConv(128, 256)
#         self.dconv_down4 = DoubleConv(256, 512)        

#         self.maxpool = nn.MaxPool2d(2)
#         self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
#         self.dconv_up3 = DoubleConv(256 + 512, 256)
#         self.dconv_up2 = DoubleConv(128 + 256, 128)
#         self.dconv_up1 = DoubleConv(64 + 128, 64)
        
#         self.conv_last = nn.Conv2d(64, 1, 1)
        
#     def forward(self, x):
#         conv1 = self.dconv_down1(x)
#         x = self.maxpool(conv1)
        
#         conv2 = self.dconv_down2(x)
#         x = self.maxpool(conv2)
        
#         conv3 = self.dconv_down3(x)
#         x = self.maxpool(conv3)
        
#         x = self.dconv_down4(x)
        
#         x = self.upsample(x)        
#         x = torch.cat([x, conv3], dim=1)
        
#         x = self.dconv_up3(x)
#         x = self.upsample(x)
#         x = torch.cat([x, conv2], dim=1)

#         x = self.dconv_up2(x)
#         x = self.upsample(x)
#         x = torch.cat([x, conv1], dim=1)

#         x = self.dconv_up1(x)
        
#         x = self.conv_last(x)

#         x = torch.sigmoid(x)  # Apply sigmoid activation to the output
#         return x

In [8]:
class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()
        self.bce = nn.BCELoss()

    def forward(self, inputs, targets, smooth=1):
        # First, calculate the BCE loss
        inputs = torch.sigmoid(inputs)
        bce_loss = self.bce(inputs, targets)
        
        # Calculate Dice Loss
        inputs_flat = inputs.view(-1)
        targets_flat = targets.view(-1)
        
        intersection = (inputs_flat * targets_flat).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs_flat.sum() + targets_flat.sum() + smooth)
        
        # Combine BCE + Dice
        return 0.5 * bce_loss + 0.5 * dice_loss

# Train

In [9]:
model = smp.Unet(
    encoder_name="efficientnet-b4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(device)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to C:\Users\laish/.cache\torch\hub\checkpoints\efficientnet-b4-6ed6700e.pth
100%|██████████| 74.4M/74.4M [03:01<00:00, 430kB/s]  


In [10]:
# model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=my_parameters['learning_rate'])
criterion = DiceBCELoss()

In [None]:
val_loss_best = 100000

for epoch in range(my_parameters['n_epochs']):
    model.train()
    train_loss = 0.0
    
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
    
    train_loss_mean = train_loss / len(train_loader.dataset)


    model.eval()
    val_loss = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)

    val_loss_mean = val_loss / len(val_loader.dataset)
    dict = {'train_loss': train_loss_mean, 'epoch': epoch, 'val_loss': val_loss_mean}
    mylogger.log(dict)

    if val_loss_mean < val_loss_best:
        val_loss_best = val_loss_mean
        torch.save(model.state_dict(), 'model.pth')
        print(f'Model saved at epoch {epoch}, val_loss: {val_loss_mean}')

100%|██████████| 1/1 [00:35<00:00, 35.57s/it]


{'epoch': 0, 'train_loss': 0.9476802349090576, 'val_loss': 1.3187569379806519}
Model saved at epoch 0, val_loss: 1.3187569379806519


100%|██████████| 1/1 [00:37<00:00, 37.14s/it]


{'epoch': 1, 'train_loss': 0.9074574708938599, 'val_loss': 1.3432536125183105}


100%|██████████| 1/1 [00:39<00:00, 39.66s/it]


{'epoch': 2, 'train_loss': 0.8770948648452759, 'val_loss': 1.3600742816925049}


100%|██████████| 1/1 [00:36<00:00, 36.76s/it]


{'epoch': 3, 'train_loss': 0.853986382484436, 'val_loss': 1.2472386360168457}
Model saved at epoch 3, val_loss: 1.2472386360168457


100%|██████████| 1/1 [00:37<00:00, 37.19s/it]


{'epoch': 4, 'train_loss': 0.834517240524292, 'val_loss': 1.107879638671875}
Model saved at epoch 4, val_loss: 1.107879638671875


100%|██████████| 1/1 [00:41<00:00, 41.66s/it]


{'epoch': 5, 'train_loss': 0.817495584487915, 'val_loss': 0.9934738874435425}
Model saved at epoch 5, val_loss: 0.9934738874435425


100%|██████████| 1/1 [00:46<00:00, 46.22s/it]


{'epoch': 6, 'train_loss': 0.7956337928771973, 'val_loss': 0.8791986703872681}
Model saved at epoch 6, val_loss: 0.8791986703872681


100%|██████████| 1/1 [00:40<00:00, 40.59s/it]


{'epoch': 7, 'train_loss': 0.7777528166770935, 'val_loss': 0.7803028225898743}
Model saved at epoch 7, val_loss: 0.7803028225898743


100%|██████████| 1/1 [00:34<00:00, 35.00s/it]


{'epoch': 8, 'train_loss': 0.7601549029350281, 'val_loss': 0.7110450267791748}
Model saved at epoch 8, val_loss: 0.7110450267791748


100%|██████████| 1/1 [00:37<00:00, 37.33s/it]


{'epoch': 9, 'train_loss': 0.7487103939056396, 'val_loss': 0.6679297685623169}
Model saved at epoch 9, val_loss: 0.6679297685623169


100%|██████████| 1/1 [00:39<00:00, 39.51s/it]


{'epoch': 10, 'train_loss': 0.7315433025360107, 'val_loss': 0.6470640301704407}
Model saved at epoch 10, val_loss: 0.6470640301704407


100%|██████████| 1/1 [00:38<00:00, 38.72s/it]


{'epoch': 11, 'train_loss': 0.7211588025093079, 'val_loss': 0.6354414820671082}
Model saved at epoch 11, val_loss: 0.6354414820671082


100%|██████████| 1/1 [00:35<00:00, 35.87s/it]


{'epoch': 12, 'train_loss': 0.708203136920929, 'val_loss': 0.6347847580909729}
Model saved at epoch 12, val_loss: 0.6347847580909729


100%|██████████| 1/1 [00:30<00:00, 30.02s/it]


{'epoch': 13, 'train_loss': 0.6935576796531677, 'val_loss': 0.6403273344039917}


100%|██████████| 1/1 [00:29<00:00, 29.82s/it]


{'epoch': 14, 'train_loss': 0.6886283159255981, 'val_loss': 0.6563809514045715}


100%|██████████| 1/1 [00:30<00:00, 30.06s/it]


{'epoch': 15, 'train_loss': 0.6796246767044067, 'val_loss': 0.6963930130004883}


100%|██████████| 1/1 [00:29<00:00, 29.88s/it]


{'epoch': 16, 'train_loss': 0.6647419929504395, 'val_loss': 0.6976496577262878}


100%|██████████| 1/1 [00:29<00:00, 29.60s/it]


{'epoch': 17, 'train_loss': 0.6568018198013306, 'val_loss': 0.67569500207901}


100%|██████████| 1/1 [00:29<00:00, 29.62s/it]


{'epoch': 18, 'train_loss': 0.6512738466262817, 'val_loss': 0.6709840893745422}


100%|██████████| 1/1 [00:29<00:00, 29.65s/it]


{'epoch': 19, 'train_loss': 0.6444740891456604, 'val_loss': 0.684826672077179}


100%|██████████| 1/1 [00:29<00:00, 29.70s/it]


{'epoch': 20, 'train_loss': 0.6381586790084839, 'val_loss': 0.689124345779419}


100%|██████████| 1/1 [00:29<00:00, 29.75s/it]


{'epoch': 21, 'train_loss': 0.6304693222045898, 'val_loss': 0.6847600936889648}


100%|██████████| 1/1 [00:29<00:00, 29.83s/it]


{'epoch': 22, 'train_loss': 0.6266589164733887, 'val_loss': 0.6771689057350159}


100%|██████████| 1/1 [00:29<00:00, 29.74s/it]


{'epoch': 23, 'train_loss': 0.6193190813064575, 'val_loss': 0.6682742834091187}


100%|██████████| 1/1 [00:29<00:00, 29.99s/it]


{'epoch': 24, 'train_loss': 0.6145834922790527, 'val_loss': 0.6592832207679749}


100%|██████████| 1/1 [00:29<00:00, 29.79s/it]


{'epoch': 25, 'train_loss': 0.6106619834899902, 'val_loss': 0.650810182094574}


100%|██████████| 1/1 [00:29<00:00, 29.76s/it]


{'epoch': 26, 'train_loss': 0.6077152490615845, 'val_loss': 0.6390623450279236}


100%|██████████| 1/1 [00:29<00:00, 29.81s/it]


{'epoch': 27, 'train_loss': 0.601283073425293, 'val_loss': 0.6285349726676941}
Model saved at epoch 27, val_loss: 0.6285349726676941


100%|██████████| 1/1 [00:29<00:00, 29.81s/it]


{'epoch': 28, 'train_loss': 0.5974946618080139, 'val_loss': 0.6211463809013367}
Model saved at epoch 28, val_loss: 0.6211463809013367


100%|██████████| 1/1 [00:29<00:00, 29.59s/it]


{'epoch': 29, 'train_loss': 0.5967226028442383, 'val_loss': 0.615733802318573}
Model saved at epoch 29, val_loss: 0.615733802318573


100%|██████████| 1/1 [00:29<00:00, 29.58s/it]


{'epoch': 30, 'train_loss': 0.5912095308303833, 'val_loss': 0.6105242371559143}
Model saved at epoch 30, val_loss: 0.6105242371559143


100%|██████████| 1/1 [00:29<00:00, 29.60s/it]


{'epoch': 31, 'train_loss': 0.5879970192909241, 'val_loss': 0.6054723858833313}
Model saved at epoch 31, val_loss: 0.6054723858833313


100%|██████████| 1/1 [00:29<00:00, 29.66s/it]


{'epoch': 32, 'train_loss': 0.5842812657356262, 'val_loss': 0.6008486747741699}
Model saved at epoch 32, val_loss: 0.6008486747741699


100%|██████████| 1/1 [00:30<00:00, 30.15s/it]


{'epoch': 33, 'train_loss': 0.5819677710533142, 'val_loss': 0.5966256856918335}
Model saved at epoch 33, val_loss: 0.5966256856918335


100%|██████████| 1/1 [00:29<00:00, 29.55s/it]


{'epoch': 34, 'train_loss': 0.5785409212112427, 'val_loss': 0.5924381613731384}
Model saved at epoch 34, val_loss: 0.5924381613731384


100%|██████████| 1/1 [00:29<00:00, 29.47s/it]


{'epoch': 35, 'train_loss': 0.5759872198104858, 'val_loss': 0.588283896446228}
Model saved at epoch 35, val_loss: 0.588283896446228


100%|██████████| 1/1 [00:29<00:00, 29.65s/it]


{'epoch': 36, 'train_loss': 0.5726929306983948, 'val_loss': 0.5841537714004517}
Model saved at epoch 36, val_loss: 0.5841537714004517


100%|██████████| 1/1 [00:29<00:00, 29.56s/it]


{'epoch': 37, 'train_loss': 0.5696883797645569, 'val_loss': 0.5800882577896118}
Model saved at epoch 37, val_loss: 0.5800882577896118


100%|██████████| 1/1 [00:29<00:00, 29.62s/it]


{'epoch': 38, 'train_loss': 0.5669289231300354, 'val_loss': 0.5761675834655762}
Model saved at epoch 38, val_loss: 0.5761675834655762


100%|██████████| 1/1 [00:29<00:00, 29.61s/it]


{'epoch': 39, 'train_loss': 0.5647851824760437, 'val_loss': 0.5725721716880798}
Model saved at epoch 39, val_loss: 0.5725721716880798


100%|██████████| 1/1 [00:29<00:00, 29.52s/it]


{'epoch': 40, 'train_loss': 0.5644711256027222, 'val_loss': 0.5693520903587341}
Model saved at epoch 40, val_loss: 0.5693520903587341


100%|██████████| 1/1 [00:29<00:00, 29.73s/it]


{'epoch': 41, 'train_loss': 0.559907078742981, 'val_loss': 0.5661208629608154}
Model saved at epoch 41, val_loss: 0.5661208629608154


100%|██████████| 1/1 [00:29<00:00, 29.51s/it]


{'epoch': 42, 'train_loss': 0.5576798319816589, 'val_loss': 0.5630660057067871}
Model saved at epoch 42, val_loss: 0.5630660057067871


100%|██████████| 1/1 [00:29<00:00, 29.83s/it]


{'epoch': 43, 'train_loss': 0.5557718873023987, 'val_loss': 0.5602511167526245}
Model saved at epoch 43, val_loss: 0.5602511167526245


100%|██████████| 1/1 [00:29<00:00, 29.61s/it]


{'epoch': 44, 'train_loss': 0.5537248253822327, 'val_loss': 0.5577045679092407}
Model saved at epoch 44, val_loss: 0.5577045679092407


100%|██████████| 1/1 [00:29<00:00, 29.70s/it]


{'epoch': 45, 'train_loss': 0.5524011254310608, 'val_loss': 0.5554431676864624}
Model saved at epoch 45, val_loss: 0.5554431676864624


100%|██████████| 1/1 [00:29<00:00, 29.81s/it]


{'epoch': 46, 'train_loss': 0.5504438281059265, 'val_loss': 0.5532936453819275}
Model saved at epoch 46, val_loss: 0.5532936453819275


100%|██████████| 1/1 [00:29<00:00, 29.56s/it]


{'epoch': 47, 'train_loss': 0.5488656163215637, 'val_loss': 0.5514654517173767}
Model saved at epoch 47, val_loss: 0.5514654517173767


100%|██████████| 1/1 [00:29<00:00, 29.45s/it]


{'epoch': 48, 'train_loss': 0.5470008850097656, 'val_loss': 0.5496166944503784}
Model saved at epoch 48, val_loss: 0.5496166944503784


100%|██████████| 1/1 [00:29<00:00, 29.56s/it]


{'epoch': 49, 'train_loss': 0.5466821193695068, 'val_loss': 0.5480538606643677}
Model saved at epoch 49, val_loss: 0.5480538606643677


100%|██████████| 1/1 [00:29<00:00, 29.57s/it]


{'epoch': 50, 'train_loss': 0.5457179546356201, 'val_loss': 0.5465207099914551}
Model saved at epoch 50, val_loss: 0.5465207099914551


100%|██████████| 1/1 [00:29<00:00, 29.49s/it]


{'epoch': 51, 'train_loss': 0.5437873601913452, 'val_loss': 0.544998824596405}
Model saved at epoch 51, val_loss: 0.544998824596405


100%|██████████| 1/1 [00:30<00:00, 30.19s/it]


{'epoch': 52, 'train_loss': 0.5431753396987915, 'val_loss': 0.5436098575592041}
Model saved at epoch 52, val_loss: 0.5436098575592041


100%|██████████| 1/1 [00:29<00:00, 29.63s/it]


{'epoch': 53, 'train_loss': 0.5411357283592224, 'val_loss': 0.5420998334884644}
Model saved at epoch 53, val_loss: 0.5420998334884644


100%|██████████| 1/1 [00:29<00:00, 29.61s/it]


{'epoch': 54, 'train_loss': 0.5401890873908997, 'val_loss': 0.5407492518424988}
Model saved at epoch 54, val_loss: 0.5407492518424988


100%|██████████| 1/1 [00:29<00:00, 29.66s/it]


{'epoch': 55, 'train_loss': 0.5396156311035156, 'val_loss': 0.5395828485488892}
Model saved at epoch 55, val_loss: 0.5395828485488892


100%|██████████| 1/1 [00:29<00:00, 29.60s/it]


{'epoch': 56, 'train_loss': 0.5382166504859924, 'val_loss': 0.5386209487915039}
Model saved at epoch 56, val_loss: 0.5386209487915039


100%|██████████| 1/1 [00:29<00:00, 29.47s/it]


{'epoch': 57, 'train_loss': 0.5376737713813782, 'val_loss': 0.5378402471542358}
Model saved at epoch 57, val_loss: 0.5378402471542358


100%|██████████| 1/1 [00:29<00:00, 29.51s/it]


{'epoch': 58, 'train_loss': 0.5367609262466431, 'val_loss': 0.5371078252792358}
Model saved at epoch 58, val_loss: 0.5371078252792358


100%|██████████| 1/1 [00:29<00:00, 29.73s/it]


{'epoch': 59, 'train_loss': 0.5363219380378723, 'val_loss': 0.5363435745239258}
Model saved at epoch 59, val_loss: 0.5363435745239258


100%|██████████| 1/1 [00:29<00:00, 29.74s/it]


{'epoch': 60, 'train_loss': 0.5352732539176941, 'val_loss': 0.5354821085929871}
Model saved at epoch 60, val_loss: 0.5354821085929871


100%|██████████| 1/1 [00:29<00:00, 29.96s/it]


{'epoch': 61, 'train_loss': 0.5348397493362427, 'val_loss': 0.5345341563224792}
Model saved at epoch 61, val_loss: 0.5345341563224792


100%|██████████| 1/1 [00:29<00:00, 29.61s/it]


{'epoch': 62, 'train_loss': 0.5341199636459351, 'val_loss': 0.5335695147514343}
Model saved at epoch 62, val_loss: 0.5335695147514343


100%|██████████| 1/1 [00:29<00:00, 29.62s/it]


{'epoch': 63, 'train_loss': 0.5332762002944946, 'val_loss': 0.5327368974685669}
Model saved at epoch 63, val_loss: 0.5327368974685669


100%|██████████| 1/1 [00:29<00:00, 29.57s/it]


{'epoch': 64, 'train_loss': 0.5328578948974609, 'val_loss': 0.5321544408798218}
Model saved at epoch 64, val_loss: 0.5321544408798218


100%|██████████| 1/1 [00:29<00:00, 29.76s/it]


{'epoch': 65, 'train_loss': 0.5319772362709045, 'val_loss': 0.5315733551979065}
Model saved at epoch 65, val_loss: 0.5315733551979065


100%|██████████| 1/1 [00:29<00:00, 29.50s/it]


{'epoch': 66, 'train_loss': 0.5317907333374023, 'val_loss': 0.5310472249984741}
Model saved at epoch 66, val_loss: 0.5310472249984741


100%|██████████| 1/1 [00:29<00:00, 29.54s/it]


{'epoch': 67, 'train_loss': 0.5306298732757568, 'val_loss': 0.5305555462837219}
Model saved at epoch 67, val_loss: 0.5305555462837219


100%|██████████| 1/1 [00:29<00:00, 29.53s/it]


{'epoch': 68, 'train_loss': 0.5302661061286926, 'val_loss': 0.5301129817962646}
Model saved at epoch 68, val_loss: 0.5301129817962646


100%|██████████| 1/1 [00:29<00:00, 29.76s/it]


{'epoch': 69, 'train_loss': 0.5298645496368408, 'val_loss': 0.5298143625259399}
Model saved at epoch 69, val_loss: 0.5298143625259399


100%|██████████| 1/1 [00:29<00:00, 29.88s/it]


{'epoch': 70, 'train_loss': 0.5298923254013062, 'val_loss': 0.529598593711853}
Model saved at epoch 70, val_loss: 0.529598593711853


100%|██████████| 1/1 [00:29<00:00, 29.65s/it]


{'epoch': 71, 'train_loss': 0.5288488268852234, 'val_loss': 0.5293056964874268}
Model saved at epoch 71, val_loss: 0.5293056964874268


100%|██████████| 1/1 [00:29<00:00, 29.57s/it]


{'epoch': 72, 'train_loss': 0.5295147895812988, 'val_loss': 0.5291252136230469}
Model saved at epoch 72, val_loss: 0.5291252136230469


100%|██████████| 1/1 [00:29<00:00, 29.58s/it]


{'epoch': 73, 'train_loss': 0.5281302332878113, 'val_loss': 0.528865396976471}
Model saved at epoch 73, val_loss: 0.528865396976471


100%|██████████| 1/1 [00:29<00:00, 29.74s/it]


{'epoch': 74, 'train_loss': 0.5275752544403076, 'val_loss': 0.5285676717758179}
Model saved at epoch 74, val_loss: 0.5285676717758179


100%|██████████| 1/1 [00:29<00:00, 29.76s/it]


{'epoch': 75, 'train_loss': 0.5272321701049805, 'val_loss': 0.5283032655715942}
Model saved at epoch 75, val_loss: 0.5283032655715942


100%|██████████| 1/1 [00:29<00:00, 29.57s/it]


{'epoch': 76, 'train_loss': 0.5271674394607544, 'val_loss': 0.5281221866607666}
Model saved at epoch 76, val_loss: 0.5281221866607666


100%|██████████| 1/1 [00:29<00:00, 29.48s/it]


{'epoch': 77, 'train_loss': 0.5260985493659973, 'val_loss': 0.5278968214988708}
Model saved at epoch 77, val_loss: 0.5278968214988708


100%|██████████| 1/1 [00:29<00:00, 29.96s/it]


{'epoch': 78, 'train_loss': 0.5259406566619873, 'val_loss': 0.5276662707328796}
Model saved at epoch 78, val_loss: 0.5276662707328796


100%|██████████| 1/1 [00:29<00:00, 29.64s/it]


{'epoch': 79, 'train_loss': 0.5259298086166382, 'val_loss': 0.527465283870697}
Model saved at epoch 79, val_loss: 0.527465283870697


100%|██████████| 1/1 [00:30<00:00, 30.27s/it]


{'epoch': 80, 'train_loss': 0.5254793167114258, 'val_loss': 0.5272982716560364}
Model saved at epoch 80, val_loss: 0.5272982716560364


100%|██████████| 1/1 [00:29<00:00, 29.65s/it]


{'epoch': 81, 'train_loss': 0.524333119392395, 'val_loss': 0.5271586775779724}
Model saved at epoch 81, val_loss: 0.5271586775779724


100%|██████████| 1/1 [00:30<00:00, 30.09s/it]


{'epoch': 82, 'train_loss': 0.5270789861679077, 'val_loss': 0.5267952084541321}
Model saved at epoch 82, val_loss: 0.5267952084541321


100%|██████████| 1/1 [00:29<00:00, 29.74s/it]


{'epoch': 83, 'train_loss': 0.5266185998916626, 'val_loss': 0.526297926902771}
Model saved at epoch 83, val_loss: 0.526297926902771


100%|██████████| 1/1 [00:29<00:00, 29.74s/it]


{'epoch': 84, 'train_loss': 0.5247071981430054, 'val_loss': 0.5257595777511597}
Model saved at epoch 84, val_loss: 0.5257595777511597


100%|██████████| 1/1 [00:29<00:00, 29.50s/it]


{'epoch': 85, 'train_loss': 0.5251723527908325, 'val_loss': 0.5253893733024597}
Model saved at epoch 85, val_loss: 0.5253893733024597


100%|██████████| 1/1 [00:29<00:00, 29.58s/it]


{'epoch': 86, 'train_loss': 0.5248496532440186, 'val_loss': 0.5250434279441833}
Model saved at epoch 86, val_loss: 0.5250434279441833


100%|██████████| 1/1 [00:29<00:00, 29.75s/it]


{'epoch': 87, 'train_loss': 0.5243903398513794, 'val_loss': 0.5246058702468872}
Model saved at epoch 87, val_loss: 0.5246058702468872


100%|██████████| 1/1 [00:29<00:00, 29.73s/it]


{'epoch': 88, 'train_loss': 0.5242964029312134, 'val_loss': 0.5242103338241577}
Model saved at epoch 88, val_loss: 0.5242103338241577


100%|██████████| 1/1 [00:29<00:00, 29.91s/it]


{'epoch': 89, 'train_loss': 0.5236587524414062, 'val_loss': 0.5238298177719116}
Model saved at epoch 89, val_loss: 0.5238298177719116


100%|██████████| 1/1 [00:29<00:00, 29.52s/it]


{'epoch': 90, 'train_loss': 0.5249167680740356, 'val_loss': 0.523541271686554}
Model saved at epoch 90, val_loss: 0.523541271686554


100%|██████████| 1/1 [00:29<00:00, 29.53s/it]


{'epoch': 91, 'train_loss': 0.5233462452888489, 'val_loss': 0.5231240391731262}
Model saved at epoch 91, val_loss: 0.5231240391731262


100%|██████████| 1/1 [00:29<00:00, 29.54s/it]


{'epoch': 92, 'train_loss': 0.5231229066848755, 'val_loss': 0.5227711796760559}
Model saved at epoch 92, val_loss: 0.5227711796760559


100%|██████████| 1/1 [00:29<00:00, 29.47s/it]


{'epoch': 93, 'train_loss': 0.5230998396873474, 'val_loss': 0.5225046277046204}
Model saved at epoch 93, val_loss: 0.5225046277046204


100%|██████████| 1/1 [00:29<00:00, 29.86s/it]


{'epoch': 94, 'train_loss': 0.5228769183158875, 'val_loss': 0.5222724676132202}
Model saved at epoch 94, val_loss: 0.5222724676132202


  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
wandb.finish()

# Test

In [None]:
def save_image(image, path):
    """Save a tensor as an image."""
    image = image.squeeze().cpu().numpy()
    plt.imsave(path, image, cmap='gray')

def test_model(model, test_loader, device='cuda'):
    # model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Turn off gradients to speed up this part
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            outputs = torch.sigmoid(outputs)  # Apply sigmoid to get values between 0 and 1
            outputs = outputs > 0.5  # Threshold the probabilities to create a binary mask
            
            # Save output images
            for idx, output in enumerate(outputs):
                save_path = f'g:/DL_Data_raw/version0/inference/tests_inference5/002_ou_DongYing_{i*test_loader.batch_size + idx + 13635}_roi_selected.png'
                save_image(output, save_path)

            print(f'Processed batch {i+1}/{len(test_loader)}')


In [None]:
# Test the model
model = smp.Unet(
    encoder_name="efficientnet-b2",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model.load_state_dict(torch.load('model.pth'))
model = model.to(device)
test_model(model, test_loader)

In [None]:
from API_functions import file_compare as fc
%matplotlib qt

db = fc.ImageDatabase()
# image_processor.add_result('pre_processed', tpi.user_threshold(image_processor.image, 160))
zoom = fc.ZoomRegion(350, 450, 100, 200)
db.add_additional_folder('f:/3.Experimental_Data/DL_Data_raw/tests/', 'test_set')
db.add_additional_folder('f:/3.Experimental_Data/DL_Data_raw/tests_inference4/', 'test_inference')
db.add_additional_folder('f:/3.Experimental_Data/DL_Data_raw/test_labels/', 'test_labels')
image_processor = db.get_image_processor('002_ou_DongYing_13636_roi_selected.png')
image_processor.show_images('test_set', 'test_inference', 'test_labels', zoom_region=zoom)