In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from patchify import patchify
import PIL
from PIL import Image
PIL.Image.MAX_IMAGE_PIXELS = 933120000
import os
import shutil
import random
%matplotlib inline
import torchvision
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
class PatchDataset(Dataset):
    def __init__(self,root,target, train=True, transforms=None):
        super(PatchDataset, self).__init__()

        self.image_path = [os.path.join(root, x) for x in os.listdir(root)]      
        self.ref_path = [os.path.join(target,x) for x in os.listdir(target)]
        
        if transform is not None:
            self.transform = transform

        if train:
            self.images = self.image_path[: int(.8 * len(self.image_path))]
            self.ref = self.ref_path[: int(.8 * len(self.image_path))]
        else:
            self.images = self.image_path[int(.8 * len(self.image_path)):]
            self.ref = self.ref_path[int(.8 * len(self.image_path)):]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        return self.transform(self.images[item]),self.transform(self.ref[item])  

In [3]:
transform = transforms.Compose([
    lambda x: Image.open(x).convert('RGB'),
    transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.in1 = nn.InstanceNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.in2 = nn.InstanceNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out += residual
        return out

class ImprovedUNetGenerator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ImprovedUNetGenerator, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),

            ResidualBlock(512, 512)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize the generator
# generator = ImprovedUNetGenerator(3, 3).to(device)


In [5]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            self._conv_block(64, 128),
#             self._conv_block(128, 256),
#             self._conv_block(256, 512),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=1)
        )
    def forward(self, x):
        return self.model(x)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )


In [6]:
train_dataset = PatchDataset("OS_412(64)_SSIM_BlankRemove","OS_415(64)_SSIM_BlankRemove",train=True, transforms=transform)
test_dataset = PatchDataset("OS_412(64)_SSIM_BlankRemove","OS_415(64)_SSIM_BlankRemove",train=False, transforms=transform)

train_loader = DataLoader(train_dataset, batch_size=64*10, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64*10, shuffle=False)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
epochs = 1000
lr = 0.001
lr_D = 0.0002
beta1 = 0.5
lambda_L1 = 100
best_loss = 1e9
# Initialize models
# generator = UNetGenerator(3, 3).to(device)
generator = ImprovedUNetGenerator(3, 3).to(device)
discriminator = Discriminator(6).to(device)

# Define loss functions and optimizers
# criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()
# criterion_MSE = nn.MSELoss()
criterion_HUBER = torch.nn.HuberLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0004, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_D, betas=(beta1, 0.999))

# Training loop
for epoch in range(epochs):
    for i, (image_1910, image_1970) in enumerate(train_loader):
        image_1910 = image_1910.to(device)
        image_1970 = image_1970.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        fake_image_1970 = generator(image_1910)
        real_pair = torch.cat((image_1910, image_1970), dim=1)
        fake_pair = torch.cat((image_1910, fake_image_1970.detach()), dim=1)

        real_preds = discriminator(real_pair)
        fake_preds = discriminator(fake_pair)

        real_targets = torch.ones_like(real_preds).to(device)
        fake_targets = torch.zeros_like(fake_preds).to(device)

        loss_D_real = criterion_HUBER(real_preds, real_targets)
        loss_D_fake = criterion_HUBER(fake_preds, fake_targets)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_preds = discriminator(fake_pair)

        loss_G_GAN = criterion_HUBER(fake_preds, real_targets)
        loss_G_L1 = criterion_L1(fake_image_1970, image_1970) * lambda_L1

        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        optimizer_G.step()
    
    
    path = 'P2P_Model/generator_{i}.pth'.format(i=epoch)
    print(f'Epoch [{epoch+1}/{epochs}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}')
    torch.save(generator.state_dict(), path)


