In [None]:
###
# This code is used for image style transfer using a pre-trained residual CycleGAN.
###

In [None]:
import os
import re
import torch
import torch.nn as nn
import shutil
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import gc
from torchvision.utils import save_image
import glob

GPU_ID = 0 
device = torch.device(f"cuda:{GPU_ID}" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

GPU_ID2 = 'cuda:' + str(GPU_ID)

In [None]:

class SimpleImageDataset(Dataset):
    def __init__(self, image_dir, size=640):
        self.image_dir = image_dir
        self.size = size
        
        self.images = sorted([
            f for f in os.listdir(image_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))
        ])
        
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])
        image = Image.open(image_path).convert('RGB')
        return {
            'image': self.transform(image),
            'path': self.images[idx]
        }

class ResidualBlock(nn.Module):
    """Residual Block with Instance Normalization"""
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class ResidualCycleGANGenerator(nn.Module):
    def __init__(self, input_size=640, initial_filters=32):
        super(ResidualCycleGANGenerator, self).__init__()
        assert input_size % 16 == 0, "Input size must be divisible by 16"

        self.encoder = nn.ModuleList([
            self._make_encoder_block(3, initial_filters),                    # 640 → 320
            self._make_encoder_block(initial_filters, initial_filters * 2),    # 320 → 160
            self._make_encoder_block(initial_filters * 2, initial_filters * 4),  # 160 → 80
            self._make_encoder_block(initial_filters * 4, initial_filters * 8)   # 80 → 40
        ])
        
        self.middle = nn.Sequential(
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),
        )

        self.decoder = nn.ModuleList([
            self._make_decoder_block(initial_filters * 16, initial_filters * 4),  # 40 → 80
            self._make_decoder_block(initial_filters * 8, initial_filters * 2),   # 80 → 160
            self._make_decoder_block(initial_filters * 4, initial_filters),       # 160 → 320
            self._make_decoder_block(initial_filters * 2, initial_filters)        # 320 → 640
        ])

        self.final = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(initial_filters + 3, initial_filters, kernel_size=3),
            nn.InstanceNorm2d(initial_filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(initial_filters, 3, kernel_size=1),
            nn.Tanh()
        )

    def _make_encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(kernel_size=2)
        )

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        encoder_features = []
        current = x
        for encoder in self.encoder:
            current = encoder(current)
            encoder_features.append(current)
        
        current = self.middle(current)
        current = current + encoder_features[-1]
        
        for i, decoder in enumerate(self.decoder):
            current = decoder(torch.cat([current, encoder_features[-(i+1)]], dim=1))

        residual = self.final(torch.cat([current, x], dim=1))
        
        return x + residual

In [None]:


class PTHFileProcessor:
    def __init__(self, pth_dir):
        self.pth_dir = pth_dir
        self.pattern = r'_source_([A-Z]{3}).*to([A-Z]{3})_'

    def validate_pth_files(self):
        pth_files = [f for f in os.listdir(self.pth_dir) if f.endswith('.pth')]
        if not pth_files:
            raise ValueError(f"No PTH files found in {self.pth_dir}")

        source_institutes = set()
        valid_files = []

        for pth_file in pth_files:
            match = re.search(self.pattern, pth_file)
            if match:
                source_inst = match.group(1)
                source_institutes.add(source_inst)
                valid_files.append({
                    'file': pth_file,
                    'source': source_inst,
                    'target': match.group(2)
                })

        if len(source_institutes) != 1:
            raise ValueError(f"Found multiple or no source institutes: {source_institutes}")

        return valid_files, source_institutes.pop()

class SimpleImageDataset(Dataset):
    def __init__(self, image_paths, size=640):
        self.image_paths = image_paths
        self.size = size
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        return {
            'image': self.transform(image),
            'path': image_path
        }

def denormalize_image(x):
    return (x + 1) * 0.5

