#### Wandb Hyperparameter (Agent) Sweep

In [13]:
#Libraries
import sys
# Add parent directory to sys.path
parent_dir = '/home/tommytang111/gap-junction-segmentation/code/src'
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

import torch
from torch import nn
from utilities import UpBlock, DownBlock, DoubleConv, GenDLoss, FocalLoss
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import BinaryRecall, BinaryPrecision, BinaryF1Score
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
from tqdm.notebook import tqdm #Change to tqdm.tqdm if not using Jupyter Notebook
import copy
import wandb
#Custom Libraries
from resize_image import resize_image

#### Set Reproducible Seeds

In [17]:
def seed_everything(seed: int = 42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False
    
def worker_init_fn(worker_id):
    seed = 42 + worker_id
    np.random.seed(seed)
    torch.manual_seed(seed)
    
seed_everything(42)

#### Define Augmentation Options

In [21]:
# Custom augmentation
def get_custom_augmentation():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Affine(scale=(0.9,1.1), rotate=10, translate_percent=0.15, shear = (-5, 5), p=0.9),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.2, p=0.5),
        A.GaussNoise(p=0.3),
        A.Normalize(mean=0.0, std=1.0),
        A.Resize(512, 512),
        ToTensorV2()
    ])
    
def get_custom_augmentation2():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Affine(scale=(0.8,1.2), rotate=360, translate_percent=0.15, shear=(-45, 45), p=0.9),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.GaussNoise(p=0.3),
        A.Normalize(mean=0.0, std=1.0),
        A.Resize(512, 512),
        ToTensorV2()
    ])

# Light augmentation for gap junction segmentation
def get_light_augmentation():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        A.GaussNoise(p=0.3),
        A.Blur(blur_limit=3, p=0.2),
        A.Normalize(mean=0.0, std=1.0),  # For grayscale
        ToTensorV2()
    ])

# Medium augmentation
def get_medium_augmentation():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.1, 
            scale_limit=0.2, 
            rotate_limit=15, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0, 
            p=0.5
        ),
        A.ElasticTransform(
            alpha=1, 
            sigma=50, 
            alpha_affine=50, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0, 
            p=0.3
        ),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
        A.Blur(blur_limit=3, p=0.2),
        A.CLAHE(clip_limit=2.0, p=0.3),
        A.Normalize(mean=0.0, std=1.0),
        ToTensorV2()
    ])
    
# Heavy augmentation
def get_heavy_augmentation():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.15, 
            scale_limit=0.3, 
            rotate_limit=25, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0, 
            p=0.6
        ),
        A.ElasticTransform(
            alpha=1, 
            sigma=50, 
            alpha_affine=50, 
            border_mode=cv2.BORDER_CONSTANT, 
            value=0, 
            p=0.4
        ),
        A.GridDistortion(p=0.3),
        A.OpticalDistortion(p=0.3),
        A.GaussNoise(var_limit=(10.0, 80.0), p=0.4),
        A.OneOf([
            A.Blur(blur_limit=3),
            A.GaussianBlur(blur_limit=3),
            A.MedianBlur(blur_limit=3),
        ], p=0.3),
        A.CLAHE(clip_limit=2.0, p=0.4),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
        A.Normalize(mean=0.0, std=1.0),
        ToTensorV2()
    ])

#### Make Dataset Class