Epoch [1/1000], Loss D: 0.0155, Loss G: 20.0153
Epoch [2/1000], Loss D: 0.0079, Loss G: 16.1132
Epoch [3/1000], Loss D: 0.0136, Loss G: 13.2508
Epoch [4/1000], Loss D: 0.0102, Loss G: 12.0952
Epoch [5/1000], Loss D: 0.0148, Loss G: 11.6685
Epoch [6/1000], Loss D: 0.0273, Loss G: 10.9880
Epoch [7/1000], Loss D: 0.0043, Loss G: 10.6887
Epoch [8/1000], Loss D: 0.0033, Loss G: 10.6710
Epoch [9/1000], Loss D: 0.0031, Loss G: 10.4873
Epoch [10/1000], Loss D: 0.0029, Loss G: 10.4590
Epoch [11/1000], Loss D: 0.0026, Loss G: 10.3209
Epoch [12/1000], Loss D: 0.0024, Loss G: 10.1614
Epoch [13/1000], Loss D: 0.0021, Loss G: 10.0622
Epoch [14/1000], Loss D: 0.0023, Loss G: 10.0338
Epoch [15/1000], Loss D: 0.0019, Loss G: 10.0365
Epoch [16/1000], Loss D: 0.0210, Loss G: 9.7001
Epoch [17/1000], Loss D: 0.0027, Loss G: 9.6370
Epoch [18/1000], Loss D: 0.0023, Loss G: 9.5540
Epoch [19/1000], Loss D: 0.0019, Loss G: 9.4248
Epoch [20/1000], Loss D: 0.0019, Loss G: 9.3690
Epoch [21/1000], Loss D: 0.0017, L

Epoch [171/1000], Loss D: 0.0044, Loss G: 4.5616
Epoch [172/1000], Loss D: 0.0066, Loss G: 4.4769
Epoch [173/1000], Loss D: 0.0066, Loss G: 4.4739
Epoch [174/1000], Loss D: 0.0054, Loss G: 4.5316
Epoch [175/1000], Loss D: 0.0057, Loss G: 4.4683
Epoch [176/1000], Loss D: 0.0065, Loss G: 4.3601
Epoch [177/1000], Loss D: 0.0044, Loss G: 4.3161
Epoch [178/1000], Loss D: 0.0050, Loss G: 4.3468
Epoch [179/1000], Loss D: 0.0038, Loss G: 4.4147
Epoch [180/1000], Loss D: 0.0036, Loss G: 4.5568
Epoch [181/1000], Loss D: 0.0048, Loss G: 4.3934
Epoch [182/1000], Loss D: 0.0072, Loss G: 4.3001
Epoch [183/1000], Loss D: 0.0057, Loss G: 4.3816
Epoch [184/1000], Loss D: 0.0051, Loss G: 4.2858
Epoch [185/1000], Loss D: 0.0066, Loss G: 4.2558
Epoch [186/1000], Loss D: 0.0060, Loss G: 4.2050
Epoch [187/1000], Loss D: 0.0045, Loss G: 4.3726
Epoch [188/1000], Loss D: 0.0046, Loss G: 4.5013
Epoch [189/1000], Loss D: 0.0047, Loss G: 4.2891
Epoch [190/1000], Loss D: 0.0056, Loss G: 4.2049
Epoch [191/1000], Lo

Epoch [339/1000], Loss D: 0.0165, Loss G: 3.4633
Epoch [340/1000], Loss D: 0.0217, Loss G: 3.4646
Epoch [341/1000], Loss D: 0.0215, Loss G: 3.4779
Epoch [342/1000], Loss D: 0.0206, Loss G: 3.4094
Epoch [343/1000], Loss D: 0.0199, Loss G: 3.4047
Epoch [344/1000], Loss D: 0.0161, Loss G: 3.4564
Epoch [345/1000], Loss D: 0.0144, Loss G: 3.4485
Epoch [346/1000], Loss D: 0.0107, Loss G: 3.5668
Epoch [347/1000], Loss D: 0.0111, Loss G: 3.5344
Epoch [348/1000], Loss D: 0.0094, Loss G: 3.5389
Epoch [349/1000], Loss D: 0.0101, Loss G: 3.5610
Epoch [350/1000], Loss D: 0.0145, Loss G: 3.4651
Epoch [351/1000], Loss D: 0.0125, Loss G: 3.4822
Epoch [352/1000], Loss D: 0.0102, Loss G: 3.5382
Epoch [353/1000], Loss D: 0.0141, Loss G: 3.5311
Epoch [354/1000], Loss D: 0.0164, Loss G: 3.4716
Epoch [355/1000], Loss D: 0.0150, Loss G: 3.4994
Epoch [356/1000], Loss D: 0.0130, Loss G: 3.4998
Epoch [357/1000], Loss D: 0.0144, Loss G: 3.4020
Epoch [358/1000], Loss D: 0.0151, Loss G: 3.3951
Epoch [359/1000], Lo

