In [None]:
###
# This code is used for training a residual cycleGAN. 
# Set the source images for transformation in source_dirs and the images with the target style in target_dir.
###

In [None]:
import os
import torch
import datetime
import glob
import re
import gc
import shutil
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm
from copy import deepcopy

SSIM_WINDOW_SIZE = 11
MSE_EPSILON = 1e-10

GPU_ID = 0 
device = torch.device(f"cuda:{GPU_ID}" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def setup_experiment(base_dir=None):
    paths = {
        'base': base_dir,
        'checkpoints': os.path.join(base_dir, 'checkpoints'),
        'results': os.path.join(base_dir, 'results'),
        'logs': os.path.join(base_dir, 'logs'),
        'samples': os.path.join(base_dir, 'samples')
    }
    for path in paths.values():
        os.makedirs(path, exist_ok=True)
        print(f"Created directory: {path}")
    return paths


def denormalize_image(x):
    return (x + 1) * 0.5

def normalize_image(x):
    return x * 2 - 1

def visualize_samples(real_A, real_B, fake_A=None, fake_B=None, figsize=(15, 5)):
    plt.figure(figsize=figsize)
    images = []
    titles = []
    
    images.extend([denormalize_image(real_A), denormalize_image(real_B)])
    titles.extend(['Real A', 'Real B'])
       
    if fake_A is not None and fake_B is not None:
        images.extend([denormalize_image(fake_B), denormalize_image(fake_A)])
        titles.extend(['Fake B', 'Fake A'])
    
    for i, (img, title) in enumerate(zip(images, titles)):
        plt.subplot(1, len(images), i + 1)
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy().transpose(1, 2, 0)
        plt.imshow(img)
        plt.title(title)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
class MetricsCalculator:
    def __init__(self, window_size=11):
        self.window_size = window_size
    
    def compute_ssim(self, img1, img2):
        img1 = denormalize_image(img1)
        img2 = denormalize_image(img2)
        C1 = (0.01 * 1) ** 2
        C2 = (0.03 * 1) ** 2
        
        ssim_value = 0.0
        for channel in range(3):  
            img1_channel = img1[:, channel:channel+1, :, :]
            img2_channel = img2[:, channel:channel+1, :, :]

            kernel = torch.ones(1, 1, self.window_size, self.window_size).to(img1.device)
            kernel = kernel / (self.window_size ** 2)

            mu1 = F.conv2d(img1_channel, kernel, padding=self.window_size//2)
            mu2 = F.conv2d(img2_channel, kernel, padding=self.window_size//2)
            
            mu1_sq = mu1 ** 2
            mu2_sq = mu2 ** 2
            mu1_mu2 = mu1 * mu2

            sigma1_sq = F.conv2d(img1_channel * img1_channel, kernel, padding=self.window_size//2) - mu1_sq
            sigma2_sq = F.conv2d(img2_channel * img2_channel, kernel, padding=self.window_size//2) - mu2_sq
            sigma12 = F.conv2d(img1_channel * img2_channel, kernel, padding=self.window_size//2) - mu1_mu2
            
            ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
                      ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
            
            ssim_value += ssim_map.mean()
        
        return (ssim_value / 3).item()
    
    def compute_psnr(self, img1, img2):
        img1 = denormalize_image(img1)
        img2 = denormalize_image(img2)
        
        mse = F.mse_loss(img1, img2)
        if mse < 1e-10:
            return float('inf')
        return 20 * torch.log10(1.0 / torch.sqrt(mse)).item()
    
    def compute_mse(self, img1, img2):
        img1 = denormalize_image(img1)
        img2 = denormalize_image(img2)
        return F.mse_loss(img1, img2).item()
    
    def compute_all_metrics(self, img1, img2):
        return {
            'ssim': self.compute_ssim(img1, img2),
            'psnr': self.compute_psnr(img1, img2),
            'mse': self.compute_mse(img1, img2)
        }

class StainDataset(Dataset):
    def __init__(self, source_dir, target_dir, size=640):
        super().__init__()
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.size = size
        
        self.source_images = sorted([
            f for f in os.listdir(source_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))
        ])
        self.target_images = sorted([
            f for f in os.listdir(target_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))
        ])
        
        print(f"Found {len(self.source_images)} source images and {len(self.target_images)} target images")
        
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return min(len(self.source_images), len(self.target_images))


    def __getitem__(self, idx):
        source_path = os.path.join(self.source_dir, self.source_images[idx])
        source_image = Image.open(source_path).convert('RGB')
        
        target_idx = torch.randint(0, len(self.target_images), (1,)).item()
        target_path = os.path.join(self.target_dir, self.target_images[target_idx])
        target_image = Image.open(target_path).convert('RGB')
        seed = torch.randint(0, 2**32, (1,)).item()

        torch.manual_seed(seed)
        source_image = self.transform(source_image)

        torch.manual_seed(seed)
        target_image = self.transform(target_image)
        
        return source_image, target_image



def create_dataloaders(source_dir, target_dir, batch_size=4, num_workers=4, 
                      train_ratio=0.8, image_size=640):
    print(f"Creating dataloaders from {source_dir} and {target_dir}")

    full_dataset = StainDataset(source_dir, target_dir, size=image_size)
    dataset_size = len(full_dataset)
    print(f"Total dataset size: {dataset_size}")

    train_size = int(train_ratio * dataset_size)
    val_size = dataset_size - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"Training set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

In [None]:
class ResidualBlock(nn.Module):
    """Residual Block with Instance Normalization"""
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


class ResidualCycleGANGenerator(nn.Module):

    def __init__(self, input_size=640, initial_filters=32):
        super(ResidualCycleGANGenerator, self).__init__()
        assert input_size % 16 == 0, "Input size must be divisible by 16"
        
        # encoder
        self.encoder = nn.ModuleList([
            self._make_encoder_block(3, initial_filters),                    # 640 → 320
            self._make_encoder_block(initial_filters, initial_filters * 2),    # 320 → 160
            self._make_encoder_block(initial_filters * 2, initial_filters * 4),  # 160 → 80
            self._make_encoder_block(initial_filters * 4, initial_filters * 8)   # 80 → 40
        ])
        
        # residual block
        self.middle = nn.Sequential(
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),
            ResidualBlock(initial_filters * 8),  
            ResidualBlock(initial_filters * 8),  
        )
        
        # decorder
        self.decoder = nn.ModuleList([
            self._make_decoder_block(initial_filters * 16, initial_filters * 4),  # 40 → 80
            self._make_decoder_block(initial_filters * 8, initial_filters * 2),   # 80 → 160
            self._make_decoder_block(initial_filters * 4, initial_filters),       # 160 → 320
            self._make_decoder_block(initial_filters * 2, initial_filters)        # 320 → 640
        ])
        
        # Final Output Layer (Difference Map Generation)
        self.final = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(initial_filters + 3, initial_filters, kernel_size=3),
            nn.InstanceNorm2d(initial_filters),
            nn.ReLU(inplace=True),
            nn.Conv2d(initial_filters, 3, kernel_size=1),
            nn.Tanh()
        )

    def _make_encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(kernel_size=2)  
        )

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),  
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        encoder_features = []
        current = x
        for encoder in self.encoder:
            current = encoder(current)
            encoder_features.append(current)
        
        current = self.middle(current)
        current = current + encoder_features[-1]  
        
        for i, decoder in enumerate(self.decoder):
            current = decoder(torch.cat([current, encoder_features[-(i+1)]], dim=1))
        
        residual = self.final(torch.cat([current, x], dim=1))
        
        return x + residual