In [22]:
#Class can load any mask as long as the model corresponds to the mask type
class TrainingDataset(Dataset):
    def __init__(self, images, labels, masks=None, augmentation=None, data_size=(512, 512), train=True):
        self.image_paths = sorted([os.path.join(images, img) for img in os.listdir(images)])
        self.label_paths = sorted([os.path.join(labels, lbl) for lbl in os.listdir(labels)])
        self.mask_paths = sorted([os.path.join(masks, mask) for mask in os.listdir(masks)]) if masks else None
        self.augmentation = augmentation
        self.data_size = data_size
        self.train = train

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        #Read image, label, and mask
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        label = cv2.imread(self.label_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE) if self.mask_paths else None
        
        #Apply resizing with padding if image is not expected size and then convert back to ndarray
        if (image.shape[0] != self.data_size[0]) or (image.shape[1] != self.data_size[1]): 
            image = np.array(resize_image(image, self.data_size[0], self.data_size[1], (0,0,0)))
            label = np.array(resize_image(label, self.data_size[0], self.data_size[1], (0,0,0)))
            if mask is not None:
                mask = np.array(resize_image(mask, self.data_size[0], self.data_size[1], (0,0,0)))

        #Convert mask/label to binary for model classification
        label[label > 0] = 1
        if mask is not None:
            mask[mask > 0] = 1
        
        #Apply augmentation if provided
        if self.augmentation and self.train:
            if mask is not None:
                #Use mask in augmentation
                augmented = self.augmentation(image=image, mask=label, label=mask)
                image = augmented['image']
                label = augmented['mask']
                mask = augmented['label']
            else:
                #Without mask
                augmented = self.augmentation(image=image, mask=label)
                image = augmented['image']
                label = augmented['mask']

        #Add entity recognition clause later if needed
        
        # Convert to tensors if not already converted from augmentation
        if not torch.is_tensor(image):
            image = ToTensor()(image).float()
        if not torch.is_tensor(label):
            label = torch.from_numpy(label).long()
        if mask is not None and not torch.is_tensor(mask):
            mask = torch.from_numpy(mask).long()
        elif mask is None:
            mask = torch.zeros_like(label)

        return image, label, mask

#### Set Augmentation

In [23]:
#For training with augmentation
train_augmentation = get_custom_augmentation()  # Change to get_medium_augmentation() or get_heavy_augmentation() as needed

# For validation without augmentation
valid_augmentation = A.Compose([
    A.Normalize(mean=0.0, std=1.0),
    ToTensorV2()
])

#### Initialize and Load Datasets

In [24]:
train = TrainingDataset(
    images="/home/tommytang111/gap-junction-segmentation/data/sem_adult/SEM_split/s250-259/imgs",
    labels="/home/tommytang111/gap-junction-segmentation/data/sem_adult/SEM_split/s250-259/gts",
    augmentation=train_augmentation,
    train=True,
)

valid = TrainingDataset(
    images="/home/tommytang111/gap-junction-segmentation/data/sem_adult/SEM_split/s200-209/imgs",
    labels="/home/tommytang111/gap-junction-segmentation/data/sem_adult/SEM_split/s200-209/gts",
    augmentation=valid_augmentation,
    train=False
)