Epoch [507/1000], Loss D: 0.0370, Loss G: 3.0155
Epoch [508/1000], Loss D: 0.0392, Loss G: 3.0132
Epoch [509/1000], Loss D: 0.0404, Loss G: 2.9669
Epoch [510/1000], Loss D: 0.0311, Loss G: 3.0679
Epoch [511/1000], Loss D: 0.0332, Loss G: 3.0194
Epoch [512/1000], Loss D: 0.0371, Loss G: 2.9854
Epoch [513/1000], Loss D: 0.0370, Loss G: 2.9926
Epoch [514/1000], Loss D: 0.0433, Loss G: 2.9826
Epoch [515/1000], Loss D: 0.0446, Loss G: 2.9483
Epoch [516/1000], Loss D: 0.0467, Loss G: 2.9349
Epoch [517/1000], Loss D: 0.0357, Loss G: 2.9367
Epoch [518/1000], Loss D: 0.0387, Loss G: 3.0021
Epoch [519/1000], Loss D: 0.0250, Loss G: 3.1175
Epoch [520/1000], Loss D: 0.0361, Loss G: 3.0831
Epoch [521/1000], Loss D: 0.0483, Loss G: 2.9270
Epoch [522/1000], Loss D: 0.0582, Loss G: 2.9049
Epoch [523/1000], Loss D: 0.0536, Loss G: 2.9308
Epoch [524/1000], Loss D: 0.0699, Loss G: 2.9273
Epoch [525/1000], Loss D: 0.0702, Loss G: 2.8636
Epoch [526/1000], Loss D: 0.0552, Loss G: 2.8960
Epoch [527/1000], Lo

Epoch [675/1000], Loss D: 0.0295, Loss G: 2.5615
Epoch [676/1000], Loss D: 0.0254, Loss G: 2.6084
Epoch [677/1000], Loss D: 0.0274, Loss G: 2.6188
Epoch [678/1000], Loss D: 0.0389, Loss G: 2.5478
Epoch [679/1000], Loss D: 0.0348, Loss G: 2.5660
Epoch [680/1000], Loss D: 0.0213, Loss G: 2.6957
Epoch [681/1000], Loss D: 0.0249, Loss G: 2.6021
Epoch [682/1000], Loss D: 0.0234, Loss G: 2.5663
Epoch [683/1000], Loss D: 0.0154, Loss G: 2.5869
Epoch [684/1000], Loss D: 0.0142, Loss G: 2.6534
Epoch [685/1000], Loss D: 0.0183, Loss G: 2.5813
Epoch [686/1000], Loss D: 0.0205, Loss G: 2.6301
Epoch [687/1000], Loss D: 0.0267, Loss G: 2.6503
Epoch [688/1000], Loss D: 0.0317, Loss G: 2.6970
Epoch [689/1000], Loss D: 0.0372, Loss G: 2.6752
Epoch [690/1000], Loss D: 0.0267, Loss G: 2.6157
Epoch [691/1000], Loss D: 0.0219, Loss G: 2.6286
Epoch [692/1000], Loss D: 0.0187, Loss G: 2.6697
Epoch [693/1000], Loss D: 0.0228, Loss G: 2.5672
Epoch [694/1000], Loss D: 0.0251, Loss G: 2.6062
Epoch [695/1000], Lo

Epoch [843/1000], Loss D: 0.0072, Loss G: 2.5297
Epoch [844/1000], Loss D: 0.0050, Loss G: 2.5110
Epoch [845/1000], Loss D: 0.0055, Loss G: 2.4943
Epoch [846/1000], Loss D: 0.0072, Loss G: 2.4793
Epoch [847/1000], Loss D: 0.0071, Loss G: 2.5265
Epoch [848/1000], Loss D: 0.0055, Loss G: 2.4901
Epoch [849/1000], Loss D: 0.0056, Loss G: 2.5589
Epoch [850/1000], Loss D: 0.0076, Loss G: 2.5154
Epoch [851/1000], Loss D: 0.0062, Loss G: 2.4807
Epoch [852/1000], Loss D: 0.0056, Loss G: 2.4600
Epoch [853/1000], Loss D: 0.0053, Loss G: 2.4958
Epoch [854/1000], Loss D: 0.0046, Loss G: 2.5222
Epoch [855/1000], Loss D: 0.0050, Loss G: 2.5114
Epoch [856/1000], Loss D: 0.0057, Loss G: 2.5152
Epoch [857/1000], Loss D: 0.0064, Loss G: 2.5578
Epoch [858/1000], Loss D: 0.0078, Loss G: 2.5607
Epoch [859/1000], Loss D: 0.0069, Loss G: 2.5159
Epoch [860/1000], Loss D: 0.0058, Loss G: 2.5109
Epoch [861/1000], Loss D: 0.0058, Loss G: 2.5647
Epoch [862/1000], Loss D: 0.0065, Loss G: 2.4967
Epoch [863/1000], Lo

