In [3]:
import torch
from torch.utils.data import DataLoader
from generator import Generator
from torchvision import transforms
from PIL import Image
import torchvision.transforms.functional as TF
import os

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

class TargetDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transforms=None):
        self.image_dir = image_dir
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.tif')])
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        if self.transforms:
            image = self.transforms(image)
        return image, self.image_files[idx]

# Set device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Load generator
G = Generator(img_channels=3).to(device)
G.load_state_dict(torch.load('/home/umang.shikarvar/Style_GAN/generator/G_epoch_60.pth', map_location=device))
G.eval()

# Dataset
target = TargetDataset(
    image_dir='/home/umang.shikarvar/Style_GAN/delhi/images',
    transforms=transform
)
target_loader = DataLoader(target, batch_size=1, pin_memory=True, num_workers=8, shuffle=False)

# Directory to save results
os.makedirs("generated_images", exist_ok=True)

# Inference and saving
# Generate and save images
# Denormalization helper
def denormalize(tensor):
    return tensor * 0.5 + 0.5  # from [-1, 1] → [0, 1]

# Inference and saving
with torch.no_grad():
    for img, filename in target_loader:
        img = img.to(device)
        gen = G(img)  # (1, C, H, W)

        # Denormalize before saving
        gen = denormalize(gen.squeeze(0).cpu()).clamp(0, 1)
        gen_image = TF.to_pil_image(gen)
        
        save_name = os.path.splitext(filename[0])[0] + '.png'
        gen_image.save(os.path.join('generated_images', save_name))