## Imports

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch import optim

In [None]:
from mscnet import *
from loss import *
from dataloaders import *
from metrics import *

In [None]:
torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## Network training

### STARE custom loss

In [None]:
n_epoch = 100

train_losses = []
eval_losses = []

net_custom_loss_stare = MSCNet(n_channels = 1, n_classes = 1)
net_custom_loss_stare.to(device)

dataset = DataLoaderSTARE("data/stare")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)


optimizer = optim.Adam(net_custom_loss_stare.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.9)


def train(epoch):
  running_loss = 0.0
  net_custom_loss_stare.train()

  for (image, mask_seg1, mask_seg2, mask_center) in tqdm(train_loader):
    image = image.to(device)
    mask_seg1 = mask_seg1.to(device)
    mask_seg2 = mask_seg2.to(device)
    mask_center = mask_center.to(device)

    optimizer.zero_grad()
    output_top, output_bot = net_custom_loss_stare(image)


    criterion = CustomLoss(0.2, 0.5, 0.1, threshold=0.63)
    loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1), torch.squeeze(output_bot), torch.squeeze(mask_center))

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  train_loss=running_loss/len(train_loader)
  train_losses.append(train_loss)
  
def test(epoch):
  running_loss = 0.0

  net_custom_loss_stare.eval()  

  with torch.no_grad():
    for (image, mask_seg1, mask_seg2, mask_center) in tqdm(test_loader):
      image = image.to(device)
      mask_seg1 = mask_seg1.to(device)
      mask_seg2 = mask_seg2.to(device)
      mask_center = mask_center.to(device)
      
      output_top, output_bot = net_custom_loss_stare(image)

      criterion = CustomLoss(0.2, 0.5, 0.1, threshold=0.63)
      loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1), torch.squeeze(output_bot), torch.squeeze(mask_center))

      running_loss+=loss.item()
      
  
  test_loss=running_loss/len(test_loader)

  eval_losses.append(test_loss)



for epoch in range(n_epoch+1):
  train(epoch)
  test(epoch)
  scheduler.step(eval_losses[-1])
  print(f"Epoch {epoch+1}, training_loss = {train_losses[-1]:.6f}, val_loss = {eval_losses[-1]:.6f}")
  if epoch % 50 == 0:
    torch.save(net_custom_loss_stare.state_dict(), f'stare_custom_loss_stare_{epoch}.pth')



### Results preparation

In [None]:
ses = []
sps = []
accs = []
aucs = []

net_custom_loss_stare = MSCNet(n_channels = 1, n_classes = 1)
net_custom_loss_stare.to(device)

dataset = DataLoaderSTARE("data/stare")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

net_custom_loss_stare.load_state_dict(torch.load("drive/MyDrive/wb/stare_custom_loss_stare_100.pth"))


for (image, mask_seg1, mask_seg2, mask_center) in test_loader:
  image = image.to(device)
  output_top, output_bot = net_custom_loss_stare(image)


  mask_top = torch.squeeze(output_top.cpu().detach()).numpy()
  mask_bot = torch.squeeze(output_bot.cpu().detach()).numpy()

  mask_seg1 = torch.squeeze(mask_seg1.cpu().detach()).numpy()
  mask_center = torch.squeeze(mask_center.cpu().detach()).numpy()

  print("--------------------------------------------")

  plt.hist(mask_top.flatten())
  plt.title("Histogram of segmentation mask")
  plt.show()

  plt.hist(mask_top.flatten())
  plt.title("Histogram of centerline extraction mask")
  plt.show()    


  ses.append(SE(mask_top, mask_seg1))
  sps.append(SP(mask_top, mask_seg1))
  accs.append(accuracy(mask_top, mask_seg1))
  aucs.append(AUC(mask_top, mask_seg1))

  print("Metrics for segmentation: ")
  print("SE = ",ses[-1])
  print("SP = ",sps[-1])
  print("accuracy = ",accs[-1])
  print("AUC = ", aucs[-1])


  print("Metrics for centerline extraction: ")
  print("SE = ",SE(mask_bot, mask_center))
  print("SP = ",SP(mask_bot, mask_center))
  print("accuracy = ",accuracy(mask_bot, mask_center))
  print("AUC = ", AUC(mask_bot, mask_center))




  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_top.cpu().detach()).numpy(), cmap="gray")
  plt.title("Segmentation mask (logits)")
  plt.show()

  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_bot.cpu().detach()).numpy(), cmap="gray")
  plt.title("Centerline extraction mask (logits)")
  plt.show()



print("DRIVE, custom loss SE = {:.4f}".format(sum(ses)/len(ses)))
print("DRIVE, custom loss SP = {:.4f}".format(sum(sps)/len(sps)))
print("DRIVE, custom loss ACC = {:.4f}".format(sum(accs)/len(accs)))
print("DRIVE, custom loss AUC = {:.4f}".format(sum(aucs)/len(aucs)))




### STARE BCE loss

In [None]:
n_epoch = 100

train_losses = []
eval_losses = []

net_BCE_loss_stare = MSCNet(n_channels = 1, n_classes = 1)
net_BCE_loss_stare.to(device)

dataset = DataLoaderSTARE("data/stare")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

optimizer = optim.Adam(net_BCE_loss_stare.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.9)


def train(epoch):
  running_loss = 0.0
  net_BCE_loss_stare.train()

  for (image, mask_seg1, mask_seg2, mask_center) in tqdm(train_loader):
    image = image.to(device)
    mask_seg1 = mask_seg1.to(device)
    mask_seg2 = mask_seg2.to(device)
    mask_center = mask_center.to(device)

    optimizer.zero_grad()
    output_top, output_bot = net_BCE_loss_stare(image)


    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1)) + criterion(torch.squeeze(output_bot), torch.squeeze(mask_seg2))

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  train_loss=running_loss/len(train_loader)
  train_losses.append(train_loss)
  
def test(epoch):
  running_loss = 0.0

  net_BCE_loss_stare.eval()  

  with torch.no_grad():
    for (image, mask_seg1, mask_seg2, mask_center) in tqdm(test_loader):
      image = image.to(device)
      mask_seg1 = mask_seg1.to(device)
      mask_seg2 = mask_seg2.to(device)
      mask_center = mask_center.to(device)
      
      output_top, output_bot = net_BCE_loss_stare(image)

      criterion = nn.BCEWithLogitsLoss()
      loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1)) + criterion(torch.squeeze(output_bot), torch.squeeze(mask_seg2)) 

      running_loss+=loss.item()
      
  
  test_loss=running_loss/len(test_loader)

  eval_losses.append(test_loss)



for epoch in range(n_epoch+1):
  train(epoch)
  test(epoch)
  scheduler.step(eval_losses[-1])
  print(f"Epoch {epoch+1}, training_loss = {train_losses[-1]:.6f}, val_loss = {eval_losses[-1]:.6f}")
  if epoch % 50 == 0:
    torch.save(net_BCE_loss_stare.state_dict(), f'stare_BCE_loss_stare_{epoch}.pth')



### Results preparation

In [None]:
ses = []
sps = []
accs = []
aucs = []

net_BCE_loss_stare = MSCNet(n_channels = 1, n_classes = 1)
net_BCE_loss_stare.to(device)

dataset = DataLoaderSTARE("data/stare")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

net_BCE_loss_stare.load_state_dict(torch.load("drive/MyDrive/wb/stare_BCE_loss_stare_100.pth"))

for (image, mask_seg1, mask_seg2, mask_center) in test_loader:
  image = image.to(device)
  output_top, output_bot = net_BCE_loss_stare(image)


  mask_top = torch.squeeze(output_top.cpu().detach()).numpy()
  mask_bot = torch.squeeze(output_bot.cpu().detach()).numpy()

  mask_seg1 = torch.squeeze(mask_seg1.cpu().detach()).numpy()
  mask_seg2 = torch.squeeze(mask_seg2.cpu().detach()).numpy()

  print("--------------------------------------------")

  plt.hist(mask_top.flatten())
  plt.title("Histogram of segmentation mask")
  plt.show()

  plt.hist(mask_top.flatten())
  plt.title("Histogram of negative segmentation mask")
  plt.show()    


  ses.append(SE(mask_top, mask_seg1))
  sps.append(SP(mask_top, mask_seg1))
  accs.append(accuracy(mask_top, mask_seg1))
  aucs.append(AUC(mask_top, mask_seg1))

  print("Metrics for segmentation: ")
  print("SE = ",ses[-1])
  print("SP = ",sps[-1])
  print("accuracy = ",accs[-1])
  print("AUC = ", aucs[-1])



  print("Metrics for negative segmentation: ")
  print("SE = ",SE(mask_bot, mask_seg2))
  print("SP = ",SP(mask_bot, mask_seg2))
  print("accuracy = ",accuracy(mask_bot, mask_seg2))
  print("AUC = ", AUC(mask_bot, mask_seg2))




  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_top.cpu().detach()).numpy(), cmap="gray")
  plt.title("Segmentation mask (logits)")
  plt.show()

  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_bot.cpu().detach()).numpy(), cmap="gray")
  plt.title("Negative segmentation mask (logits)")
  plt.show()

print("DRIVE, custom loss SE = {:.4f}".format(sum(ses)/len(ses)))
print("DRIVE, custom loss SP = {:.4f}".format(sum(sps)/len(sps)))
print("DRIVE, custom loss ACC = {:.4f}".format(sum(accs)/len(accs)))
print("DRIVE, custom loss AUC = {:.4f}".format(sum(aucs)/len(aucs)))




### CHASE custom loss

In [None]:
n_epoch = 100

train_losses = []
eval_losses = []

net_custom_loss_chase = MSCNet(n_channels = 1, n_classes = 1, inner_sizes=(59, 114, 234, 474), output_size=960)
net_custom_loss_chase.to(device)

dataset = DataLoaderCHASE("data/chase")
train, test = torch.utils.data.random_split(dataset, [21,7], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

optimizer = optim.Adam(net_custom_loss_chase.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.9)


def train(epoch):
  running_loss = 0.0
  net_custom_loss_chase.train()

  for (image, mask_seg1, mask_seg2, mask_center) in tqdm(train_loader):
    image = image.to(device)
    mask_seg1 = mask_seg1.to(device)
    mask_seg2 = mask_seg2.to(device)
    mask_center = mask_center.to(device)

    optimizer.zero_grad()
    output_top, output_bot = net_custom_loss_chase(image)


    criterion = CustomLoss(0.2, 0.5, 0.1, threshold=0.63)
    loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1), torch.squeeze(output_bot), torch.squeeze(mask_center))

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  train_loss=running_loss/len(train_loader)
  train_losses.append(train_loss)
  
def test(epoch):
  running_loss = 0.0

  net_custom_loss_chase.eval()  

  with torch.no_grad():
    for (image, mask_seg1, mask_seg2, mask_center) in tqdm(test_loader):
      image = image.to(device)
      mask_seg1 = mask_seg1.to(device)
      mask_seg2 = mask_seg2.to(device)
      mask_center = mask_center.to(device)
      
      output_top, output_bot = net_custom_loss_chase(image)

      criterion = CustomLoss(0.2, 0.5, 0.1, threshold=0.63)
      loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1), torch.squeeze(output_bot), torch.squeeze(mask_center))

      running_loss+=loss.item()
      
  
  test_loss=running_loss/len(test_loader)

  eval_losses.append(test_loss)



for epoch in range(n_epoch+1):
  train(epoch)
  test(epoch)
  scheduler.step(eval_losses[-1])
  print(f"Epoch {epoch+1}, training_loss = {train_losses[-1]:.6f}, val_loss = {eval_losses[-1]:.6f}")
  if epoch % 50 == 0:
    torch.save(net_custom_loss_chase.state_dict(), f'custom_loss_chase_{epoch}.pth')



### Results preparation

In [None]:
for (image, mask_seg1, mask_seg2, mask_center) in train_loader:
  image = image.to(device)
  output_top, output_bot = net_custom_loss_chase(image)


  mask_top = torch.squeeze(output_top.cpu().detach()).numpy()
  mask_bot = torch.squeeze(output_bot.cpu().detach()).numpy()

  mask_seg1 = torch.squeeze(mask_seg1.cpu().detach()).numpy()
  mask_center = torch.squeeze(mask_center.cpu().detach()).numpy()

  print("--------------------------------------------")

  plt.hist(mask_top.flatten())
  plt.title("Histogram of segmentation mask")
  plt.show()

  plt.hist(mask_top.flatten())
  plt.title("Histogram of centerline extraction mask")
  plt.show()    


  print("Metrics for segmentation: ")
  print("SE = ",SE(mask_top, mask_seg1))
  print("SP = ",SP(mask_top, mask_seg1))
  print("accuracy = ",accuracy(mask_top, mask_seg1))
  print("AUC = ", AUC(mask_top, mask_seg1))


  print("Metrics for centerline extraction: ")
  print("SE = ",SE(mask_bot, mask_center))
  print("SP = ",SP(mask_bot, mask_center))
  print("accuracy = ",accuracy(mask_bot, mask_center))
  print("AUC = ", AUC(mask_bot, mask_center))




  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_top.cpu().detach()).numpy(), cmap="gray")
  plt.title("Segmentation mask (logits)")
  plt.show()

  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_bot.cpu().detach()).numpy(), cmap="gray")
  plt.title("Centerline extraction mask (logits)")
  plt.show()





### CHASE BCE loss

In [None]:
n_epoch = 100

train_losses = []
eval_losses = []

net_BCE_loss_chase = MSCNet(n_channels = 1, n_classes = 1, inner_sizes=(59, 114, 234, 474), output_size=960)
net_BCE_loss_chase.to(device)

dataset = DataLoaderCHASE("data/chase")
train, test = torch.utils.data.random_split(dataset, [21,7], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

optimizer = optim.Adam(net_BCE_loss_chase.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.9)


def train(epoch):
  running_loss = 0.0
  net_BCE_loss_chase.train()

  for (image, mask_seg1, mask_seg2, mask_center) in tqdm(train_loader):
    image = image.to(device)
    mask_seg1 = mask_seg1.to(device)
    mask_seg2 = mask_seg2.to(device)
    mask_center = mask_center.to(device)

    optimizer.zero_grad()
    output_top, output_bot = net_BCE_loss_chase(image)


    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1)) + criterion(torch.squeeze(output_bot), torch.squeeze(mask_seg2)) 

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  train_loss=running_loss/len(train_loader)
  train_losses.append(train_loss)
  
def test(epoch):
  running_loss = 0.0

  net_BCE_loss_chase.eval()  

  with torch.no_grad():
    for (image, mask_seg1, mask_seg2, mask_center) in tqdm(test_loader):
      image = image.to(device)
      mask_seg1 = mask_seg1.to(device)
      mask_seg2 = mask_seg2.to(device)
      mask_center = mask_center.to(device)
      
      output_top, output_bot = net_BCE_loss_chase(image)

      criterion = nn.BCEWithLogitsLoss()
      loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1)) + criterion(torch.squeeze(output_bot), torch.squeeze(mask_seg2)) 


      running_loss+=loss.item()
      
  
  test_loss=running_loss/len(test_loader)

  eval_losses.append(test_loss)



for epoch in range(n_epoch+1):
  train(epoch)
  test(epoch)
  scheduler.step(eval_losses[-1])
  print(f"Epoch {epoch+1}, training_loss = {train_losses[-1]:.6f}, val_loss = {eval_losses[-1]:.6f}")
  if epoch % 50 == 0:
    torch.save(net_BCE_loss_chase.state_dict(), f'BCE_loss_chase_{epoch}.pth')