train_dataloader = DataLoader(train, batch_size=8, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
valid_dataloader = DataLoader(valid, batch_size=8, shuffle=False, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

#### Initialize model and send to gpu

In [25]:
class UNet(nn.Module):
    """UNet Architecture"""
    def __init__(self, out_classes=2, up_sample_mode='conv_transpose', three=False, attend=False, residual=False, scale=False, spatial=False, dropout=0, classes=2):
        """Initialize the UNet model"""
        super(UNet, self).__init__()
        self.three = three
        self.up_sample_mode = up_sample_mode
        self.dropout=dropout

        # Downsampling Path
        self.down_conv1 = DownBlock(1, 64, three=three, spatial=False, residual=residual) # 3 input channels --> 64 output channels
        self.down_conv2 = DownBlock(64, 128, three=three, spatial=spatial, dropout=self.dropout, residual=residual) # 64 input channels --> 128 output channels
        self.down_conv3 = DownBlock(128, 256, spatial=spatial, dropout=self.dropout, residual=residual) # 128 input channels --> 256 output channels
        self.down_conv4 = DownBlock(256, 512, spatial=spatial, dropout=self.dropout, residual=residual) # 256 input channels --> 512 output channels
        # Bottleneck
        self.double_conv = DoubleConv(512, 1024,spatial=spatial, dropout=self.dropout, residual=residual)
        # Upsampling Path
        self.up_conv4 = UpBlock(512 + 1024, 512, self.up_sample_mode, dropout=self.dropout, residual=residual) # 512 + 1024 input channels --> 512 output channels
        self.up_conv3 = UpBlock(256 + 512, 256, self.up_sample_mode, dropout=self.dropout, residual=residual)
        self.up_conv2 = UpBlock(128+ 256, 128, self.up_sample_mode, dropout=self.dropout, residual=residual)
        self.up_conv1 = UpBlock(128 + 64, 64, self.up_sample_mode)
        # Final Convolution
        self.conv_last = nn.Conv2d(64, 1 if classes == 2 else classes, kernel_size=1)
        self.attend = attend
        if scale:
            self.s1, self.s2 = torch.nn.Parameter(torch.ones(1), requires_grad=True), torch.nn.Parameter(torch.ones(1), requires_grad=True) # learn scaling


    def forward(self, x):
        """Forward pass of the UNet model
        x: (16, 1, 512, 512)
        """
        # print(x.shape)
        x, skip1_out = self.down_conv1(x) # x: (16, 64, 256, 256), skip1_out: (16, 64, 512, 512) (batch_size, channels, height, width)    
        x, skip2_out = self.down_conv2(x) # x: (16, 128, 128, 128), skip2_out: (16, 128, 256, 256)
        if self.three: x = x.squeeze(-3)   
        x, skip3_out = self.down_conv3(x) # x: (16, 256, 64, 64), skip3_out: (16, 256, 128, 128)
        x, skip4_out = self.down_conv4(x) # x: (16, 512, 32, 32), skip4_out: (16, 512, 64, 64)
        x = self.double_conv(x) # x: (16, 1024, 32, 32)
        x = self.up_conv4(x, skip4_out) # x: (16, 512, 64, 64)
        x = self.up_conv3(x, skip3_out) # x: (16, 256, 128, 128)
        if self.three: 
            #attention_mode???
            skip1_out = torch.mean(skip1_out, dim=2)
            skip2_out = torch.mean(skip2_out, dim=2)
        x = self.up_conv2(x, skip2_out) # x: (16, 128, 256, 256)
        x = self.up_conv1(x, skip1_out) # x: (16, 64, 512, 512)
        x = self.conv_last(x) # x: (16, 1, 512, 512)
        return x
    
device = torch.device("cuda")    
model = UNet().to(device)

#### Initialize loss function and optimizer

In [26]:
loss_fn = GenDLoss()
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6)

In [27]:
#Send evaluation metrics to device
recall = BinaryRecall().to(device)
precision = BinaryPrecision().to(device)
f1 = BinaryF1Score().to(device)

In [28]:
def train(dataloader, model, loss_fn, optimizer, recall, precision, f1):
    model.train()
    train_loss = 0
    num_batches = len(dataloader)
    
    # Reset metrics for each epoch
    recall.reset()
    precision.reset()
    f1.reset()
    
    for batch, (X, y, _) in tqdm(enumerate(dataloader), total=num_batches, desc="Training", leave=False):
        X, y = X.to(device), y.to(device)
        # Special handling for BCEWithLogitsLoss
        if y.dim() == 3:
            y = y.unsqueeze(1).float()
        
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Calculate metrics after converting predictions to binary
        pred_binary = (torch.sigmoid(pred) > 0.5).squeeze(1)
        
        # Update metrics
        if y.dim() == 4 and y.size(1) == 1:
            y = y.squeeze(1)  # [B, 1, H, W] -> [B, H, W]
        recall.update(pred_binary, y)
        precision.update(pred_binary, y)
        f1.update(pred_binary, y)
        
        train_loss += loss.item()

    # Compute final metrics per epoch
    train_recall = recall.compute().item()
    train_precision = precision.compute().item()
    train_f1 = f1.compute().item()
    train_loss_per_epoch = train_loss / num_batches 
    
    return train_loss_per_epoch, train_recall, train_precision, train_f1

def validate(dataloader, model, loss_fn, recall, precision, f1):
    model.eval()
    test_loss = 0
    num_batches = len(dataloader)
    
    # Reset metrics for each epoch
    recall.reset()
    precision.reset()
    f1.reset()
    
    with torch.no_grad():
        for X, y, _ in tqdm(dataloader, desc="Validation", leave=False):
            X, y = X.to(device), y.to(device)
            #Special handling for BCEWithLogitsLoss
            if y.dim() == 3:
                y = y.unsqueeze(1).float()
            
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            
            # Calculate metrics
            pred_binary = (torch.sigmoid(pred) > 0.5).squeeze(1)
            
            # Update metrics
            if y.dim() == 4 and y.size(1) == 1:
                y = y.squeeze(1)  # [B, 1, H, W] -> [B, H, W]
            recall.update(pred_binary, y)
            precision.update(pred_binary, y)
            f1.update(pred_binary, y)
            
    # Compute final metrics per epoch
    val_recall = recall.compute().item()
    val_precision = precision.compute().item()
    val_f1 = f1.compute().item()
    val_loss_per_epoch = test_loss / num_batches

    return val_loss_per_epoch, val_recall, val_precision, val_f1

In [33]:
def sweep():
    # Initialize wandb run
    wandb.login(key="04e003d2c64e518f8033ab016c7a0036545c05f5")
    wandb.init(
        project="gap-junction-segmentation",
        entity="zhen_lab",
        dir="/home/tommytang111/gap-junction-segmentation/wandb"
    )
    
    # Get hyperparameters from wandb config
    config = wandb.config
    
    # Set seeds
    seed_everything(42)
    
    # Get augmentation strategy from config, default to 'medium'
    aug_strategy = config.get('augmentation', 'custom2')

    if aug_strategy == 'custom1':
        train_aug = get_custom_augmentation()
    elif aug_strategy == 'custom2':
        train_aug = get_custom_augmentation2()

    valid_aug = A.Compose([A.Normalize(mean=0.0, std=1.0), ToTensorV2()])

    # Initialize datasets with config batch size
    train_dataset = TrainingDataset(
        images="/home/tommytang111/gap-junction-segmentation/data/pilot1/train/imgs",
        labels="/home/tommytang111/gap-junction-segmentation/data/pilot1/train/gts",
        augmentation=train_aug,
        train=True
    )
    
    valid_dataset = TrainingDataset(
        images="/home/tommytang111/gap-junction-segmentation/data/pilot1/val/imgs",
        labels="/home/tommytang111/gap-junction-segmentation/data/pilot1/val/gts",
        augmentation=valid_aug,
        train=False
    )
    
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True, 
        num_workers=4, 
        pin_memory=True, 
        worker_init_fn=worker_init_fn
    )
    valid_dataloader = DataLoader(
        valid_dataset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=4, 
        pin_memory=True, 
        worker_init_fn=worker_init_fn
    )
    
    # Initialize model with config dropout
    device = torch.device("cuda")
    model = UNet(dropout=config.dropout).to(device)
    
    #Loss function mapping
    if config.loss_function == "GenDLoss":
        loss_fn = GenDLoss()
    #elif config.loss_function == "BCEWithLogitsLoss":
        #loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([2.0], device=device))
    elif config.loss_function == "FocalLoss":
        loss_fn = FocalLoss(alpha=torch.Tensor([0.08, 0.92]), device=device)

    #Optimizer mapping
    if config.optimizer == "AdamW":
        optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
    elif config.optimizer == "SGD":
        optimizer = SGD(model.parameters(), lr=config.learning_rate, momentum=0.9, weight_decay=1e-4)
    
    # Initialize learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=10, 
        min_lr=1e-6
    )
    
    # Initialize metrics
    recall = BinaryRecall().to(device)
    precision = BinaryPrecision().to(device)
    f1 = BinaryF1Score().to(device)
    
    # Training loop
    torch.cuda.empty_cache()
    epochs = 50  # Reduced for sweep
    best_f1 = 0.0
    best_val_loss = float('inf')
    best_epoch = 0
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}")
        
        # Training
        train_loss, train_recall, train_precision, train_f1 = train(
            train_dataloader, model, loss_fn, optimizer, recall, precision, f1
        )
        
        # Validation
        val_loss, val_recall, val_precision, val_f1 = validate(
            valid_dataloader, model, loss_fn, recall, precision, f1
        )
        
        # Update learning rate scheduler
        scheduler.step(val_loss)
        
        # Print metrics
        print(f"Train | Loss: {train_loss:.4f}, Recall: {train_recall:.4f}, Precision: {train_precision:.4f}, F1: {train_f1:.4f}")
        print(f"Val   | Loss: {val_loss:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}, F1: {val_f1:.4f}")
        print("-----------------------------")

        # Log best model state
        if val_loss < best_val_loss:
            best_val_loss = val_loss
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_epoch = epoch
            best_model_state = copy.deepcopy(model.state_dict())
            # Save best model for this run
            model_path = f"/home/tommytang111/gap-junction-segmentation/models/sweep_model_{wandb.run.id}.pt"
            
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_recall": train_recall,
            "train_precision": train_precision,
            "train_f1": train_f1,
            "val_loss": val_loss,
            "val_recall": val_recall,
            "val_precision": val_precision,
            "val_f1": val_f1,
            "best_val_f1": best_f1,
            "best_val_loss": best_val_loss,
            "best_epoch": best_epoch,
            "lr": optimizer.param_groups[0]["lr"]
        })

    print("Training Complete!")
    torch.save(best_model_state, model_path)
    print("Saved PyTorch Model to ", model_path)
    wandb.finish()


