### ```1-LIBRARIES```

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
# from google.colab import drive
from PIL import Image
import os
import matplotlib.pyplot as plt
import random
from torch.utils.data import random_split
import numpy as np
import pandas as pd
import gc
import json
import re
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
torch.cuda.empty_cache()
# Set random seed for reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)


<torch._C.Generator at 0x204e34cbe70>

### ```2-MOUNT DRIVE (if running on colab)```

In [2]:
#drive.mount('/content/drive')

#### ```3-addresses (replace the datasets and ground truth directories and the evaluation path with your own addresses). Also put the IHC prediction of each model on BCI dataset both for train and test in each directory. For example the IHC prediction of model1 on the test data should be in model1_dir_test_eval```

In [2]:
model1_dir_train_eval = '\\DB1_train_IHC'
model1_dir_test_eval = '\\DB1_test_IHC'
model2_dir_train_eval = '\\DB2_train_eval'
model2_dir_test_eval = '\\DB2_test_eval'
model3_dir_train_eval = '\\DB3_train_eval'
model3_dir_test_eval = '\\DB3_test_eval'
gt_dir_test  = '\\GT_test'
gt_dir_train  = '\\GT_train'
#-----------------------------------------------------------------
eval_path = 'result metrics\\evaluation.xlsx'
weights_dir = 'experiments\\weights'
experiments_dir = 'experiments\\result images'

#### ```4-preparing the train and test datasets and the loaders```
----

In [3]:
class ImageDataset(Dataset):
    def __init__(self, db1_dir, db2_dir, db3_dir, gt_dir, transform=None):

        self.db1_dir = db1_dir
        self.db2_dir = db2_dir
        self.db3_dir = db3_dir
        self.gt_dir = gt_dir
        self.transform = transform
        self.image_names = os.listdir(db1_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]

        # Load images from the three databases
        img1 = Image.open(os.path.join(self.db1_dir, img_name)).convert('RGB')
        img2 = Image.open(os.path.join(self.db2_dir, img_name)).convert('RGB')
        img3 = Image.open(os.path.join(self.db3_dir, img_name)).convert('RGB')

        # Load the ground truth image
        gt = Image.open(os.path.join(self.gt_dir, img_name)).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
            gt = self.transform(gt)

        inputs = torch.stack([img1, img2, img3], dim=-1)  # Stack images along the last dimension
        return inputs, gt.unsqueeze(0)
    
    def get_image_name(self,idx):
        return self.image_names[idx]

seed = 42
random.seed(seed)
torch.manual_seed(seed)

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = ImageDataset(db1_dir=model1_dir_train_eval, db2_dir=model2_dir_train_eval, db3_dir=model3_dir_train_eval, gt_dir=gt_dir_train, transform=transform)
test_dataset = ImageDataset(db1_dir=model1_dir_test_eval, db2_dir=model2_dir_test_eval, db3_dir=model3_dir_test_eval, gt_dir=gt_dir_test, transform=transform)

the_batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=the_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=the_batch_size, shuffle=False)

#### ```5-Model()```
----



In [4]:
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels=3, out_channels=16, kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 32,  kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 64,  kernel_size=3, stride=(1, 2, 2), padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(32, 16,  kernel_size=3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(16, 1,  kernel_size=3, stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)),
            nn.Sigmoid()
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, inputs):
        x = inputs.permute(0,4,1,2,3)
        x = self.encoder(x)
        x = self.decoder(x) 
        return x  