In [None]:
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Resize, Compose
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
epochs = 100
lr = 0.0002
beta1 = 0.5
lambda_L1 = 100
best_loss = 1e9
# Initialize models
generator = ImprovedUNetGenerator(3, 3).to(device)
discriminator = Discriminator(6).to(device)

# Define loss functions and optimizers
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Create a new generator
eval_generator = ImprovedUNetGenerator(3, 3).to(device)

# Load the saved weights
eval_generator.load_state_dict(torch.load('model_save4\generator_999.pth'))
eval_generator.eval()  # Set the model to evaluation mode


# Define the test dataset and the DataLoader
test_transforms = Compose([
    Resize((64, 64)),
    ToTensor()
])




In [8]:
unloader = transforms.ToPILImage()
def tensor_to_PIL(tensor):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    return image

In [14]:
import torchvision.utils as vutils
import os
import numpy as np
# from skimage.metrics import peak_signal_noise_ratio, structural_similarity

output_dir = "generated_images(train)"
os.makedirs(output_dir, exist_ok=True)

output_dir2 = "input_images(train)"
os.makedirs(output_dir2, exist_ok=True)

output_dir3 = "target_images(train)"
os.makedirs(output_dir3, exist_ok=True)

psnrs = []
ssims = []

with torch.no_grad():
    for i, (image_1910, image_1970) in enumerate(train_loader):
#         print(image_1910.shape)
#         print(image_1970.shape)
  
        image_1910 = image_1910.to(device)
        image_1970 = image_1970.to(device)
        
        # Generate the output image
        generated_image = generator(image_1910)
#         print(generated_image.shape)
        for index in range(len(image_1910)):
            inp = image_1910[index]
            tar = image_1970[index]
            out = generated_image[index]
                        
            image1 = tensor_to_PIL(inp)
            image2 = tensor_to_PIL(tar)
            image3 = tensor_to_PIL(out)
            
            path1 = 'input_images(train)/patch_{num}(train).png'.format(num=index)
            path2 = 'target_images(train)/patch_{num}(train).png'.format(num=index)
            path3 = 'generated_images(train)/patch_{num}(train).png'.format(num=index)
            
            image1.save(path1)
            image2.save(path2)
            image3.save(path3)
#         # Save the generated image
#         vutils.save_image(generated_image, f"{output_dir}/generated_{i}.tif", normalize=True)
# #         vutils.save_image(generated_image, f"{output_dir2}/input_{i}.tif", normalize=True)
# #         vutils.save_image(generated_image, f"{output_dir3}/target_{i}.tif", normalize=True)
#         # Calculate PSNR and SSIM
#         generated_image_np = generated_image.squeeze().permute(1, 2, 0).cpu().numpy()
#         image_1970_np = image_1970.squeeze().permute(1, 2, 0).cpu().numpy()
        
#         psnr = peak_signal_noise_ratio(image_1970_np, generated_image_np)
#         ssim = structural_similarity(image_1970_np, generated_image_np, multichannel=True)
        
#         psnrs.append(psnr)
#         ssims.append(ssim)
        break
# print(f"Average PSNR: {np.mean(psnrs):.2f}")
# print(f"Average SSIM: {np.mean(ssims):.2f}")


In [None]:
with torch.no_grad():
    for idx, (inputs,targets) in enumerate(train_loader):
#         inputs = inputs.to(device)
#         targets = targets.to(device)
        input_recon, target_recon, diff = model(inputs, targets)
        for index in range(len(input_recon)):
            inp = inputs[index]
            tar = target_recon[index]
            out = input_recon[index]
            
            image1 = tensor_to_PIL(inp)
            image2 = tensor_to_PIL(tar)
            image3 = tensor_to_PIL(out)

#             path1 = 'img_input/patch_{num}.tif'.format(num=index)
#             path2 = 'img_target/patch_{num}.tif'.format(num=index)
#             path3 = 'img_output/patch_{num}.tif'.format(num=index)
            
            path1 = 'img_input(train)/patch_{num}.tif'.format(num=index)
            path2 = 'img_target(train)/patch_{num}.tif'.format(num=index)
            path3 = 'img_output(train)/patch_{num}.tif'.format(num=index)
            
            image1.save(path1)
            image2.save(path2)
            image3.save(path3)
            
        break