#### **Sweep**

In [30]:
#Define sweep configuration
sweep_config = {
    'method': 'bayes',  # or 'random', 'bayes'
    'metric': {
        'name': 'val_f1',
        'goal': 'maximize'
    },
    'parameters': {
        'learning_rate': {
            'values': [0.01, 0.001, 0.0001]
        },
        'batch_size': {
            'values': [8]
        },
        'optimizer': {
            'values': ['AdamW']
        },
        'loss_function': {
            'values': ['GenDLoss']
        },
        'dropout': {
            'values': [0, 0.1]
        },
        'augmentation': {
            'values': ['custom1', 'custom2']
        },
    }
}

In [34]:
#Initialize sweep
sweep_id = wandb.sweep(sweep_config, project="gap-junction-segmentation")
print(f"Sweep ID: {sweep_id}")

# Start the sweep agent
wandb.agent(sweep_id=sweep_id, function=sweep)

Create sweep with ID: nsu0z7xy
Sweep URL: https://wandb.ai/zhen_lab/gap-junction-segmentation/sweeps/nsu0z7xy
Sweep ID: nsu0z7xy


[34m[1mwandb[0m: Agent Starting Run: dxfg093q with config:
[34m[1mwandb[0m: 	augmentation: custom1
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	dropout: 0
[34m[1mwandb[0m: 	learning_rate: 0.01
[34m[1mwandb[0m: 	loss_function: GenDLoss
[34m[1mwandb[0m: 	optimizer: AdamW
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/tommytang111/.netrc


Epoch 1


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Train | Loss: 0.9614, Recall: 0.7559, Precision: 0.0196, F1: 0.0383
Val   | Loss: 0.8702, Recall: 0.3660, Precision: 0.1152, F1: 0.1753
-----------------------------
Epoch 2


Training:   0%|          | 0/25 [00:00<?, ?it/s]

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7fd75b184490>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fd78b62e230, execution_count=34 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7fd78b62fee0, raw_cell="#Initialize sweep
sweep_id = wandb.sweep(sweep_con.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu/home/tommytang111/gap-junction-segmentation/code/notebooks/sweep.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe