In [51]:
#Libraries
import torch
from torch import nn
from utilities import UpBlock, DownBlock, DoubleConv, GenDLoss
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import BinaryRecall, BinaryPrecision, BinaryF1Score
import numpy as np
import os
import cv2
from tqdm import tqdm
import copy
#Custom Libraries
from resize_image import resize_image

In [52]:
#Set seeds
def seed_everything(seed: int = 42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False
    
seed_everything(42)

In [106]:
#Make Dataset Class
#Class can load any mask as long as the model corresponds to the mask type
class TrainingDataset(Dataset):
    def __init__(self, images, labels, masks=None, augmentation=None, data_size=(512, 512), train=True):
        self.image_paths = [os.path.join(images, img) for img in os.listdir(images)]
        self.label_paths = [os.path.join(labels, lbl) for lbl in os.listdir(labels)]
        self.mask_paths = [os.path.join(masks, mask) for mask in os.listdir(masks)] if masks else None
        self.augmentation = augmentation
        self.data_size = data_size
        self.train = train

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):

        #Apply resizing with padding if image is not expected size
        raw_image = cv2.imread(self.image_paths[idx], cv2.IMREAD_COLOR)

        if (raw_image.shape[0] != self.data_size[0]) or (raw_image.shape[1] != self.data_size[1]): 
            raw_image = resize_image(raw_image, self.data_size[0], self.data_size[1], (0,0,0))
        
        #Read image, label, and mask
        image = cv2.cvtColor(np.array(raw_image), cv2.COLOR_BGR2GRAY)
        label = cv2.imread(self.label_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE) if self.mask_paths else None

        #Convert mask/label to binary for model classification
        label[label > 0] = 1
        if mask is not None:
            mask[mask > 0] = 1
        
        #Add augmentation clause later
        #Add entity recognition clause later if needed
        
        #Convert to tensors
        image = ToTensor()(image).float()
        label = torch.from_numpy(label).long()
        if mask is not None:
            mask = torch.from_numpy(mask).long()

        return image, label, mask

In [105]:
train = TrainingDataset(
    images="/home/tommytang111/data/sem_adult/SEM_split/s250-259/imgs",
    labels="/home/tommytang111/data/sem_adult/SEM_split/s250-259/gts",
    train=True
)
images="/home/tommytang111/data/sem_adult/SEM_split/s250-259/imgs"
a = [os.path.join(images, img) for img in os.listdir(images) if img.lower().endswith('.png')]
b= cv2.imread(a[0], cv2.IMREAD_COLOR)
data_size = (512, 512)

if (b.shape[-1] != data_size[0]) or (b.shape[-2] != data_size[1]):
    print('hi')
b.shape

hi


(512, 512, 3)

In [62]:
#Initialize and load datasets
train = TrainingDataset(
    images="/home/tommytang111/data/sem_adult/SEM_split/s250-259/imgs",
    labels="/home/tommytang111/data/sem_adult/SEM_split/s250-259/gts",
    train=True
)

valid = TrainingDataset(
    images="/home/tommytang111/data/sem_adult/SEM_split/s200-209/imgs",
    labels="/home/tommytang111/data/sem_adult/SEM_split/s200-209/gts",
    train=False
)   

train_dataloader = DataLoader(train, batch_size=8, shuffle=True, num_workers=8, pin_memory=True)
valid_dataloader = DataLoader(valid, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)

In [63]:
#Initialize model and send to gpu
class UNet(nn.Module):
    """UNet Architecture"""
    def __init__(self, out_classes=2, up_sample_mode='conv_transpose', three=False, attend=False, residual=False, scale=False, spatial=False, dropout=0, classes=2):
        """Initialize the UNet model"""
        super(UNet, self).__init__()
        self.three = three
        self.up_sample_mode = up_sample_mode
        self.dropout=dropout

        # Downsampling Path
        self.down_conv1 = DownBlock(1, 64, three=three, spatial=False, residual=residual) # 3 input channels --> 64 output channels
        self.down_conv2 = DownBlock(64, 128, three=three, spatial=spatial, dropout=self.dropout, residual=residual) # 64 input channels --> 128 output channels
        self.down_conv3 = DownBlock(128, 256, spatial=spatial, dropout=self.dropout, residual=residual) # 128 input channels --> 256 output channels
        self.down_conv4 = DownBlock(256, 512, spatial=spatial, dropout=self.dropout, residual=residual) # 256 input channels --> 512 output channels
        # Bottleneck
        self.double_conv = DoubleConv(512, 1024,spatial=spatial, dropout=self.dropout, residual=residual)
        # Upsampling Path
        self.up_conv4 = UpBlock(512 + 1024, 512, self.up_sample_mode, dropout=self.dropout, residual=residual) # 512 + 1024 input channels --> 512 output channels
        self.up_conv3 = UpBlock(256 + 512, 256, self.up_sample_mode, dropout=self.dropout, residual=residual)
        self.up_conv2 = UpBlock(128+ 256, 128, self.up_sample_mode, dropout=self.dropout, residual=residual)
        self.up_conv1 = UpBlock(128 + 64, 64, self.up_sample_mode)
        # Final Convolution
        self.conv_last = nn.Conv2d(64, 1 if classes == 2 else classes, kernel_size=1)
        self.attend = attend
        if scale:
            self.s1, self.s2 = torch.nn.Parameter(torch.ones(1), requires_grad=True), torch.nn.Parameter(torch.ones(1), requires_grad=True) # learn scaling


    def forward(self, x):
        """Forward pass of the UNet model
        x: (16, 1, 512, 512)
        """
        # print(x.shape)
        x, skip1_out = self.down_conv1(x) # x: (16, 64, 256, 256), skip1_out: (16, 64, 512, 512) (batch_size, channels, height, width)    
        x, skip2_out = self.down_conv2(x) # x: (16, 128, 128, 128), skip2_out: (16, 128, 256, 256)
        if self.three: x = x.squeeze(-3)   
        x, skip3_out = self.down_conv3(x) # x: (16, 256, 64, 64), skip3_out: (16, 256, 128, 128)
        x, skip4_out = self.down_conv4(x) # x: (16, 512, 32, 32), skip4_out: (16, 512, 64, 64)
        x = self.double_conv(x) # x: (16, 1024, 32, 32)
        x = self.up_conv4(x, skip4_out) # x: (16, 512, 64, 64)
        x = self.up_conv3(x, skip3_out) # x: (16, 256, 128, 128)
        if self.three: 
            #attention_mode???
            skip1_out = torch.mean(skip1_out, dim=2)
            skip2_out = torch.mean(skip2_out, dim=2)
        x = self.up_conv2(x, skip2_out) # x: (16, 128, 256, 256)
        x = self.up_conv1(x, skip1_out) # x: (16, 64, 512, 512)
        x = self.conv_last(x) # x: (16, 1, 512, 512)
        return x
    
