In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from patchify import patchify
import PIL
from PIL import Image
PIL.Image.MAX_IMAGE_PIXELS = 933120000
import os
import shutil
import random
%matplotlib inline
import torchvision
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
import torch.onnx

In [2]:
class PatchDataset(Dataset):
    def __init__(self,root,target, train=True, transforms=None):
        super(PatchDataset, self).__init__()
        self.image_path = [os.path.join(root, x) for x in os.listdir(root)]      
        self.ref_path = [os.path.join(target,x) for x in os.listdir(target)]
        
        if transform is not None:
            self.transform = transform

        if train:
            self.images = self.image_path[: int(.8 * len(self.image_path))]
            self.ref = self.ref_path[: int(.8 * len(self.image_path))]
        else:
            self.images = self.image_path[int(.8 * len(self.image_path)):]
            self.ref = self.ref_path[int(.8 * len(self.image_path)):]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        return self.transform(self.images[item]),self.transform(self.ref[item])  

In [3]:
transform = transforms.Compose([
    lambda x: Image.open(x).convert('RGB'),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.encoder_ = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        )

    def forward(self, x):
        return self.encoder_(x)

In [5]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.decoder_ = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.ConvTranspose2d(32, 3, kernel_size=2, stride=2, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder_(x)

In [6]:
class CAE(nn.Module):
    def __init__(self):
        super(CAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [7]:
class ModifiedContrastiveLoss(nn.Module):
    def __init__(self):
        super(ModifiedContrastiveLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, encoded_x1, encoded_x2):
        # Calculate the modified contrastive loss using MSE loss
        loss = self.mse_loss(encoded_x1, encoded_x2)
        return loss

In [8]:
class SiameseCAE(nn.Module):
    def __init__(self):
        super(SiameseCAE, self).__init__()
        self.cae = CAE()

    def forward(self, x1, x2):
        # Pass the input images through the shared-weight CAE
        encoded_x1 = self.cae.encoder(x1)
        encoded_x2 = self.cae.encoder(x2)

        # Decode the encoded representations
        decoded_x1 = self.cae.decoder(encoded_x1)
        decoded_x2 = self.cae.decoder(encoded_x2)

        return encoded_x1, encoded_x2, decoded_x1, decoded_x2

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = CAE().to(device)
# Set loss function and optimizer
criterion = nn.MSELoss().to(device)
l1_loss = nn.L1Loss().to(device)
Hube = torch.nn.HuberLoss().to(device)
CE = nn.CrossEntropyLoss().to(device)

model = SiameseCAE().to(device)
contrastive_loss = ModifiedContrastiveLoss()
reconstruction_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

unloader = transforms.ToPILImage()

train_dataset = PatchDataset("OS_412(64)_SSIM_BlankRemove","OS_415(64)_SSIM_BlankRemove",train=True, transforms=transform)
test_dataset = PatchDataset("OS_412(64)_SSIM_BlankRemove","OS_415(64)_SSIM_BlankRemove",train=False, transforms=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(device)
print(model.eval())

cuda
SiameseCAE(
  (cae): CAE(
    (encoder): Encoder(
      (encoder_): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (10): ReLU()
        (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (decoder): Decoder(
      (decoder_): Sequential(
    

In [12]:
train_losses = []
test_losses = []
# Training loop
num_epochs = 500
for epoch in range(num_epochs):
    for (img1,img2) in train_loader:
        img1 = img1.to(device)
        img2 = img2.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        encoded_x1, encoded_x2, decoded_x1, decoded_x2 = model(img1, img2)
        # Calculate the modified contrastive loss using MSE loss
        loss_contrastive = contrastive_loss(encoded_x1, encoded_x2)
        
        # Calculate the reconstruction loss
        loss_reconstruction1 = reconstruction_loss(decoded_x1, img1)
        loss_reconstruction2 = reconstruction_loss(decoded_x2, img2)

        # Combine the losses with appropriate weights
        loss = loss_contrastive + 0.5 * (loss_reconstruction1 + loss_reconstruction2)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
    
    with torch.no_grad():
        for (img1,img2) in test_loader:
            img1 = img1.to(device)
            img2 = img2.to(device)
            
            # Forward pass
            encoded_x1, encoded_x2, decoded_x1, decoded_x2 = model(img1, img2)
            # Calculate the modified contrastive loss using MSE loss
            loss_contrastive = contrastive_loss(encoded_x1, encoded_x2)

            # Calculate the reconstruction loss
            loss_reconstruction1 = reconstruction_loss(decoded_x1, img1)
            loss_reconstruction2 = reconstruction_loss(decoded_x2, img2)

            # Combine the losses with appropriate weights
            loss_test = loss_contrastive + 0.5 * (loss_reconstruction1 + loss_reconstruction2)            
            
            
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss.item():.4f}')
    print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {loss_test.item():.4f}')
    train_losses.append(loss.item())
    test_losses.append(loss_test.item())
    # Save the trained model
    path = 'SiameseCAE_Model(64)_MapLearning/epoch_{i}.pth'.format(i=epoch+1)
    torch.save(model.state_dict(), path)
print("Training completed.")

Epoch [1/500], Train Loss: 0.0456
Epoch [1/500], Test Loss: 0.0470
Epoch [2/500], Train Loss: 0.0360
Epoch [2/500], Test Loss: 0.0366
Epoch [3/500], Train Loss: 0.0307
Epoch [3/500], Test Loss: 0.0306
Epoch [4/500], Train Loss: 0.0283
Epoch [4/500], Test Loss: 0.0281
Epoch [5/500], Train Loss: 0.0264
Epoch [5/500], Test Loss: 0.0268
Epoch [6/500], Train Loss: 0.0255
Epoch [6/500], Test Loss: 0.0259
Epoch [7/500], Train Loss: 0.0248
Epoch [7/500], Test Loss: 0.0254
Epoch [8/500], Train Loss: 0.0246
Epoch [8/500], Test Loss: 0.0250
Epoch [9/500], Train Loss: 0.0239
Epoch [9/500], Test Loss: 0.0245
Epoch [10/500], Train Loss: 0.0236
Epoch [10/500], Test Loss: 0.0242
Epoch [11/500], Train Loss: 0.0232
Epoch [11/500], Test Loss: 0.0239
Epoch [12/500], Train Loss: 0.0230
Epoch [12/500], Test Loss: 0.0236
Epoch [13/500], Train Loss: 0.0227
Epoch [13/500], Test Loss: 0.0234
Epoch [14/500], Train Loss: 0.0226
Epoch [14/500], Test Loss: 0.0232
Epoch [15/500], Train Loss: 0.0223
Epoch [15/500], T

Epoch [120/500], Train Loss: 0.0173
Epoch [120/500], Test Loss: 0.0186
Epoch [121/500], Train Loss: 0.0172
Epoch [121/500], Test Loss: 0.0186
Epoch [122/500], Train Loss: 0.0172
Epoch [122/500], Test Loss: 0.0186
Epoch [123/500], Train Loss: 0.0172
Epoch [123/500], Test Loss: 0.0186
Epoch [124/500], Train Loss: 0.0172
Epoch [124/500], Test Loss: 0.0186
Epoch [125/500], Train Loss: 0.0173
Epoch [125/500], Test Loss: 0.0186
Epoch [126/500], Train Loss: 0.0172
Epoch [126/500], Test Loss: 0.0186
Epoch [127/500], Train Loss: 0.0172
Epoch [127/500], Test Loss: 0.0185
Epoch [128/500], Train Loss: 0.0172
Epoch [128/500], Test Loss: 0.0185
Epoch [129/500], Train Loss: 0.0172
Epoch [129/500], Test Loss: 0.0185
Epoch [130/500], Train Loss: 0.0173
Epoch [130/500], Test Loss: 0.0186
Epoch [131/500], Train Loss: 0.0172
Epoch [131/500], Test Loss: 0.0185
Epoch [132/500], Train Loss: 0.0172
Epoch [132/500], Test Loss: 0.0185
Epoch [133/500], Train Loss: 0.0172
Epoch [133/500], Test Loss: 0.0185
Epoch 

Epoch [236/500], Train Loss: 0.0168
Epoch [236/500], Test Loss: 0.0181
Epoch [237/500], Train Loss: 0.0167
Epoch [237/500], Test Loss: 0.0181
Epoch [238/500], Train Loss: 0.0168
Epoch [238/500], Test Loss: 0.0181
Epoch [239/500], Train Loss: 0.0168
Epoch [239/500], Test Loss: 0.0180
Epoch [240/500], Train Loss: 0.0167
Epoch [240/500], Test Loss: 0.0181
Epoch [241/500], Train Loss: 0.0167
Epoch [241/500], Test Loss: 0.0181
Epoch [242/500], Train Loss: 0.0167
Epoch [242/500], Test Loss: 0.0181
Epoch [243/500], Train Loss: 0.0168
Epoch [243/500], Test Loss: 0.0181
Epoch [244/500], Train Loss: 0.0167
Epoch [244/500], Test Loss: 0.0181
Epoch [245/500], Train Loss: 0.0167
Epoch [245/500], Test Loss: 0.0181
Epoch [246/500], Train Loss: 0.0167
Epoch [246/500], Test Loss: 0.0181
Epoch [247/500], Train Loss: 0.0167
Epoch [247/500], Test Loss: 0.0180
Epoch [248/500], Train Loss: 0.0167
Epoch [248/500], Test Loss: 0.0181
Epoch [249/500], Train Loss: 0.0167
Epoch [249/500], Test Loss: 0.0181
Epoch 

Epoch [352/500], Train Loss: 0.0165
Epoch [352/500], Test Loss: 0.0179
Epoch [353/500], Train Loss: 0.0165
Epoch [353/500], Test Loss: 0.0179
Epoch [354/500], Train Loss: 0.0165
Epoch [354/500], Test Loss: 0.0178
Epoch [355/500], Train Loss: 0.0165
Epoch [355/500], Test Loss: 0.0178
Epoch [356/500], Train Loss: 0.0165
Epoch [356/500], Test Loss: 0.0179
Epoch [357/500], Train Loss: 0.0165
Epoch [357/500], Test Loss: 0.0179
Epoch [358/500], Train Loss: 0.0165
Epoch [358/500], Test Loss: 0.0178
Epoch [359/500], Train Loss: 0.0165
Epoch [359/500], Test Loss: 0.0179
Epoch [360/500], Train Loss: 0.0165
Epoch [360/500], Test Loss: 0.0178
Epoch [361/500], Train Loss: 0.0165
Epoch [361/500], Test Loss: 0.0179
Epoch [362/500], Train Loss: 0.0165
Epoch [362/500], Test Loss: 0.0179
Epoch [363/500], Train Loss: 0.0165
Epoch [363/500], Test Loss: 0.0178
Epoch [364/500], Train Loss: 0.0165
Epoch [364/500], Test Loss: 0.0178
Epoch [365/500], Train Loss: 0.0165
Epoch [365/500], Test Loss: 0.0178
Epoch 

Epoch [468/500], Train Loss: 0.0164
Epoch [468/500], Test Loss: 0.0177
Epoch [469/500], Train Loss: 0.0164
Epoch [469/500], Test Loss: 0.0177
Epoch [470/500], Train Loss: 0.0164
Epoch [470/500], Test Loss: 0.0177
Epoch [471/500], Train Loss: 0.0164
Epoch [471/500], Test Loss: 0.0177
Epoch [472/500], Train Loss: 0.0164
Epoch [472/500], Test Loss: 0.0177
Epoch [473/500], Train Loss: 0.0164
Epoch [473/500], Test Loss: 0.0177
Epoch [474/500], Train Loss: 0.0164
Epoch [474/500], Test Loss: 0.0177
Epoch [475/500], Train Loss: 0.0164
Epoch [475/500], Test Loss: 0.0177
Epoch [476/500], Train Loss: 0.0164
Epoch [476/500], Test Loss: 0.0177
Epoch [477/500], Train Loss: 0.0164
Epoch [477/500], Test Loss: 0.0177
Epoch [478/500], Train Loss: 0.0164
Epoch [478/500], Test Loss: 0.0177
Epoch [479/500], Train Loss: 0.0164
Epoch [479/500], Test Loss: 0.0178
Epoch [480/500], Train Loss: 0.0164
Epoch [480/500], Test Loss: 0.0177
Epoch [481/500], Train Loss: 0.0164
Epoch [481/500], Test Loss: 0.0177
Epoch 

In [14]:

state = torch.load("E:\PG_HP_FINAL\CAE Learn Map difference\SiameseCAE_Model(64)_MapLearning\epoch_500.pth")
model = SiameseCAE()
model.load_state_dict(state)
model.eval()

SiameseCAE(
  (cae): CAE(
    (encoder): Encoder(
      (encoder_): Sequential(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (10): ReLU()
        (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (decoder): Decoder(
      (decoder_): Sequential(
        (

In [18]:
unloader = transforms.ToPILImage()
def tensor_to_PIL(tensor):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    return image

In [21]:
with torch.no_grad():
    for idx, (inputs,targets) in enumerate(train_loader):
        encoded_x1, encoded_x2, decoded_x1, decoded_x2 = model(inputs, targets)
        for index,xxx in enumerate(decoded_x1):
            original_x = inputs[index]
            reconstr_x = xxx
            target_x = decoded_x2[index]

            image1 = tensor_to_PIL(original_x)
            image2 = tensor_to_PIL(reconstr_x)
            image3 =  tensor_to_PIL(target_x)
            path1 =  'img(ori)MSE/patch_{num}_original(train).png'.format(num=idx*256+index)
            path2 =  'img(recon)MSE/patch_{num}_reconstruct(train).png'.format(num=idx*256+index)
            path3 =  'img(target)MSE/patch_{num}_target(train).png'.format(num=idx*256+index)
            
            image1.save(path1)
            image2.save(path2)  
            image3.save(path3)
        break