In [2]:
!unzip train3.zip -d train3

Archive:  train3.zip
   creating: train3/train3/
  inflating: train3/train3/.DS_Store  
  inflating: train3/__MACOSX/train3/._.DS_Store  
   creating: train3/train3/input/
   creating: train3/train3/output/
  inflating: train3/train3/input/.DS_Store  
  inflating: train3/__MACOSX/train3/input/._.DS_Store  
  inflating: train3/train3/input/s1.jpg  
  inflating: train3/__MACOSX/train3/input/._s1.jpg  
  inflating: train3/train3/input/s3.jpg  
  inflating: train3/__MACOSX/train3/input/._s3.jpg  
  inflating: train3/train3/input/s2.jpg  
  inflating: train3/__MACOSX/train3/input/._s2.jpg  
  inflating: train3/train3/output/.DS_Store  
  inflating: train3/__MACOSX/train3/output/._.DS_Store  
  inflating: train3/train3/output/o2.jpg  
  inflating: train3/__MACOSX/train3/output/._o2.jpg  
  inflating: train3/train3/output/o3.jpg  
  inflating: train3/__MACOSX/train3/output/._o3.jpg  
  inflating: train3/train3/output/o1.jpg  
  inflating: train3/__MACOSX/train3/output/._o1.jpg  


In [1]:
pip install piq


Collecting piq
  Downloading piq-0.8.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision>=0.10.0->piq)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
C

In [3]:
import torch
import torchvision

In [4]:
from torch.utils.data import Dataset
import torch
from PIL import Image
import os
import torchvision.transforms as T

class PairedSAROpticalDataset(Dataset):
    def __init__(self, input_dir, output_dir):
        self.input_dir = input_dir
        self.output_dir = output_dir

        valid_exts = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

        self.input_files = sorted([
            f for f in os.listdir(input_dir)
            if f.lower().endswith(valid_exts)
        ])

        self.output_files = sorted([
            f for f in os.listdir(output_dir)
            if f.lower().endswith(valid_exts)
        ])

        assert len(self.input_files) == len(self.output_files), "Mismatch in number of SAR and optical images."

        self.transform_sar = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),  # Shape: (1, H, W)
        ])

        self.transform_opt = T.Compose([
            T.Resize((256, 256)),
            T.ToTensor(),  # Shape: (3, H, W)
        ])

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        sar_path = os.path.join(self.input_dir, self.input_files[idx])
        opt_path = os.path.join(self.output_dir, self.output_files[idx])

        sar = Image.open(sar_path).convert('L')  # Grayscale
        opt = Image.open(opt_path).convert('RGB')  # Color

        sar_tensor = self.transform_sar(sar)      # (1, H, W)
        opt_tensor = self.transform_opt(opt)      # (3, H, W)

        return sar_tensor, opt_tensor


In [5]:
import torch
import torch.nn as nn

class EColorNetANN(nn.Module):
    def __init__(self, img_size=256):
        super(EColorNetANN, self).__init__()

        self.img_size = img_size
        self.input_dim = img_size * img_size       # 1-channel SAR image flattened
        self.output_dim = self.input_dim * 3       # RGB output flattened

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 2048),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(2048, 4096),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(4096, self.output_dim),
            nn.Sigmoid()  # Output pixel values in [0, 1]
        )

    def forward(self, x):  # x: (B, 1, H, W)
        B, _, H, W = x.shape
        x_flat = x.view(B, -1)                  # (B, H*W)
        out = self.fc(x_flat)                   # (B, H*W*3)
        out = out.view(B, 3, H, W)              # (B, 3, H, W)
        r = out[:, 0:1, :, :]
        g = out[:, 1:2, :, :]
        b = out[:, 2:3, :, :]
        return r, g, b


In [6]:
import torch
import torch.nn as nn
from piq import ssim

class ChannelwiseCombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.mse = nn.MSELoss()
        self.alpha = alpha

    def forward(self, preds, targets):
        # preds, targets: (B,3,H,W)
        losses = []
        for c in range(3):
            pred = preds[:, c:c+1, :, :]
            target = targets[:, c:c+1, :, :]
            mse_loss = self.mse(pred, target)
            ssim_score = ssim(pred, target, data_range=1.0)
            ssim_loss = 1 - ssim_score
            combined = self.alpha * mse_loss + (1 - self.alpha) * ssim_loss
            losses.append(combined)

        return sum(losses)


In [7]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

# ==== Paths ====
input_dir = '/content/train3/train3/input'
output_dir = '/content/train3/train3/output'
save_dir = './outputs'
os.makedirs(save_dir, exist_ok=True)

# ==== Dataset and Dataloader ====
dataset = PairedSAROpticalDataset(input_dir, output_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

# ==== Model, Loss Function, Optimizer ====
model = EColorNetANN().cuda()
criterion = ChannelwiseCombinedLoss(alpha=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# ==== Training Loop ====
epochs = 300
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0

    for i, (sar, rgb_gt) in enumerate(dataloader):
        sar = sar.cuda()
        rgb_gt = rgb_gt.cuda()

        optimizer.zero_grad()

        r_pred, g_pred, b_pred = model(sar)  # Each (B,1,H,W)
        rgb_pred = torch.cat([r_pred, g_pred, b_pred], dim=1)  # (B,3,H,W)

        loss = criterion(rgb_pred, rgb_gt)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Visualize only the first batch of the epoch
        if i == 0:
            model.eval()
            with torch.no_grad():
                rgb_fake = torch.cat([r_pred, g_pred, b_pred], dim=1)
                rgb_gt = torch.cat([r_gt, g_gt, b_gt], dim=1)
                sar_vis = sar.repeat(1, 3, 1, 1)

                combined = torch.cat([sar_vis, rgb_fake, rgb_gt], dim=0)
                grid = make_grid(combined, nrow=sar.size(0), normalize=True)
                save_image(grid, os.path.join(save_dir, f'epoch_{epoch+1:03d}.png'))
            model.train()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs} - Avg Loss: {avg_loss:.4f}")

    # Save model checkpoint
torch.save(model.state_dict(), os.path.join(save_dir, f'enet_epoch_{epoch+1:03d}.pth'))


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 462.12 MiB is free. Process 5019 has 14.29 GiB memory in use. Of the allocated memory 14.15 GiB is allocated by PyTorch, and 6.16 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)