class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_channels=3, ndf=64):
        super(PatchGANDiscriminator, self).__init__()
        
        self.net = nn.Sequential(
            # Layer 1: no normalization
            nn.Conv2d(input_channels, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 2
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 3
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 4
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Output layer
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.net(x)




In [None]:
class CycleGANTrainer:
    def __init__(self, generator_AB, generator_BA, discriminator_A, discriminator_B,
                 device, checkpoint_dir, lr=0.0001, beta1=0.5, beta2=0.999, 
                 lambda_cycle=10.0, lambda_identity=0.5, lambda_residual=0.2):
        self.device = device
        self.checkpoint_dir = checkpoint_dir  
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity
        self.lambda_residual = lambda_residual  
        self.gen_AB = generator_AB.to(device)
        self.gen_BA = generator_BA.to(device)
        self.disc_A = discriminator_A.to(device)
        self.disc_B = discriminator_B.to(device)
        
        self.optimizer_G = torch.optim.Adam(
            list(generator_AB.parameters()) + list(generator_BA.parameters()),
            lr=lr, betas=(beta1, beta2)
        )
        self.optimizer_D = torch.optim.Adam(
            list(discriminator_A.parameters()) + list(discriminator_B.parameters()),
            lr=lr, betas=(beta1, beta2)
        )
        
        self.scheduler_G = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_G,
            mode='max',  
            factor=0.5,  
            patience=5,  
            verbose=True,
            min_lr=1e-6  
        )

        self.scheduler_D = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_D,
            mode='max',
            factor=0.5,
            patience=5,
            verbose=True,
            min_lr=1e-6
        )

        
        self.lambda_cycle = lambda_cycle
        self.metrics_calculator = MetricsCalculator()
        
        self.loss_history = {
            'G_loss': [], 'D_loss': [],
            'cycle_loss': [], 'metrics': []
        }
    
    def compute_cycle_loss(self, real_image, cycled_image):
        pixel_wise_loss = F.l1_loss(cycled_image, real_image, reduction='none')
        flat_loss = pixel_wise_loss.view(pixel_wise_loss.size(0), -1)
        threshold = torch.quantile(flat_loss, 0.9, dim=1, keepdim=True)
        mask = (flat_loss > threshold).float()
        return (flat_loss * mask).mean()


    def train_step(self, real_A, real_B):
        batch_size = real_A.size(0)
    
        fake_B = self.gen_AB(real_A)
        fake_A = self.gen_BA(real_B)
        cycle_A = self.gen_BA(fake_B)
        cycle_B = self.gen_AB(fake_A)
        
        self.optimizer_G.zero_grad()
        
        fake_A_disc = self.disc_A(fake_A)
        fake_B_disc = self.disc_B(fake_B)
        real_label = torch.ones_like(fake_A_disc).to(self.device)
        
        g_loss_A = F.mse_loss(fake_A_disc, real_label)
        g_loss_B = F.mse_loss(fake_B_disc, real_label)
        g_loss_adv = g_loss_A + g_loss_B
        
        cycle_loss_A = self.compute_cycle_loss(real_A, cycle_A)
        cycle_loss_B = self.compute_cycle_loss(real_B, cycle_B)
        cycle_loss = (cycle_loss_A + cycle_loss_B) * self.lambda_cycle
        
        identity_A = self.gen_BA(real_A)
        identity_B = self.gen_AB(real_B)
        identity_loss = (
            F.l1_loss(identity_A, real_A) +
            F.l1_loss(identity_B, real_B)
        ) * self.lambda_identity

        residual_map_in_A_domain = fake_A - real_A
        residual_map_in_B_domain = fake_B - real_B
        residual_loss = self.lambda_residual * (
            torch.mean(torch.abs(residual_map_in_A_domain)) +
            torch.mean(torch.abs(residual_map_in_B_domain))
        )

        g_loss = g_loss_adv + cycle_loss + identity_loss + residual_loss
        g_loss.backward()
        self.optimizer_G.step()
        
        self.optimizer_D.zero_grad()
    
        d_real_A = self.disc_A(real_A)
        d_real_B = self.disc_B(real_B)
        d_fake_A = self.disc_A(fake_A.detach())
        d_fake_B = self.disc_B(fake_B.detach())
        
        d_loss = (
            F.mse_loss(d_real_A, torch.ones_like(d_real_A)) +
            F.mse_loss(d_real_B, torch.ones_like(d_real_B)) +
            F.mse_loss(d_fake_A, torch.zeros_like(d_fake_A)) +
            F.mse_loss(d_fake_B, torch.zeros_like(d_fake_B))
        ) * 0.5
        
        d_loss.backward()
        self.optimizer_D.step()

        return {
            'g_loss': g_loss.item(),
            'd_loss': d_loss.item(),
            'cycle_loss': cycle_loss.item(),  
            'identity_loss': identity_loss.item(),
            'residual_loss': residual_loss.item(),
            'fake_A': fake_A,
            'fake_B': fake_B
        }


    def validate(self, val_loader):
        self.gen_AB.eval()
        self.gen_BA.eval()
        
        metrics_sum = {
            'direct_ssim': 0.0, 'direct_psnr': 0.0, 'direct_mse': 0.0,
            'cycle_ssim': 0.0, 'cycle_psnr': 0.0, 'cycle_mse': 0.0,
            'combined_ssim': 0.0, 'combined_psnr': 0.0, 'combined_mse': 0.0
        }
        
        with torch.no_grad():
            for real_A, real_B in val_loader:
                real_A = real_A.to(self.device)
                real_B = real_B.to(self.device)
                
                fake_B = self.gen_AB(real_A)
                fake_A = self.gen_BA(real_B)
                
                cycle_A = self.gen_BA(fake_B)
                cycle_B = self.gen_AB(fake_A)
                
                metrics_AB = self.metrics_calculator.compute_all_metrics(fake_B, real_B)
                metrics_BA = self.metrics_calculator.compute_all_metrics(fake_A, real_A)
                
                metrics_cycle_A = self.metrics_calculator.compute_all_metrics(cycle_A, real_A)
                metrics_cycle_B = self.metrics_calculator.compute_all_metrics(cycle_B, real_B)
                
                for key in ['ssim', 'psnr', 'mse']:
                    metrics_sum[f'direct_{key}'] += (metrics_AB[key] + metrics_BA[key]) / 2
            
                for key in ['ssim', 'psnr', 'mse']:
                    metrics_sum[f'cycle_{key}'] += (metrics_cycle_A[key] + metrics_cycle_B[key]) / 2
                
                for key in ['ssim', 'psnr', 'mse']:
                    direct_val = (metrics_AB[key] + metrics_BA[key]) / 2
                    cycle_val = (metrics_cycle_A[key] + metrics_cycle_B[key]) / 2
                    metrics_sum[f'combined_{key}'] += (direct_val + cycle_val) / 2
        
        num_batches = len(val_loader)
        avg_metrics = {k: v / num_batches for k, v in metrics_sum.items()}
        
        self.gen_AB.train()
        self.gen_BA.train()
        
        return avg_metrics



    def save_checkpoint(self, epoch, metrics, is_best=False):
        timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        checkpoint = {
            'epoch': epoch,
            'generator_AB': self.gen_AB.state_dict(),
            'generator_BA': self.gen_BA.state_dict(),
            'discriminator_A': self.disc_A.state_dict(),
            'discriminator_B': self.disc_B.state_dict(),
            'optimizer_G': self.optimizer_G.state_dict(),
            'optimizer_D': self.optimizer_D.state_dict(),
            'scheduler_G': self.scheduler_G.state_dict(),
            'scheduler_D': self.scheduler_D.state_dict(),
            'metrics': metrics,
            'loss_history': self.loss_history
        }
        
        prefix = 'best' if is_best else f'epoch_{epoch}'
        save_path = os.path.join(
            self.checkpoint_dir,
            f'{prefix}_model_combined_ssim_{metrics["combined_ssim"]:.4f}_direct_{metrics["direct_ssim"]:.4f}_cycle_{metrics["cycle_ssim"]:.4f}_{timestamp}.pth'
        )
        torch.save(checkpoint, save_path)
        print(f"Saved checkpoint: {save_path}")