model summary
```----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
               input      [batch_size, 3, 3, 1024, 1024]
               
            Conv3d-1       [batch_size, 16, 3, 512, 512]           1,312
       BatchNorm3d-2       [batch_size, 16, 3, 512, 512]              32
              ReLU-3       [batch_size, 16, 3, 512, 512]               0
            Conv3d-4       [batch_size, 32, 3, 256, 256]          13,856
       BatchNorm3d-5       [batch_size, 32, 3, 256, 256]              64
              ReLU-6       [batch_size, 32, 3, 256, 256]               0
            Conv3d-7       [batch_size, 64, 3, 128, 128]          55,360
       BatchNorm3d-8       [batch_size, 64, 3, 128, 128]             128
              ReLU-9       [batch_size, 64, 3, 128, 128]               0
  ConvTranspose3d-10       [batch_size, 32, 3, 256, 256]          55,328
      BatchNorm3d-11       [batch_size, 32, 3, 256, 256]              64
             ReLU-12       [batch_size, 32, 3, 256, 256]               0
  ConvTranspose3d-13       [batch_size, 16, 3, 512, 512]          13,840
      BatchNorm3d-14       [batch_size, 16, 3, 512, 512]              32
             ReLU-15       [batch_size, 16, 3, 512, 512]               0
  ConvTranspose3d-16      [batch_size, 1, 3, 1024, 1024]             433
          Sigmoid-17      [batch_size, 1, 3, 1024, 1024]               0
================================================================
Total params: 140,449
Trainable params: 140,449
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 72.00
Forward/backward pass size (MB): 1968.00
Params size (MB): 0.54
Estimated Total Size (MB): 2040.54
----------------------------------------------------------------


#### ```6-For training the model from scratch or from the last epoch run the the bottom cell (after running the top cells)```
the last weights will be saved after each epoch in the 'weights_dir' directory

In [None]:
#-----------------training method----------------------------
def train_model(model, dataloader, criterion, optimizer, scheduler, num_epochs=21, device='cuda', weights_dir=weights_dir):

    model.to(device)

    total_steps = len(dataloader) * num_epochs
    current_step = 0

    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
        print("directory for weights created:",weights_dir)
    
    existing_files = [f for f in os.listdir(weights_dir) if f.endswith('.pth')]
    if existing_files:
        # Extract epoch numbers and get the maximum
        epochs = [int(re.search(r'_epoch_(\d+)', f).groups()[0]) for f in existing_files]
        last_epoch = max(epochs)
        print(f"Resuming training from epoch: {last_epoch + 1}")
        # Load weights from the file
        last_weights_file = f"{weights_dir}/weights_epoch_{last_epoch}.pth"
        model.load_state_dict(torch.load(last_weights_file))
        # Set the starting epoch for training
        start_epoch = last_epoch + 1
        current_step = len(dataloader) * last_epoch
    else:
        print("Initializing weights using Xavier initialization")
        model._initialize_weights()  # Custom function to initialize weights using Xavier initialization
        start_epoch = 1  # Start from the first epoch

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            current_step += 1

            percentage_done = (current_step / total_steps) * 100

            print(f'Epoch {epoch}/{num_epochs}, Step {current_step}/{total_steps} ({percentage_done:.2f}% complete)')

            del inputs, targets, outputs
            gc.collect()

        scheduler.step()
        epoch_loss = running_loss / len(dataloader.dataset)
        print(f'Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.4f}')

        weights_file = os.path.join(weights_dir, f'weights_epoch_{epoch}.pth')
        torch.save(model.state_dict(), weights_file)
        print(f'Model weights saved at: {weights_file}')

    return model
#------------------------------------------------------------------------------------------------------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EncoderDecoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
model = train_model(model, train_loader, criterion, optimizer, scheduler, num_epochs=25, device=device)

#### ```7-For loading the pretrained weights on the model (skip cell 6) run this cell```

In [None]:
weights_path = 'experiments\\weights\\weights_epoch_13.pth'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EncoderDecoder().to(device)
model.load_state_dict(torch.load(weights_path,map_location=torch.device('cpu')))

#### ```8-evaluate() (this will only evaluate on the test data and save the metrics in eval_path (the images will not be saved))```
----

In [8]:
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    model.to(device)
    results = []
    ssim_change = 0
    psnr_change = 0
    counter = 0
    global_counter = 0
    with torch.no_grad():
        for inputs, gt in dataloader:
            inputs, gt = inputs.to(device), gt.to(device)
            pred = model(inputs)
            pred = pred.cpu().numpy()
            gt = gt.cpu().numpy()
            inputs = inputs.cpu().numpy()
            global_counter += 1
            print(f'{(100 * global_counter/(len(dataloader)))} percent evaluated')

            for j in range(inputs.shape[0]):

              counter += 1
              ssim_img1 = (ssim(inputs[j, :, :, :, 0].transpose(1,2,0), gt[j,0,:,:,:].transpose(1,2,0), multichannel=True, channel_axis=2 , data_range=1))
              psnr_img1 = (psnr(inputs[j, :, :, :, 0].transpose(1,2,0), gt[j,0,:,:,:].transpose(1,2,0)))

              ssim_pred = ssim(pred[j,0,:,:,:].transpose(1, 2, 0), gt[j,0,:,:,:].transpose(1, 2, 0), multichannel=True, channel_axis=2  , data_range=1)
              psnr_pred = psnr(pred[j,0,:,:,:], gt[j,0,:,:,:])

              results.append((ssim_img1, ssim_pred, psnr_img1, psnr_pred))

    df = pd.DataFrame(results, columns=['SSIM Image1', 'SSIM Predicted', 'PSNR Image1', 'PSNR Predicted'])
    try:
        df.to_excel(eval_path, index=True)
        print('metrics saved at: ', eval_path)
    except Exception as e:
        print(f"Error saving to Excel: {e}")
    print('-------------------------------------------------')
    print(df.mean())

evaluate_model(model, test_loader, device='cuda')

#### ```9-evaluate() (this cell will predict the test dataset images (IHC images) and save them in experiments_dir)```
----

In [None]:
def predict(model, inputs, device='cuda'):
    model.eval()
    with torch.no_grad():
        inputs = inputs.to(device)
        output = model(inputs)
    return output.cpu()

def predict_and_save(model, inputs, name ,device='cuda'):
    pred = predict(model, inputs, device=device)
    pred = pred[0,0,:,:,:].numpy().transpose(1, 2, 0)
    inputs = inputs.numpy()
    plt.imsave(os.path.join(experiments_dir,name), pred)
    print(f'image prediction {name} saved in {experiments_dir}')

for i in range(len(test_dataset)):
    predict_and_save(model,test_dataset[i][0].unsqueeze(0),test_dataset.get_image_name(i),device = device)