In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import sys
sys.path.append('../resnetunet/')

from resnet50_unet import UNetWithResnet50Encoder
from utils import BinaryLovaszHingeLoss, DiceLoss, JaccardLoss, dice_coefficient, show_images_and_masks
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Fix seed
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False




In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import os
import numpy as np
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, augment=False, denoise=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.denoise = denoise
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
        self.augment = augment

        # Updated transformation pipeline for augmented path
        self.augment_transform = A.Compose([
            A.OneOf([
                A.Rotate(limit=180, p=0.5),
                A.ElasticTransform(alpha=2, sigma=50, alpha_affine=50, p=0.5),
            ], p=1.0),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.GaussianBlur(blur_limit=3, p=0.5),
            A.GaussNoise(var_limit=(10, 50), p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.5),
            A.ToFloat(max_value=255.0),
            ToTensorV2(),
        ])

        # Define a simple transformation pipeline for non-augmented path
        self.basic_transform = A.Compose([
             A.ToFloat(max_value=255.0),
            ToTensorV2(),
        ])


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_filename = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_filename)
        mask_path = os.path.join(self.mask_dir, image_filename.replace('image', 'mask'))

        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.denoise:
            image = cv2.fastNlMeansDenoising(image, None, h=15, templateWindowSize=7, searchWindowSize=21)

        image = cv2.resize(image, (512, 512))
        mask = cv2.resize(mask, (512, 512))
        _, mask = cv2.threshold(mask, 127, 1, cv2.THRESH_BINARY) 


        # Stack image to create 3 channels if needed
        image = np.stack([image] * 3, axis=-1)

        if self.augment:
            augmented = self.augment_transform(image=image, mask=mask)
        else:
            augmented = self.basic_transform(image=image, mask=mask)

        image = augmented['image']
        mask = augmented['mask']




        return image, mask.float()


In [None]:
def freeze_all_encoder_layers(model):
    for layer in model.down_blocks:
        for param in layer.parameters():
            param.requires_grad = False


In [None]:
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True

In [None]:
def unfreeze_block(model, block_name):
    """
    Unfreezes a specific block of the ResNet encoder within the U-Net model.

    Parameters:
    - model: The instance of your UNetWithResnet50Encoder model.
    - block_name: A string name of the block to unfreeze (e.g., 'layer4', 'layer3').
    """
    if hasattr(model, 'down_blocks'):
        for block in model.down_blocks:
            # The down_blocks attribute is a ModuleList; each block is a layer in ResNet
            if block_name == block.__class__.__name__:
                for param in block.parameters():
                    param.requires_grad = True
    else:
        print(f"The model does not have the specified block: {block_name}")