### Results preparation

In [None]:
for (image, mask_seg1, mask_seg2, mask_center) in train_loader:
  image = image.to(device)
  output_top, output_bot = net_BCE_loss_chase(image)


  mask_top = torch.squeeze(output_top.cpu().detach()).numpy()
  mask_bot = torch.squeeze(output_bot.cpu().detach()).numpy()

  mask_seg1 = torch.squeeze(mask_seg1.cpu().detach()).numpy()
  mask_seg2 = torch.squeeze(mask_seg2.cpu().detach()).numpy()

  print("--------------------------------------------")

  plt.hist(mask_top.flatten())
  plt.title("Histogram of segmentation mask")
  plt.show()

  plt.hist(mask_top.flatten())
  plt.title("Histogram of negative segmentation mask")
  plt.show()    


  print("Metrics for segmentation: ")
  print("SE = ",SE(mask_top, mask_seg1))
  print("SP = ",SP(mask_top, mask_seg1))
  print("accuracy = ",accuracy(mask_top, mask_seg1))
  print("AUC = ", AUC(mask_top, mask_seg1))


  print("Metrics for negative segmentation: ")
  print("SE = ",SE(mask_bot, mask_seg2))
  print("SP = ",SP(mask_bot, mask_seg2))
  print("accuracy = ",accuracy(mask_bot, mask_seg2))
  print("AUC = ", AUC(mask_bot, mask_seg2))




  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_top.cpu().detach()).numpy(), cmap="gray")
  plt.title("Segmentation mask (logits)")
  plt.show()

  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_bot.cpu().detach()).numpy(), cmap="gray")
  plt.title("Negative segmentation mask (logits)")
  plt.show()





### DRIVE custom loss

In [None]:
n_epoch = 100

train_losses = []
eval_losses = []

net_custom_loss_drive = MSCNet(n_channels = 1, n_classes = 1, inner_sizes=(35, 66, 138, 282), output_size=576)
net_custom_loss_drive.to(device)

dataset = DataLoaderDRIVE("data/drive")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

optimizer = optim.Adam(net_custom_loss_drive.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.9)


def train(epoch):
  running_loss = 0.0
  net_custom_loss_drive.train()

  for (image, mask_seg1, mask_seg2, mask_center) in tqdm(train_loader):
    image = image.to(device)
    mask_seg1 = mask_seg1.to(device)
    mask_seg2 = mask_seg2.to(device)
    mask_center = mask_center.to(device)

    optimizer.zero_grad()
    output_top, output_bot = net_custom_loss_drive(image)


    criterion = CustomLoss(0.2, 0.5, 0.1, threshold=0.63)
    loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1), torch.squeeze(output_bot), torch.squeeze(mask_center))

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  train_loss=running_loss/len(train_loader)
  train_losses.append(train_loss)
  
def test(epoch):
  running_loss = 0.0

  net_custom_loss_drive.eval()  

  with torch.no_grad():
    for (image, mask_seg1, mask_seg2, mask_center) in tqdm(test_loader):
      image = image.to(device)
      mask_seg1 = mask_seg1.to(device)
      mask_seg2 = mask_seg2.to(device)
      mask_center = mask_center.to(device)
      
      output_top, output_bot = net_custom_loss_drive(image)

      criterion = CustomLoss(0.2, 0.5, 0.1, threshold=0.63)
      loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1), torch.squeeze(output_bot), torch.squeeze(mask_center))

      running_loss+=loss.item()
      
  
  test_loss=running_loss/len(test_loader)

  eval_losses.append(test_loss)



for epoch in range(n_epoch+1):
  train(epoch)
  test(epoch)
  scheduler.step(eval_losses[-1])
  print(f"Epoch {epoch+1}, training_loss = {train_losses[-1]:.6f}, val_loss = {eval_losses[-1]:.6f}")
  if epoch % 50 == 0:
    torch.save(net_custom_loss_drive.state_dict(), f'custom_loss_drive_{epoch}.pth')



### Results preparation

In [None]:
ses = []
sps = []
accs = []
aucs = []

net_custom_loss_drive = MSCNet(n_channels = 1, n_classes = 1, inner_sizes=(35, 66, 138, 282), output_size=576)
net_custom_loss_drive.to(device)

dataset = DataLoaderDRIVE("data/drive")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

net_custom_loss_drive.load_state_dict(torch.load("drive/MyDrive/wb/custom_loss_drive_100.pth"))


for (image, mask_seg1, mask_seg2, mask_center) in test_loader:
  image = image.to(device)
  output_top, output_bot = net_custom_loss_drive(image)


  mask_top = torch.squeeze(output_top.cpu().detach()).numpy()
  mask_bot = torch.squeeze(output_bot.cpu().detach()).numpy()

  mask_seg1 = torch.squeeze(mask_seg1.cpu().detach()).numpy()
  mask_center = torch.squeeze(mask_center.cpu().detach()).numpy()

  print("--------------------------------------------")

  plt.hist(mask_top.flatten())
  plt.title("Histogram of segmentation mask")
  plt.show()

  plt.hist(mask_bot.flatten())
  plt.title("Histogram of centerline extraction mask")
  plt.show()    


  ses.append(SE(mask_top, mask_seg1))
  sps.append(SP(mask_top, mask_seg1))
  accs.append(accuracy(mask_top, mask_seg1))
  aucs.append(AUC(mask_top, mask_seg1))

  print("Metrics for segmentation: ")
  print("SE = ",ses[-1])
  print("SP = ",sps[-1])
  print("accuracy = ",accs[-1])
  print("AUC = ", aucs[-1])


  print("Metrics for centerline extraction: ")
  print("SE = ",SE(mask_bot, mask_center))
  print("SP = ",SP(mask_bot, mask_center))
  print("accuracy = ",accuracy(mask_bot, mask_center))
  print("AUC = ", AUC(mask_bot, mask_center))




  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_top.cpu().detach()).numpy(), cmap="gray")
  plt.title("Segmentation mask (logits)")
  plt.show()

  plt.figure(figsize=(16,9))
  plt.imshow(torch.squeeze(output_bot.cpu().detach()).numpy(), cmap="gray")
  plt.title("Centerline extraction mask (logits)")
  plt.show()

print("DRIVE, custom loss SE = {:.4f}".format(sum(ses)/len(ses)))
print("DRIVE, custom loss SP = {:.4f}".format(sum(sps)/len(sps)))
print("DRIVE, custom loss ACC = {:.4f}".format(sum(accs)/len(accs)))
print("DRIVE, custom loss AUC = {:.4f}".format(sum(aucs)/len(aucs)))




### DRIVE BCE loss

In [None]:
n_epoch = 100

train_losses = []
eval_losses = []

net_BCE_loss_drive = MSCNet(n_channels = 1, n_classes = 1, inner_sizes=(35, 66, 138, 282), output_size=576)
net_BCE_loss_drive.to(device)

dataset = DataLoaderDRIVE("data/drive")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

optimizer = optim.Adam(net_BCE_loss_drive.parameters(), lr = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.9)


def train(epoch):
  running_loss = 0.0
  net_BCE_loss_drive.train()

  for (image, mask_seg1, mask_seg2, mask_center) in tqdm(train_loader):
    image = image.to(device)
    mask_seg1 = mask_seg1.to(device)
    mask_seg2 = mask_seg2.to(device)
    mask_center = mask_center.to(device)

    optimizer.zero_grad()
    output_top, output_bot = net_BCE_loss_drive(image)


    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1)) + criterion(torch.squeeze(output_bot), torch.squeeze(mask_seg2)) 

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  train_loss=running_loss/len(train_loader)
  train_losses.append(train_loss)
  
def test(epoch):
  running_loss = 0.0

  net_BCE_loss_drive.eval()  

  with torch.no_grad():
    for (image, mask_seg1, mask_seg2, mask_center) in tqdm(test_loader):
      image = image.to(device)
      mask_seg1 = mask_seg1.to(device)
      mask_seg2 = mask_seg2.to(device)
      mask_center = mask_center.to(device)
      
      output_top, output_bot = net_BCE_loss_drive(image)

      criterion = nn.BCEWithLogitsLoss()
      loss = criterion(torch.squeeze(output_top), torch.squeeze(mask_seg1)) + criterion(torch.squeeze(output_bot), torch.squeeze(mask_seg2)) 


      running_loss+=loss.item()
      
  
  test_loss=running_loss/len(test_loader)

  eval_losses.append(test_loss)



for epoch in range(n_epoch+1):
  train(epoch)
  test(epoch)
  scheduler.step(eval_losses[-1])
  print(f"Epoch {epoch+1}, training_loss = {train_losses[-1]:.6f}, val_loss = {eval_losses[-1]:.6f}")
  if epoch % 50 == 0:
    torch.save(net_BCE_loss_drive.state_dict(), f'BCE_loss_drive_{epoch}.pth')



### Results preparation

In [None]:
ses = []
sps = []
accs = []
aucs = []

net_BCE_loss_drive = MSCNet(n_channels = 1, n_classes = 1, inner_sizes=(35, 66, 138, 282), output_size=576)
net_BCE_loss_drive.to(device)

dataset = DataLoaderDRIVE("data/drive")
train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

net_BCE_loss_drive.load_state_dict(torch.load("drive/MyDrive/wb/BCE_loss_drive_100.pth"))


for (image, mask_seg1, mask_seg2, mask_center) in test_loader:
  image = image.to(device)
  output_top, output_bot = net_BCE_loss_drive(image)


  mask_top = torch.squeeze(output_top.cpu().detach()).numpy()
  mask_bot = torch.squeeze(output_bot.cpu().detach()).numpy()

  mask_seg1 = torch.squeeze(mask_seg1.cpu().detach()).numpy()
  mask_seg2 = torch.squeeze(mask_seg2.cpu().detach()).numpy()

  print("--------------------------------------------")

  plt.hist(mask_top.flatten())
  plt.title("Histogram of segmentation mask")
  plt.show()

  plt.hist(mask_top.flatten())
  plt.title("Histogram of negative segmentation mask")
  plt.show()    


  ses.append(SE(mask_top, mask_seg1))
  sps.append(SP(mask_top, mask_seg1))
  accs.append(accuracy(mask_top, mask_seg1))
  aucs.append(AUC(mask_top, mask_seg1))

  print("Metrics for segmentation: ")
  print("SE = ",ses[-1])
  print("SP = ",sps[-1])
  print("accuracy = ",accs[-1])
  print("AUC = ", aucs[-1])



  print("Metrics for negative segmentation: ")
  print("SE = ",SE(mask_bot, mask_seg2))
  print("SP = ",SP(mask_bot, mask_seg2))
  print("accuracy = ",accuracy(mask_bot, mask_seg2))
  print("AUC = ", AUC(mask_bot, mask_seg2))


  plt.figure(figsize=(16,9))
  plt.imshow(mask_top, cmap="gray")
  plt.title("Segmentation mask (logits)")
  plt.show()

  plt.figure(figsize=(16,9))
  plt.imshow(mask_bot, cmap="gray")
  plt.title("Negative segmentation mask (logits)")
  plt.show()



print("DRIVE, custom loss SE = {:.4f}".format(sum(ses)/len(ses)))
print("DRIVE, custom loss SP = {:.4f}".format(sum(sps)/len(sps)))
print("DRIVE, custom loss ACC = {:.4f}".format(sum(accs)/len(accs)))
print("DRIVE, custom loss AUC = {:.4f}".format(sum(aucs)/len(aucs)))
