In [1]:
import torch
from torch.utils.data import DataLoader
import train_methods as tm
from torch.optim import Adam
from tqdm import tqdm
from train_methods import train, validation, EarlyStopping, WeightedBCEWithLogitsLoss, SegmentationDataset
from UNet import Unet
import json
import os
import nibabel as nib
import numpy as np

In [2]:

C = 16
epochs = 5
learning_rate = 1e-2
act_batch_size = 1  # Can't fit more than one image in GPU!
eff_batch_size = 1  # Efective batch (Gradient accumulation)
momentum = 0.99
device = 'cpu'
channels = [C, 64, 128, 256, 512, 1024]
w0 = 10
sigma = 5
model_path = './model.pt'
Test_image_dir = 'imagesTr'
Test_label_dir = 'labelsTr'

In [3]:
# Early stopping
es = tm.EarlyStopping(patience=100, fname=model_path)

imagesTR_len = len(os.listdir("/home/or/PycharmProjects/AMOS/imagesTr"))
print(f"Number of files in imagesTr: {imagesTR_len}")


# Make progress bars
pbar_epoch = tqdm(total=epochs, unit='epoch', position=0, leave=False)
pbar_train = tqdm(total=imagesTR_len, unit='batch', position=1, leave=False)

# Make model
model = Unet(channels=channels, no_classes=1).double().to(device)

# Make adam optimiser
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.001,  # Learning rate
    betas=(0.9, 0.999),  # Coefficients used for computing running averages of gradient and its square
    eps=1e-08,  # Term added to denominator to improve numerical stability
    weight_decay=0  # Weight decay (L2 penalty)
)

# Make loss
criterion = tm.WeightedBCEWithLogitsLoss(batch_size=act_batch_size)

Number of files in imagesTr: 240


  0%|          | 0/5 [00:00<?, ?epoch/s]
  0%|          | 0/240 [00:00<?, ?batch/s][A

In [4]:
# Load checkpoint (if it exists)
cur_epoch = 0
if os.path.isfile(model_path):
    checkpoint = torch.load(model_path)
    cur_epoch = checkpoint['epoch']
    es.best_loss = checkpoint['loss']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Hold stats for training process
stats = {'epoch': [], 'train_loss': [], 'val_loss': []}

# Training  / validation loop
for epoch in range(cur_epoch, epochs):
    print(f"Epoch {epoch + 1}/{epochs}")

    # Train / validate
    pbar_epoch.set_description_str(f'Epoch {epoch + 1}')
    train_loss = tm.train(model, optimizer, Test_image_dir, Test_label_dir, criterion, device, pbar_train)
    val_loss = tm.validation(model, Test_image_dir, Test_label_dir, criterion, device)


    # Append stats
    stats['epoch'].append(epoch)
    stats['train_loss'].append(train_loss)
    stats['val_loss'].append(val_loss)

    # Early stopping (just saves model if validation loss decreases when: pass)
    if es(epoch, val_loss, optimizer, model): pass

    # Update progress bars
    pbar_epoch.set_postfix(train_loss=train_loss, val_loss=val_loss)
    pbar_epoch.update(1)
    pbar_train.reset()

Epoch 1:   0%|          | 0/5 [00:00<?, ?epoch/s]

Epoch 1/5
image length and lable length: 240 240
image shape: (768, 768, 90) (768, 768, 90)
label shape: (768, 768, 90)


  return (image - np.min(image)) * (max_val - min_val) / (np.max(image) - np.min(image)) + min_val


RuntimeError: Input type (float) and bias type (double) should be the same

In [7]:
import os

# Define directories
Test_image_dir = 'imagesTr'
Test_label_dir = 'labelsTr'

# Get list of file names in each directory
image_files = set(os.listdir(Test_image_dir))
label_files = set(os.listdir(Test_label_dir))

# Find files in imagesTr but not in labelsTr
extra_image_files = image_files - label_files

# Print the results
if extra_image_files:
    print("Files in 'imagesTr' but not in 'labelsTr':")
    for file in extra_image_files:
        print(file)
else:
    print("No extra files in 'imagesTr'.")


Files in 'imagesTr' but not in 'labelsTr':
.DS_Store


In [6]:
# load model

model      = Unet(channels = channels, no_classes = 1).double().to(device)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Make loss
criterion = engine.WeightedBCEWithLogitsLoss(batch_size = act_batch_size)
criterion = nn.BCEWithLogitsLoss()

with torch.no_grad():

    for batch_id, (X, y, weights) in enumerate(test_loader):

        # Forward
        y_hat = model(X)
        y_hat = torch.sigmoid(y_hat)


        # Convert to numpy
        X = np.squeeze(X.cpu().numpy())
        y = np.squeeze(y.cpu().numpy())
        w = np.squeeze(weights.cpu().numpy())
        y_hat = np.squeeze(y_hat.detach().cpu().numpy())

        # Make mask
        y_hat2 = y_hat > 0.5

        # plot
        fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (8, 8))

        ax[0].imshow(y, 'gray', interpolation = None)
        ax[0].axis('off');
        ax[0].set_title('Target');

        ax[1].imshow(y_hat, 'gray', interpolation = None)
        ax[1].axis('off');
        ax[1].set_title('Prediction');

  checkpoint = torch.load(model_path)


FileNotFoundError: [Errno 2] No such file or directory: './model.pt'