device = torch.device("cuda")    
model = UNet().to(device)

In [64]:
#Initialize loss function and optimizer
loss_fn = GenDLoss()
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6)

In [65]:
#Send evaluation metrics to device
recall = BinaryRecall().to(device)
precision = BinaryPrecision().to(device)
f1 = BinaryF1Score().to(device)

In [66]:
#Define training function
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    train_loss = 0
    num_batches = len(dataloader)
    
    #Reset metrics for each epoch
    recall.reset()
    precision.reset()
    f1.reset()
    
    for batch, (X, y, _) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        #Calculate metrics after converting predictions to binary
        pred_binary = torch.sigmoid(pred) > 0.5
        
        #Update metrics
        recall.update(pred_binary, y)
        precision.update(pred_binary, y)
        f1.update(pred_binary, y)
        
        train_loss += loss.item()
        
    #Compute final metrics per epoch
    train_recall = recall.compute().item()
    train_precision = precision.compute().item()
    train_f1 = f1.compute().item()
    train_loss_per_epoch = train_loss / num_batches 
    
    return train_loss_per_epoch, train_recall.item(), train_precision.item(), train_f1.item()


In [67]:
#Define validation function
def validate(dataloader, model, loss_fn):
    model.eval()
    test_loss = 0
    num_batches = len(dataloader)
    
    #Reset metrics for each epoch
    recall.reset()
    precision.reset()
    f1.reset()
    
    with torch.no_grad():
        for X, y, _ in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            
        #Calculate metrics
        pred_binary = torch.sigmoid(pred) > 0.5
        
        #Update metrics
        recall.update(pred_binary, y)
        precision.update(pred_binary, y)
        f1.update(pred_binary, y)
            
        #Compute final metrics per epoch
        val_recall = recall.compute().item()
        val_precision = precision.compute().item()
        val_f1 = f1.compute().item()
        val_loss_per_epoch = test_loss / num_batches

    return val_loss_per_epoch, val_recall.item(), val_precision.item(), val_f1.item()
    print(f"Avg loss: {test_loss:>7f}\n")
    

In [None]:
#Training loop
epochs = 100
best_f1 = 0.0
best_epoch = 0

for epoch in range(epochs):
    progress_bar = tqdm(total=len(train_dataloader), position=0, leave=True)
    print(f"Epoch {epoch+1}\n-------------------------------")
    
    #Training
    train_loss, train_recall, train_precision, train_f1 = train(train_dataloader, model, loss_fn, optimizer)

    #Validation
    val_loss, val_recall, val_precision, val_f1 = validate(valid_dataloader, model, loss_fn)

    #Update learning rate scheduler
    scheduler.step(val_loss)
    
    #Print metrics
    print(f"Train | Loss: {train_loss:.4f}, Recall: {train_recall:.4f}, Precision: {train_precision:.4f}, F1: {train_f1:.4f}")
    print(f"Val   | Loss: {val_loss:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}, F1: {val_f1:.4f}")
    
    #Log best model state based on F1 score
    if val_f1 > best_f1:
        best_f1 = val_f1
        best_epoch = epoch
        best_model_state = copy.deepcopy(model.state_dict())
    
print("Training Complete!")

  0%|          | 0/5 [06:05<?, ?it/s]

Epoch 1
-------------------------------





error: Caught error in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_13856/1957242918.py", line 25, in __getitem__
    image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2GRAY)
cv2.error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'cvtColor'
> Overload resolution failed:
>  - src is not a numpy array, neither a scalar
>  - Expected Ptr<cv::UMat> for argument 'src'



Process Process-21:
Traceback (most recent call last):
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/multiprocessing/queues.py", line 199, in _finalize_join
    thread.join()
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/threading.py", line 1096, in join
    self._wait_for_tstate_lock()
  File "/home/tommytang111/.conda/envs/gap_junction/lib/python3.10/threadi

In [None]:
#Save the best logged model state
torch.save(best_model_state, "/home/tommytang111/gap-junction-segmentationmodels/unet_v1.pt")
print("Saved PyTorch Model to unet_v1.pt")