In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from patchify import patchify
import PIL
from PIL import Image
PIL.Image.MAX_IMAGE_PIXELS = 933120000
import os
import shutil
import random
%matplotlib inline
import torchvision
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
class PatchDataset(Dataset):
    def __init__(self,root,target, train=True, transforms=None):
        super(PatchDataset, self).__init__()

        self.image_path = [os.path.join(root, x) for x in os.listdir(root)]      
        self.ref_path = [os.path.join(target,x) for x in os.listdir(target)]
        
        if transform is not None:
            self.transform = transform

        if train:
            self.images = self.image_path[: int(.8 * len(self.image_path))]
            self.ref = self.ref_path[: int(.8 * len(self.image_path))]
        else:
            self.images = self.image_path[int(.8 * len(self.image_path)):]
            self.ref = self.ref_path[int(.8 * len(self.image_path)):]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        return self.transform(self.images[item]),self.transform(self.ref[item])  

In [3]:
transform = transforms.Compose([
    lambda x: Image.open(x).convert('RGB'),
    transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [4]:
class UNetGenerator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNetGenerator, self).__init__()

        # Encoding layers
        self.encoder = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ),
            self._conv_block(64, 128),
            self._conv_block(128, 256),
            self._conv_block(256, 512)
        ])

        # Decoding layers
        self.decoder = nn.ModuleList([
            self._deconv_block(512, 256),
            self._deconv_block(512, 128),
            self._deconv_block(256, 64),
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        ])

    def forward(self, x):
        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)

        skips = list(reversed(skips[:-1]))
        for idx, (skip, layer) in enumerate(zip(skips, self.decoder[:-2])):
            x = layer(x)
            x = torch.cat((x, skip), dim=1)

        x = self.decoder[-2](x)
        x = self.decoder[-1](x)
        return x

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def _deconv_block(self, in_channels, out_channels, dropout=0.0):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        return nn.Sequential(*layers)


In [5]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            self._conv_block(64, 128),
            self._conv_block(128, 256),
            self._conv_block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )
    def forward(self, x):
        return self.model(x)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )


In [6]:
train_dataset = PatchDataset("img415_patches(64)","img418_patches(64)",train=True, transforms=transform)
test_dataset = PatchDataset("img415_patches(64)","img418_patches(64)",train=False, transforms=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
epochs = 100
lr = 0.0002
beta1 = 0.5
lambda_L1 = 100
best_loss = 1e9
# Initialize models
generator = UNetGenerator(3, 3).to(device)
discriminator = Discriminator(6).to(device)

# Define loss functions and optimizers
criterion_GAN = nn.MSELoss()
criterion_L1 = nn.L1Loss()
criterion_GAN = torch.nn.HuberLoss
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))


In [10]:
for epoch in range(epochs):
    for i, (image_1910, image_1970) in enumerate(train_loader):
        image_1910 = image_1910.to(device)
        image_1970 = image_1970.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        fake_image_1970 = generator(image_1910)
        real_pair = torch.cat((image_1910, image_1970), dim=1)
        fake_pair = torch.cat((image_1910, fake_image_1970.detach()), dim=1)

        real_preds = discriminator(real_pair)
        fake_preds = discriminator(fake_pair)

        real_targets = torch.ones_like(real_preds).to(device)
        fake_targets = torch.zeros_like(fake_preds).to(device)

        loss_D_real = criterion_GAN(real_preds, real_targets)
        loss_D_fake = criterion_GAN(fake_preds, fake_targets)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_preds = discriminator(fake_pair)

        loss_G_GAN = criterion_GAN(fake_preds, real_targets)
        loss_G_L1 = criterion_L1(fake_image_1970, image_1970) * lambda_L1

        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        optimizer_G.step()
    
    
    path = 'model_P2P_SSR/generator_{i}.pth'.format(i=epoch)
    print(f'Epoch [{epoch+1}/{epochs}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}')
    torch.save(generator.state_dict(), path)

TypeError: unsupported operand type(s) for +: 'HuberLoss' and 'HuberLoss'

