In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.models import resnet18, vgg19
from torchvision.utils import save_image
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os

In [7]:
import numpy as np
import cv2

def dct2(block):
    return cv2.dct(block.astype(np.float32))

def idct2(block):
    return cv2.idct(block.astype(np.float32))

def homomorphic_transform(image):
    return np.log1p(image)

def inverse_homomorphic_transform(image):
    return np.expm1(image)

In [8]:
import os
from PIL import Image

# Directories containing SAR and Optical images
sar_image_dir = '/home/hehe/SIH/RP/images/sar_train'
optical_image_dir = '/home/hehe/SIH/RP/images/oi_train'

# Get lists of file paths in the directories
sar_image_paths = [os.path.join(sar_image_dir, f) for f in os.listdir(sar_image_dir) if os.path.isfile(os.path.join(sar_image_dir, f))]
optical_image_paths = [os.path.join(optical_image_dir, f) for f in os.listdir(optical_image_dir) if os.path.isfile(os.path.join(optical_image_dir, f))]

# Print paths for validation
print("SAR image paths:", sar_image_paths)
print("Optical image paths:", optical_image_paths)

# Process images
for sar_image_path, optical_image_path in zip(sar_image_paths, optical_image_paths):
    try:
        # Open images
        sar_image = Image.open(sar_image_path)
        optical_image = Image.open(optical_image_path)
        
        # Example processing (print image sizes)
        print(f"Processing {sar_image_path} and {optical_image_path}")
        print(f"SAR Image Size: {sar_image.size}, Optical Image Size: {optical_image.size}")
    except IsADirectoryError as e:
        print(f"Error: {e}")
    except Exception as e:
        print(f"Unexpected error for {sar_image_path} or {optical_image_path}: {e}")


SAR image paths: ['/home/hehe/SIH/RP/images/sar_train/SAR1_1_2.png', '/home/hehe/SIH/RP/images/sar_train/SAR2_0_0.png', '/home/hehe/SIH/RP/images/sar_train/SAR4_0_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR3_2_2.png', '/home/hehe/SIH/RP/images/sar_train/SAR4_1_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR2_0_4.png', '/home/hehe/SIH/RP/images/sar_train/SAR1_1_5.png', '/home/hehe/SIH/RP/images/sar_train/SAR3_4_4.png', '/home/hehe/SIH/RP/images/sar_train/SAR1_0_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR5_4_2.png', '/home/hehe/SIH/RP/images/sar_train/SAR2_5_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR3_2_0.png', '/home/hehe/SIH/RP/images/sar_train/SAR2_8_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR4_7_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR4_5_5.png', '/home/hehe/SIH/RP/images/sar_train/SAR2_0_1.png', '/home/hehe/SIH/RP/images/sar_train/SAR6_0_0.png', '/home/hehe/SIH/RP/images/sar_train/SAR4_2_3.png', '/home/hehe/SIH/RP/images/sar_train/SAR4_7_5.png', '/home/hehe/S

In [9]:
BATCH_SIZE = 32  # Define the batch size

class SAROpticalDataset(Dataset):
    def __init__(self, sar_images, optical_images, transform=None):
        self.sar_images = sar_images
        self.optical_images = optical_images
        self.transform = transform

    def __len__(self):
        return len(self.sar_images)

    def __getitem__(self, idx):
        sar_image = Image.open(self.sar_images[idx]).convert("L")
        optical_image = Image.open(self.optical_images[idx]).convert("RGB")
        if self.transform:
            sar_image = self.transform(sar_image)
            optical_image = self.transform(optical_image)
        return sar_image, optical_image

transform = Compose([
    ToTensor(),
    Normalize(mean=[0.5], std=[0.5])
])

sar_train, sar_test, optical_train, optical_test = train_test_split(
    sar_image_paths, optical_image_paths, test_size=0.2, random_state=42
)

train_dataset = SAROpticalDataset(sar_train, optical_train, transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
class DCTRB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DCTRB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.dct = nn.Conv2d(out_channels, out_channels, kernel_size=1)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.dct(out)
        out += residual
        return out

In [11]:
class LightASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LightASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.global_pool(x)
        x4 = self.conv4(x4)
        x4 = nn.functional.interpolate(x4, size=x1.size()[2:], mode='bilinear', align_corners=True)
        out = x1 + x2 + x3 + x4
        out = self.relu(out)
        return out

In [12]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            DCTRB(64, 64),
            LightASPP(64, 64)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [13]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [16]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

adversarial_loss = nn.BCELoss()
pixelwise_loss = nn.SmoothL1Loss()

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg = vgg19(pretrained=True).features
        self.slice1 = nn.Sequential(*list(vgg.children())[:4])
        self.slice2 = nn.Sequential(*list(vgg.children())[4:9])
        self.slice3 = nn.Sequential(*list(vgg.children())[9:16])

    def forward(self, x):
        h = self.slice1(x)
        h_relu2_2 = self.slice2(h)
        h_relu3_2 = self.slice3(h_relu2_2)
        return h_relu2_2, h_relu3_2

vgg_extractor = VGGFeatureExtractor().to(DEVICE)

def style_loss(generated, target):
    gen_features = vgg_extractor(generated)
    target_features = vgg_extractor(target)
    loss = 0
    for gf, tf in zip(gen_features, target_features):
        loss += torch.mean((gf - tf) ** 2)
    return loss

In [18]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

LEARNING_RATE = 0.0002  # Define the learning rate

generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
optimizer_D = optim.SGD(discriminator.parameters(), lr=LEARNING_RATE)

EPOCHS = 100  # Updated number of epochs

for epoch in range(EPOCHS):
    for i, (sar, optical) in enumerate(train_loader):
        sar, optical = sar.to(DEVICE), optical.to(DEVICE)

        # Train Generator
        optimizer_G.zero_grad()
        gen_optical = generator(sar)
        valid = torch.ones((sar.size(0), 1, 29, 29)).to(DEVICE)
        g_loss = adversarial_loss(discriminator(gen_optical), valid) + pixelwise_loss(gen_optical, optical) + style_loss(gen_optical, optical)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(optical), valid)
        fake_loss = adversarial_loss(discriminator(gen_optical.detach()), torch.zeros_like(valid))
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{EPOCHS}] Batch [{i}/{len(train_loader)}] G Loss: {g_loss.item()} D Loss: {d_loss.item()}")

Epoch [0/100] Batch [0/9] G Loss: 7.144992828369141 D Loss: 0.6931685209274292
Epoch [1/100] Batch [0/9] G Loss: 6.710237979888916 D Loss: 0.6999185085296631


KeyboardInterrupt: 