class ImageConverter:
    def __init__(self, device=GPU_ID2):
        self.device = device
        self.model = None

    def load_model(self, model_path):
        if self.model is not None:
            del self.model
            torch.cuda.empty_cache()
            gc.collect()

        self.model = ResidualCycleGANGenerator(input_size=640).to(self.device)
        checkpoint = torch.load(model_path, map_location=self.device)
        
        if 'generator_AB' in checkpoint:
            state_dict = checkpoint['generator_AB']
        else:
            state_dict = checkpoint
        
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def convert_images(self, image_paths, output_paths, batch_size=32):
        dataset = SimpleImageDataset(image_paths)
        dataloader = DataLoader(dataset, batch_size=batch_size, 
                              shuffle=False, num_workers=4, pin_memory=True)
        
        print("\nChecking paths before conversion:")
        for in_path, out_path in zip(image_paths[:5], output_paths[:5]):
            print(f"Input: {in_path}")
            print(f"Output: {out_path}")
            print("-" * 80)
        
        batch_count = 0
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Converting images"):
                images = batch['image'].to(self.device)
                paths = batch['path']

                start_idx = batch_count * batch_size
                end_idx = start_idx + len(images)
                current_output_paths = output_paths[start_idx:end_idx]
                
                print(f"\nProcessing batch {batch_count}:")
                converted_images = self.model(images)

                for idx, (img, in_path, out_path) in enumerate(zip(converted_images, paths, current_output_paths)):
                    print(f"Saving image {start_idx + idx}:")
                    print(f"From: {in_path}")
                    print(f"To: {out_path}")

                    out_dir = os.path.dirname(out_path)
                    if not os.path.exists(out_dir):
                        print(f"Creating directory: {out_dir}")
                        os.makedirs(out_dir, exist_ok=True)

                    save_image(denormalize_image(img), out_path)
                    print(f"Saved successfully")
                    print("-" * 40)
                
                batch_count += 1

                if torch.cuda.is_available():
                    print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
                    print(f"GPU Memory cached: {torch.cuda.memory_cached() / 1024**2:.2f} MB")

def process_directory(input_dir, output_dir, pth_dir, batch_size=32):
    pth_processor = PTHFileProcessor(pth_dir)
    valid_pth_files, source_institute = pth_processor.validate_pth_files()
    print(f"Found {len(valid_pth_files)} valid PTH files for source institute {source_institute}")
    print("\nValid PTH files:")
    for pth_info in valid_pth_files:
        print(f"- {pth_info['file']}")
        print(f"  Source: {pth_info['source']}")
        print(f"  Target: {pth_info['target']}")

    image_pattern = os.path.join(input_dir, "*", "woAug", "*", "img", "*.png")
    all_image_paths = glob.glob(image_pattern)
    print(f"\nFound {len(all_image_paths)} total image paths")

    valid_images = []
    for img_path in all_image_paths:
        parts = Path(img_path).parts
        for part in parts:
            if source_institute in part:
                valid_images.append(img_path)
                break
    
    print(f"\nFiltered to {len(valid_images)} valid images")
    print("\nExample valid image paths:")
    for path in valid_images[:5]:
        print(f"- {path}")

    converter = ImageConverter(device=device)
    
    for pth_index, pth_info in enumerate(valid_pth_files):
        print(f"\n{'='*80}")
        print(f"Processing PTH file {pth_index + 1}/{len(valid_pth_files)}: {pth_info['file']}")
        suffix = f"_{pth_info['source']}2{pth_info['target']}"

        current_output_paths = []
        current_input_label_paths = []
        current_output_label_paths = []
        
        print("\nGenerating output paths...")
        for img_path in valid_images:
            rel_path = os.path.relpath(img_path, input_dir)
            out_dir = os.path.join(output_dir, os.path.dirname(rel_path))
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            new_name = f"{base_name}{suffix}.png"
            output_path = os.path.join(out_dir, new_name)
            current_output_paths.append(output_path)

            label_path = img_path.replace('/img/', '/lab/').replace('.png', '.txt')
            if os.path.exists(label_path):
                output_label_path = output_path.replace('/img/', '/lab/').replace('.png', '.txt')
                current_input_label_paths.append(label_path)
                current_output_label_paths.append(output_label_path)

            if len(current_output_paths) <= 5:
                print(f"\nInput image: {img_path}")
                print(f"Output image: {output_path}")
                if os.path.exists(label_path):
                    print(f"Input label: {label_path}")
                    print(f"Output label: {output_label_path}")

        print(f"\nGenerated {len(current_output_paths)} output paths")
        print(f"Found {len(current_input_label_paths)} label files")

        print(f"\nLoading model: {pth_info['file']}")
        converter.load_model(os.path.join(pth_dir, pth_info['file']))

        print(f"Model device: {next(converter.model.parameters()).device}")
        
        print("\nStarting image conversion...")
        converter.convert_images(valid_images, current_output_paths, batch_size=batch_size)

        print("\nCopying label files...")
        for in_label, out_label in zip(current_input_label_paths, current_output_label_paths):
            out_dir = os.path.dirname(out_label)
            if not os.path.exists(out_dir):
                os.makedirs(out_dir, exist_ok=True)
            try:
                shutil.copy2(in_label, out_label)
                if out_label in current_output_label_paths[:5]:
                    print(f"Copied: {in_label} -> {out_label}")
            except Exception as e:
                print(f"Error copying label file {in_label}: {str(e)}")

        print(f"\nCompleted processing with {pth_info['file']}")
        print(f"Converted {len(current_output_paths)} images")
        print(f"Copied {len(current_output_label_paths)} label files")

        del current_output_paths
        del current_input_label_paths
        del current_output_label_paths
        gc.collect()
        torch.cuda.empty_cache()

        if torch.cuda.is_available():
            print(f"GPU Memory allocated: {torch.cuda.memory_allocated(GPU_ID) / 1024**2:.2f} MB")
            print(f"GPU Memory cached: {torch.cuda.memory_cached(GPU_ID) / 1024**2:.2f} MB")

def debug_conversion(input_dir, output_dir, pth_dir, batch_size=32):  
    print("Starting debug conversion process...")
    print(f"Input directory: {input_dir}")
    print(f"Output directory: {output_dir}")
    print(f"PTH directory: {pth_dir}")
    print(f"Batch size: {batch_size}")
    
    try:
        process_directory(input_dir, output_dir, pth_dir, batch_size=batch_size) 
    except Exception as e:
        print(f"\nError occurred: {str(e)}")
        import traceback
        traceback.print_exc()


In [None]:
input_dir = "/PATH/TO/YOUR/INPUT"  
output_dir = "/PATH/TO/YOUR/PUTPUT"  
pth_dir = "/PATH/TO/YOUR/PTH_FILE/DIRECTORY"  

debug_conversion(input_dir, output_dir, pth_dir, batch_size=16)

In [None]:
#The input directory must contain a subdirectory called 'woAug'
#Images are stored in the 'img' folder, and labels are stored in the 'lab' folder
#PTH filenames must contain the patterns sourceXXX and toYYY_ (where XXX and YYY are 3-character institution codes)
#The folder names to be processed in the input directory must contain the 3 characters after source_ from the PTH file
#The output directory maintains the same structure as the input, with source and destination information added to the filenames

In [None]:


input_dir/
├── DIR1/
│   └── woAug/
│       └── XXX_/
│           ├── img/
│           │   ├── image001.png
│           │   ├── image002.png
│           │   └── ...
│           └── lab/
│               ├── image001.txt
│               ├── image002.txt
│               └── ...
├── DIR2/
│   └── woAug/
│       └── YYY_/
│           ├── img/
│           │   └── ...
│           └── lab/
│               └── ...
└── ...
    

pth_dir/
├── abc_best_model....pth
├── def_best_model....pth
└── ...