In [8]:
for epoch in range(epochs):
    for i, (image_1910, image_1970) in enumerate(train_loader):
        image_1910 = image_1910.to(device)
        image_1970 = image_1970.to(device)

        # Train Discriminator
        optimizer_D.zero_grad() 
        fake_image_1970 = generator(image_1910)
        real_pair = torch.cat((image_1910, image_1970), dim=1)
        fake_pair = torch.cat((image_1910, fake_image_1970.detach()), dim=1)

        real_preds = discriminator(real_pair)
        fake_preds = discriminator(fake_pair)

        real_targets = torch.ones_like(real_preds).to(device)
        fake_targets = torch.zeros_like(fake_preds).to(device)

        loss_D_real = criterion_GAN(real_preds, real_targets)
        loss_D_fake = criterion_GAN(fake_preds, fake_targets)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_preds = discriminator(fake_pair)

        loss_G_GAN = criterion_GAN(fake_preds, real_targets)
        loss_G_L1 = criterion_L1(fake_image_1970, image_1970) * lambda_L1

        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        optimizer_G.step()
    
    
    path = 'model_P2P_SSR/generator_{i}.pth'.format(i=epoch)
    print(f'Epoch [{epoch+1}/{epochs}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}')
    torch.save(generator.state_dict(), path)

Epoch [1/100], Loss D: 0.3179, Loss G: 3.6422
Epoch [2/100], Loss D: 0.2473, Loss G: 5.2486
Epoch [3/100], Loss D: 0.1846, Loss G: 2.8642
Epoch [4/100], Loss D: 0.1610, Loss G: 2.8926
Epoch [5/100], Loss D: 0.0552, Loss G: 3.2287
Epoch [6/100], Loss D: 0.2171, Loss G: 2.6547
Epoch [7/100], Loss D: 0.1185, Loss G: 2.6292
Epoch [8/100], Loss D: 0.0427, Loss G: 3.1130
Epoch [9/100], Loss D: 0.0196, Loss G: 3.1847
Epoch [10/100], Loss D: 0.0197, Loss G: 2.8270
Epoch [11/100], Loss D: 0.0713, Loss G: 2.6242
Epoch [12/100], Loss D: 0.0863, Loss G: 2.1308
Epoch [13/100], Loss D: 0.0562, Loss G: 3.2406
Epoch [14/100], Loss D: 0.0086, Loss G: 2.6976
Epoch [15/100], Loss D: 0.0034, Loss G: 2.7385
Epoch [16/100], Loss D: 0.0061, Loss G: 2.6612
Epoch [17/100], Loss D: 0.2217, Loss G: 1.7533
Epoch [18/100], Loss D: 0.0822, Loss G: 1.7919
Epoch [19/100], Loss D: 0.0465, Loss G: 2.8218
Epoch [20/100], Loss D: 0.0459, Loss G: 1.9137
Epoch [21/100], Loss D: 0.0772, Loss G: 2.9829
Epoch [22/100], Loss D

In [12]:
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Resize, Compose

# Create a new generator
eval_generator = UNetGenerator(3, 3).to(device)

# Load the saved weights
eval_generator.load_state_dict(torch.load('model_P2P_SSR\generator_30.pth'))
eval_generator.eval()  # Set the model to evaluation mode


# Define the test dataset and the DataLoader
test_transforms = Compose([
    Resize((64, 64)),
    ToTensor()
])




In [None]:
import torchvision.utils as vutils
import os
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

output_dir = "generated_images(train)"
os.makedirs(output_dir, exist_ok=True)

output_dir2 = "input_images(train)"
os.makedirs(output_dir2, exist_ok=True)

output_dir3 = "target_images(train)"
os.makedirs(output_dir3, exist_ok=True)

psnrs = []
ssims = []

with torch.no_grad():
    for i, (image_1910, image_1970) in enumerate(train_loader):
#         print(image_1910.shape)
#         print(image_1970.shape)
  
        image_1910 = image_1910.to(device)
        image_1970 = image_1970.to(device)
        
        # Generate the output image
        generated_image = eval_generator(image_1910)
#         print(generated_image.shape)
        for index in range(len(image_1910)):
            inp = image_1910[index]
            tar = image_1970[index]
            out = generated_image[index]
            print(inp.shape)
            print(tar.shape)
            print(out.shape)
            
            
            image1 = tensor_to_PIL(inp)
            image2 = tensor_to_PIL(tar)
            image3 = tensor_to_PIL(out)
            
            path1 = 'input_images(train)/patch_{num}.tif'.format(num=index)
            path2 = 'target_images(train)/patch_{num}.tif'.format(num=index)
            path3 = 'generated_images(train)/patch_{num}.tif'.format(num=index)
            
            image1.save(path1)
            image2.save(path2)
            image3.save(path3)
        break