In [8]:
import json
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from natsort import natsorted
from torch.utils.data import Dataset
from Get_model_and_data import *


In [10]:

class Global_Local_Dataset(Dataset):
    def __init__(self, array_folder_local, array_folder_global, gt_folder):
        self.array_fps_local = natsorted([os.path.join(array_folder_local, file) for file in os.listdir(array_folder_local)])
        self.array_fps_global = natsorted([os.path.join(array_folder_global, file) for file in os.listdir(array_folder_global)])
        self.gt_fps = natsorted([os.path.join(gt_folder, file) for file in os.listdir(gt_folder)])



    def __getitem__(self, idx):
        local_array = np.load(self.array_fps_local[idx])
        global_array = np.load(self.array_fps_global[idx])
        merged_array= np.concatenate((local_array, global_array), axis=-1) 
        gt_image = cv2.imread(self.gt_fps[idx], cv2.IMREAD_GRAYSCALE)
        gt_mask = (gt_image > 1).astype('float')

        return merged_array, gt_mask

    def __len__(self):
        return len(self.array_fps_local)
    

In [11]:

def global_local_prob_model(global_dir,local_dir,device):


    train_mask_dir = "c:/Users/PC/Desktop/label_zoom_1_0_1/train/" 
    resolution = "1152"
  

    array_folder_local = f"{local_dir}/pred_probs_caglar_he_train"
    array_folder_global = f"{global_dir}/pred_probs_caglar_he_train_resized"

    dataset = Global_Local_Dataset(array_folder_local, array_folder_global,train_mask_dir)

# Determine the size of the training and testing sets based on the split ratio
    train_size = int(0.9 * len(dataset))
    test_size = len(dataset) - train_size

    # Split the dataset into training and testing sets
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Create data loaders for the training and testing sets
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True,drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    class CoefficientModel(nn.Module):
        def __init__(self):
            super(CoefficientModel, self).__init__()
            self.fc1 = nn.Conv2d(2,1,kernel_size=1,stride = 1,padding=0)
                    

        def forward(self, merged_arrayy):
            # Flatten the input arrays
            # Concatenate the flattened input arrays
            # Apply the first fully connected layer
            x = self.fc1( merged_arrayy)

            x =  x.view(-1, 1, 3456, 3456)
            return x
    model = CoefficientModel().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.00001)

    # Create a scheduler to adjust learning rate
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.5**(epoch // 15))

    # Training loop
    num_epochs = 15
    for epoch in range(num_epochs):
        model.train()
        for merged_arrayy, gt_mask in train_loader:
            merged_arrayy, gt_mask = merged_arrayy.to(device), gt_mask.to(device)
            optimizer.zero_grad()
            output = model(merged_arrayy.float())
            loss = criterion(output, gt_mask.float())
            loss.backward()
            optimizer.step()
        scheduler.step()  # Adjust learning rate
        print(f"Epoch {epoch+1}, Loss: {loss.item()}, Learning Rate: {scheduler.get_lr()[0]}")
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(name, param.data)
        

    # Evaluation loop
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for local_array, global_array, gt_mask in test_loader:
            local_array, global_array, gt_mask = local_array.to(device), global_array.to(device), gt_mask.to(device)
            output = model(local_array.float(), global_array.float())
            total_loss += criterion(output, gt_mask.float()).item()

    avg_loss = total_loss / len(test_loader)
    print(f"Average Loss on Test Set: {avg_loss}")

        


In [None]:
global_local_prob_model(global_dir = "C:/Users/PC/Desktop/Segmentation/IDRiD-Eye-Fundus-Dataset-Lesion-Segmentation/2024_May_12-08_07_57/",
    local_dir = "C:/Users/PC/Desktop/Segmentation/IDRiD-Eye-Fundus-Dataset-Lesion-Segmentation/2024_May_11-10_44_04/",device= "cuda")