In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
from datetime import datetime
from PIL import Image
import torchvision.transforms.functional as TF
from io import BytesIO
import torch 
import torch.nn as nn
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np
import random, torch, os, numpy as np
import torch.nn as nn
import copy
import albumentations as A
from albumentations.pytorch import ToTensorV2
import sys
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import torch
import matplotlib.pyplot as plt
import pandas as pd
import torchvision.transforms as T
import torchvision
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [None]:
'''

'''
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image
from torchvision.utils import save_image

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f)) and f.endswith('.bmp')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

# Define the transforms
transform = transforms.Compose([
    transforms.Resize((576, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Directory to save the generated images
SAVE_DIR = "generated_images"
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

# Initialize the dataset and dataloader
dataset = CustomImageDataset(root_dir='data/train/good/good_200/', transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Define a function to denormalize the tensor images
def denormalize(tensor):
    tensor = tensor * 0.5 + 0.5
    return tensor.clamp(0, 1)

MODEL_DIRS = [
   
]

for model_dir in MODEL_DIRS:
    model_path = os.path.join("results",model_dir,"saved_images", 'genm.pth.tar')
    gen_M = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    checkpoint = torch.load(model_path, map_location=DEVICE)
    gen_M.load_state_dict(checkpoint["state_dict"])
    model = gen_M.to(DEVICE)

    # Load the model
    # model_path = os.path.join(model_dir, 'genm.pth.tar')
    # model =  Generator(img_channels=3, num_residuals=9).to(DEVICE)
    # model.load_state_dict(torch.load(model_dir, map_location=DEVICE))  # Load the state dict
    # model = model.to(DEVICE)  # Send model to device
    # model.eval()  # Set to evaluation mode

    # Create a directory to save generated images for this model
    save_dir = os.path.join('generated', model_dir)
    os.makedirs(save_dir, exist_ok=True)

    # Generate and save the images
    for i, input_images in enumerate(dataloader):
        input_images = input_images.to(DEVICE)

        with torch.no_grad():
            generated_images = model(input_images)  # Generate images
            generated_images = denormalize(generated_images)  # Denormalize

        # Save the generated images
        
        save_path = os.path.join(save_dir, f'image_{i:04d}.png')
        print(save_path)
        save_image(generated_images, save_path)

    print(f"Images from {model_dir} have been generated and saved.")

print("All images have been generated and saved.")