class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True  



def train(trainer, train_loader, val_loader, num_epochs, save_dir, save_interval=5):
    best_combined_ssim = -float('inf')
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    
    for epoch in range(num_epochs):
        epoch_losses = {'g_loss': 0., 'd_loss': 0., 'cycle_loss': 0.}
        num_batches = len(train_loader)
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for i, (real_A, real_B) in enumerate(pbar):
            real_A = real_A.to(trainer.device)
            real_B = real_B.to(trainer.device)
            
            losses = trainer.train_step(real_A, real_B)
            
            for key in ['g_loss', 'd_loss', 'cycle_loss']:
                epoch_losses[key] += losses[key]
            
            current_losses = {
                key: value / (i + 1) 
                for key, value in epoch_losses.items()
            }
            
            pbar.set_postfix({
                'G_loss': f"{current_losses['g_loss']:.4f}",
                'D_loss': f"{current_losses['d_loss']:.4f}",
                'Cycle_loss': f"{current_losses['cycle_loss']:.4f}"
            })
        
        avg_losses = {key: value / num_batches for key, value in epoch_losses.items()}
        
        metrics = trainer.validate(val_loader)
        
        early_stopping(-metrics['combined_ssim'])  
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch}")
            break        
            
        if (epoch + 1) % save_interval == 0:
            save_validation_images(trainer, val_loader, save_dir, epoch)
        
        if metrics['combined_ssim'] > best_combined_ssim:
            best_combined_ssim = metrics['combined_ssim']
            trainer.save_checkpoint(epoch, metrics, is_best=True)
        
        if (epoch + 1) % save_interval == 0:
            trainer.save_checkpoint(epoch, metrics)
        
        if isinstance(trainer.scheduler_G, torch.optim.lr_scheduler.ReduceLROnPlateau):
            trainer.scheduler_G.step(metrics['combined_ssim'])
            trainer.scheduler_D.step(metrics['combined_ssim'])
        else:
            trainer.scheduler_G.step()
            trainer.scheduler_D.step()
        
        print(f"\nEpoch {epoch+1} Results:")
        print("Training Losses:")
        print(f"  Generator Loss: {avg_losses['g_loss']:.4f}")
        print(f"  Discriminator Loss: {avg_losses['d_loss']:.4f}")
        print(f"  Cycle Loss: {avg_losses['cycle_loss']:.4f}")
        print("\nValidation Metrics:")
        print("Direct Translation Quality:")
        print(f"  SSIM: {metrics['direct_ssim']:.4f}")
        print(f"  PSNR: {metrics['direct_psnr']:.4f}")
        print(f"  MSE: {metrics['direct_mse']:.4f}")
        print("\nCycle Consistency Quality:")
        print(f"  SSIM: {metrics['cycle_ssim']:.4f}")
        print(f"  PSNR: {metrics['cycle_psnr']:.4f}")
        print(f"  MSE: {metrics['cycle_mse']:.4f}")
        print("\nCombined Metrics (Direct + Cycle):")
        print(f"  SSIM: {metrics['combined_ssim']:.4f}")
        print(f"  PSNR: {metrics['combined_psnr']:.4f}")
        print(f"  MSE: {metrics['combined_mse']:.4f}\n")


