In [6]:
import torch
from torchvision.utils import save_image
import import_ipynb
import model
from model import CycleGANLosses
import preprocess
import os

# a = landscape images
# b = vangogh images

def train_cyclegan(data_loaders, device, epochs, lambda_cyc=10.0, lambda_id=5.0, learning_rate=0.0002):
    
    # Initialize models
    g_ab, g_ba, d_a, d_b = initialize_models(device)
    
    # Optimizers
    optimizer_g = torch.optim.Adam(list(g_ab.parameters()) + list(g_ba.parameters()), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_d_a = torch.optim.Adam(d_a.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_d_b = torch.optim.Adam(d_b.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    
    # Loss functions
    losses = CycleGANLosses(device)
    
    # Training
    print("Beginning training...")
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        for i, (real_a, real_b) in enumerate(zip(data_loaders['landscape_train'], data_loaders['vangogh_train'])):
            print(f"Processing batch {i + 1}/{len(data_loaders['landscape_train'])}")
            real_a = real_a[0].to(device)  # Real images from landscape
            real_b = real_b[0].to(device)  # Real images from vangogh

            #  Train generators
            optimizer_g.zero_grad()
            
            # Generate images
            fake_b = g_ab(real_a)
            fake_a = g_ba(real_b)
            cycle_a = g_ba(fake_b)
            cycle_b = g_ab(fake_a)
            identity_a = g_ba(real_a)
            identity_b = g_ab(real_b)
            
            # Compute loss
            g_loss = losses.compute_g_loss(
                fake_b, real_b, cycle_a, identity_a, lambda_cyc, lambda_id
            )
            g_loss += losses.compute_g_loss(
                fake_a, real_a, cycle_b, identity_b, lambda_cyc, lambda_id
            )
            
            # Update generator weights
            g_loss.backward()
            optimizer_g.step()
            
            #  Train landscape discriminator
            optimizer_d_a.zero_grad()
            
            fake_a_detached = fake_a.detach()  # Avoid training G_BA
            d_a_loss = losses.compute_d_loss(d_a(real_a), d_a(fake_a_detached))
            
            # Update weights
            d_a_loss.backward()
            optimizer_d_a.step()
            
            # Train vangogh discriminator
            optimizer_d_b.zero_grad()
            
            fake_b_detached = fake_b.detach()  # Avoid training G_AB
            d_b_loss = losses.compute_d_loss(d_b(real_b), d_b(fake_b_detached))
            
            # Update weights
            d_b_loss.backward()
            optimizer_d_b.step()

            # Print losses during training
            if i % 10 == 0:
                print(f"[Epoch {epoch+1}/{epochs}] [Batch {i+1}] "
                      f"Generator Loss: {g_loss.item():.4f}, "
                      f"D_A Loss: {d_a_loss.item():.4f}, "
                      f"D_B Loss: {d_b_loss.item():.4f}")
        
        # Save generated samples during training
        if (epoch + 1) % 10 == 0:
            save_image(fake_b[:4], f"generated_epoch_{epoch+1}_fake_b.jpg", nrow=2, normalize=True)
            save_image(fake_a[:4], f"generated_epoch_{epoch+1}_fake_a.jpg", nrow=2, normalize=True)

        '''
        # Save model checkpoints every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save(g_ab.state_dict(), os.path.join(save_dir, f"g_ab_epoch_{epoch+1}.pth"))
            torch.save(g_ba.state_dict(), os.path.join(save_dir, f"g_ba_epoch_{epoch+1}.pth"))
            torch.save(d_a.state_dict(), os.path.join(save_dir, f"d_a_epoch_{epoch+1}.pth"))
            torch.save(d_b.state_dict(), os.path.join(save_dir, f"d_b_epoch_{epoch+1}.pth"))
            print(f"Models saved at epoch {epoch+1}.")
        '''
    

    # Save model after training
    torch.save(g_ab.state_dict(), os.path.join(save_dir, "g_ab_final.pth"))
    torch.save(g_ba.state_dict(), os.path.join(save_dir, "g_ba_final.pth"))
    torch.save(d_a.state_dict(), os.path.join(save_dir, "d_a_final.pth"))
    torch.save(d_b.state_dict(), os.path.join(save_dir, "d_b_final.pth"))
    print("Final models saved.")


# Train model

data_folder = os.getcwd() + "/data"
dataloaders = preprocess.preprocess_data(data_folder, image_size=(256, 256), batch_size=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

train_cyclegan(dataloaders, device, epochs=20)

Using device: cpu
Beginning training...
Epoch 1/20
Processing batch 1/53
[Epoch 1/20] [Batch 1] Generator Loss: 17.4542, D_A Loss: 0.7985, D_B Loss: 0.7101
Processing batch 2/53
Processing batch 3/53
Processing batch 4/53
Processing batch 5/53
Processing batch 6/53
Processing batch 7/53
Processing batch 8/53
Processing batch 9/53
Processing batch 10/53
Processing batch 11/53
[Epoch 1/20] [Batch 11] Generator Loss: 14.1547, D_A Loss: 0.1915, D_B Loss: 0.2481
Processing batch 12/53
Processing batch 13/53
Processing batch 14/53
Processing batch 15/53
Processing batch 16/53
Processing batch 17/53
Processing batch 18/53
Processing batch 19/53
Processing batch 20/53
Processing batch 21/53
[Epoch 1/20] [Batch 21] Generator Loss: 15.1437, D_A Loss: 0.0761, D_B Loss: 0.2051
Processing batch 22/53


KeyboardInterrupt: 