In [None]:
# Data Loaders
BATCH_SIZE = 4
train_dataset = CustomDataset('../data_retina/train/images' , '../data_retina/train/masks', augment=True, denoise=False)
val_dataset = CustomDataset('../data_retina/test/images', '../data_retina/test/masks', denoise=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model, Optimizer, and Loss Functions
model = UNetWithResnet50Encoder(n_classes=1).to(device)
model.load_state_dict(torch.load('output_dir/retina_encoder_l4.pth'))
freeze_all_encoder_layers(model)
unfreeze_block(model, 'layer4')
unfreeze_block(model, 'layer3') 


In [None]:
import torchsummary
torchsummary.summary(model, (3, 512, 512))

In [None]:
def create_optimizer_param_groups(model):
    encoder_params = []
    decoder_params = []
    
    for block in model.down_blocks:
        for param in block.parameters():
            if param.requires_grad:
                encoder_params.append(param)
                

    for block in model.up_blocks:
        for param in block.parameters():
            if param.requires_grad:
                decoder_params.append(param)

    param_groups = [
        {'params': encoder_params, 'lr': 1e-4},  
        {'params': decoder_params, 'lr': 1e-4}   
    ]
    
    return param_groups

optimizer = torch.optim.AdamW(create_optimizer_param_groups(model))


In [None]:
#optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
#optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, weight_decay=0.0001)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, verbose=True)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
loss_fn = DiceLoss().to(device)

In [None]:
# Display train set
show_images_and_masks(train_dataset, num_imgs=2)

In [None]:
# Training Function
def train_one_epoch(epoch):
    model.train()
    loss_total = 0.0
    loss_throughout_epoch = []

    for input_img, mask in tqdm(train_loader, desc=f'Training epoch {epoch}'):
        input_img, mask = input_img.to(device), mask.to(device)

        optimizer.zero_grad()
        output = model(input_img)
        loss = loss_fn(output, mask)

        # Combine losses and backpropagate, normalize them against each other
        loss.backward()
        optimizer.step()

        # Accumulate individual losses for logging
        loss_total += loss.item()
        loss_throughout_epoch.append(loss.item())
    
    # Save graph of loss throughout epoch
    plt.figure(figsize=(12, 8))
    plt.plot(loss_throughout_epoch, label='Loss')
    plt.title('Loss')
    plt.legend()
    plt.tight_layout()
    if not os.path.exists('output_dir'):
        os.makedirs('output_dir')

    plt.savefig(f'output_dir/internalloss_graph_epoch_{epoch}.png')
    plt.close()

    return loss_total / len(train_loader)

# Validation Function
def validate(epoch):
    model.eval()
    loss_total = 0.0

    with torch.no_grad():
        for input_img, mask in tqdm(val_loader, desc=f'Validating epoch {epoch}'):
            input_img, mask = input_img.to(device), mask.to(device)

            output = model(input_img)
            loss = loss_fn(output, mask)
            loss_total += loss.item()

    return loss_total / len(val_loader)


# Plot Loss Function
def plot_loss(epoch,train_inp_losses, val_inp_losses):
    plt.figure(figsize=(12, 8))

    epochs = range(1, epoch + 2)  

    plt.plot(epochs, train_inp_losses, label='Train Loss')
    plt.plot(epochs, val_inp_losses, label='Val Loss')
    plt.title('Loss')
    plt.legend()

    plt.tight_layout()
    plt.savefig(f'output_dir/loss_graph.png')
    plt.close()


# Save Model Function
def save_model(model, path):
    torch.save(model.state_dict(), path)


train_inp_losses = []
val_inp_losses = []

best_val_loss = float('inf')  # Initialize with a high value
best_model_path = ''  # Path to save the best model


 
num_epochs = 15

for epoch in range(num_epochs):
    train_inp_loss = train_one_epoch(epoch) 
    scheduler.step(train_inp_loss)
    val_inp_loss = validate(epoch)  
    train_inp_losses.append(train_inp_loss)  
    val_inp_losses.append(val_inp_loss)  

    # Save model if it has the lowest validation loss so far
    if val_inp_loss < best_val_loss:
        best_val_loss = val_inp_loss
        best_model_path = f'output_dir/best_model.pth'
        save_model(model, best_model_path)  

    print(f'Epoch {epoch}, Train Loss: {train_inp_loss}, Val Loss: {val_inp_loss}')

    # Save model for every epoch (optional if you only want the best model)
    #save_model(model, f'st_output_dir/model_epoch_{epoch}.pth')

    plot_loss(epoch, train_inp_losses, val_inp_losses)

# Prediction

In [None]:
import matplotlib.pyplot as plt

def visualize(image, true_mask=None, predicted_mask=None):
    """Visualize comparison between input image, true mask, and predicted mask."""
    fig, axs = plt.subplots(1, 3, figsize=(20, 10)) 

    axs[0].imshow(image, cmap='gray')
    axs[0].set_title('Input Image')
    axs[0].axis('off')

    if true_mask is not None:
        axs[1].imshow(true_mask, cmap='gray')
        axs[1].set_title('True Mask')
        axs[1].axis('off')
    else:
        axs[1].axis('off')

    axs[2].imshow(predicted_mask, cmap='gray')
    axs[2].set_title('Predicted Mask')
    axs[2].axis('off')

    plt.show()


In [None]:
def predict_and_visualize(model, dataset, device, n_images):
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for i in range(n_images):
            input_img, true_mask = dataset[i]
            input_img_unsqueeze = input_img.unsqueeze(0).to(device)

            # Predict
            pred_mask_logits = model(input_img_unsqueeze)
            pred_mask_prob = torch.sigmoid(pred_mask_logits)
            pred_mask = pred_mask_prob > 0.5 

            # Calculate Dice Coefficient
            pred = pred_mask.squeeze()
            dice_score = dice_coefficient(pred, true_mask.to(device), smooth=1e-6)

            dice_scores.append(dice_score.item())

            # Convert for visualization
            pred_mask_np = pred_mask.squeeze().cpu().numpy() 
            input_img_np = input_img.squeeze().permute(1, 2, 0).cpu().numpy()
            input_img_np = (input_img_np * 255).astype(np.uint8)
            true_mask_np = true_mask.squeeze().cpu().numpy() if true_mask is not None else None 

            visualize(
                image=input_img_np, 
                true_mask=true_mask_np,
                predicted_mask=pred_mask_np
            )

    # Print the average Dice score at the end 
    average_dice_score = np.mean(dice_scores) 
    print(f'Average Dice Coefficient over {n_images} images: {average_dice_score:.4f}')


In [None]:
test_dataset = CustomDataset('../data_retina/test/images', '../data_retina/test/masks', augment=False, denoise=False)

# Load the model (ensure the model is already trained and weights are loaded)
model = UNetWithResnet50Encoder(n_classes=1).to(device)
model.load_state_dict(torch.load('output_dir/best_model.pth'))
# Predict and visualize on the real data
predict_and_visualize(model, test_dataset, device, n_images=len(test_dataset))