def save_validation_images(trainer, val_loader,save_dir, epoch):
    os.makedirs(save_dir, exist_ok=True)

    trainer.gen_AB.eval()
    trainer.gen_BA.eval()
    
    with torch.no_grad():
        real_A, real_B = next(iter(val_loader))
        real_A = real_A.to(trainer.device)
        real_B = real_B.to(trainer.device)
        
        fake_B = trainer.gen_AB(real_A)
        fake_A = trainer.gen_BA(real_B)
        
        image_grid = torch.cat([
            real_A[0:1], fake_B[0:1], real_B[0:1],
            real_B[0:1], fake_A[0:1], real_A[0:1]
        ], dim=3)
        
        save_path = os.path.join(save_dir, f'epoch_{epoch+1}.png')
        save_image(
            denormalize_image(image_grid),
            save_path
        )
        print(f"Validation images saved to: {save_path}")
    
    trainer.gen_AB.train()
    trainer.gen_BA.train()

In [None]:

base_config = {
    'target_dir': '/PATH/TO/YOUR/TAEGET/DIRECTORY',
    'batch_size': 4,
    'epochs': 100,
    'lr': 0.0001,
    'lambda_cycle': 5.0,
    'image_size': 640,
    'num_workers': 8,
    'save_interval': 5,
    'resume_from': None
}

In [None]:
source_dirs = [
    '/PATH/TO/YOUR/SOURCE/DIRECTORY'
]

In [None]:
def train_multiple_sources(source_dirs, base_config):
    for source_dir in source_dirs:
        source_name = os.path.basename(source_dir)
        exp_name = f'residual_cyclegan_{source_name}toHGO_C10_I05_R02'
        
        config = deepcopy(base_config)
        config['source_dir'] = source_dir
        config['exp_name'] = exp_name
        
        print(f"\nStarting training for source: {source_dir}")
        print(f"Experiment name: {exp_name}")
        
        try:
            paths = setup_experiment(config['exp_name'])

            train_loader, val_loader = create_dataloaders(
                source_dir=config['source_dir'],
                target_dir=config['target_dir'],
                batch_size=config['batch_size'],
                num_workers=config['num_workers'],
                image_size=config['image_size']
            )

            generator_AB = ResidualCycleGANGenerator(input_size=config['image_size']).to(device)
            generator_BA = ResidualCycleGANGenerator(input_size=config['image_size']).to(device)
            discriminator_A = PatchGANDiscriminator().to(device)
            discriminator_B = PatchGANDiscriminator().to(device)

            trainer = CycleGANTrainer(
                generator_AB=generator_AB,
                generator_BA=generator_BA,
                discriminator_A=discriminator_A,
                discriminator_B=discriminator_B,
                device=device,
                checkpoint_dir=paths['checkpoints'],
                lr=config['lr'],
                lambda_cycle=10.0,
                lambda_identity=0.5,
                lambda_residual=0.2
            )

            train(
                trainer=trainer,
                train_loader=train_loader,
                val_loader=val_loader,
                num_epochs=config['epochs'],
                save_interval=config['save_interval'],
                save_dir=os.path.join('.', exp_name, 'results')
            )
            
            final_path = os.path.join(paths['checkpoints'], 'final_model.pth')
            torch.save({
                'generator_AB': generator_AB.state_dict(),
                'generator_BA': generator_BA.state_dict(),
                'discriminator_A': discriminator_A.state_dict(),
                'discriminator_B': discriminator_B.state_dict(),
            }, final_path)
            
        except Exception as e:
            print(f"Error during training for {source_dir}: {str(e)}")
            
        finally:
            if 'trainer' in locals():
                del trainer
            if 'generator_AB' in locals():
                del generator_AB
            if 'generator_BA' in locals():
                del generator_BA
            if 'discriminator_A' in locals():
                del discriminator_A
            if 'discriminator_B' in locals():
                del discriminator_B
            if 'train_loader' in locals():
                del train_loader
            if 'val_loader' in locals():
                del val_loader
            
            torch.cuda.empty_cache()
            gc.collect()
            
            print(f"\nCompleted training for source: {source_dir}")
            print("Memory cleaned up and ready for next model")
            
            if os.path.exists('./__pycache__'):
                shutil.rmtree('./__pycache__')

In [None]:
train_multiple_sources(source_dirs, base_config)