<a href="https://colab.research.google.com/github/abarb2022/-House-Prices---Advanced-Regression-Techniques/blob/main/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install kaggle



Mount the Google drive so you can store your kaggle API credentials for future use

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Make a directory for kaggle at the temporary instance location on Colab drive.

Download your kaggle API key (.json file). You can do this by going to your kaggle account page and clicking 'Create new API token' under the API section.

In [2]:
! mkdir ~/.kaggle

If you want to copy the kaggle API credentials to the temporary location... (I recommend placing it on your Google Drive)

In [3]:
#! cp kaggle.json ~/.kaggle/

Upload the json file to Google Drive and then copy to the temporary location.

In [4]:
!cp /content/drive/MyDrive/ColabNotebooks/kaggle_API_credentials/kaggle.json ~/.kaggle/kaggle.json

Change the file permissions to read/write to the owner only

In [5]:
! chmod 600 ~/.kaggle/kaggle.json

**Competitions and Datasets are the two types of Kaggle data**

**1. Download competition data**

If you get 403 Forbidden error, you need to click 'Late Submission' on the Kaggle page for that competition.

In [6]:
! kaggle competitions download -c gan-getting-started

Downloading gan-getting-started.zip to /content
 87% 320M/367M [00:06<00:01, 38.2MB/s]
100% 367M/367M [00:06<00:00, 58.5MB/s]


Unzip, in case the downloaded file is zipped. Refresh the files on the left hand side to update the view.

In [None]:
! unzip gan-getting-started

In [None]:
!mkdir -p /content/drive/MyDrive/cyclegan_checkpoints


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
from tqdm import tqdm
import wandb

# ============================================
# 1. GENERATOR (FROM SCRATCH)
# ============================================
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, padding=3, padding_mode='reflect'),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # Residual blocks
        res_blocks = []
        for _ in range(num_residual_blocks):
            res_blocks.append(ResidualBlock(256))
        self.res_blocks = nn.Sequential(*res_blocks)

        # Upsampling
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Output
        self.output = nn.Sequential(
            nn.Conv2d(64, out_channels, 7, padding=3, padding_mode='reflect'),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.res_blocks(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.output(x)
        return x

# ============================================
# 2. DISCRIMINATOR (FROM SCRATCH)
# ============================================
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# ============================================
# 3. DATASET
# ============================================
class MonetDataset(Dataset):
    def __init__(self, photo_dir, monet_dir, img_size=256):
        self.photo_files = [os.path.join(photo_dir, f) for f in os.listdir(photo_dir)]
        self.monet_files = [os.path.join(monet_dir, f) for f in os.listdir(monet_dir)]

        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return max(len(self.photo_files), len(self.monet_files))

    def __getitem__(self, idx):
        photo_path = self.photo_files[idx % len(self.photo_files)]
        monet_path = random.choice(self.monet_files)

        photo = Image.open(photo_path).convert('RGB')
        monet = Image.open(monet_path).convert('RGB')

        return self.transform(photo), self.transform(monet)

In [None]:
def train_cyclegan(
    photo_dir='data/photo_jpg',
    monet_dir='data/monet_jpg',
    num_epochs=30,
    batch_size=1,
    checkpoint_dir='/content/drive/MyDrive/cyclegan_checkpoints',
    save_every=5,
    use_wandb=True
):
    """
    Train CycleGAN from scratch
    """
    # Device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"üöÄ Using device: {device}")

    # Data
    print("üìÅ Loading data...")
    dataset = MonetDataset(photo_dir, monet_dir)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    print(f"‚úÖ Loaded {len(dataset)} samples")

    # Models
    print("üèóÔ∏è  Creating models...")
    gen_A2B = Generator().to(device)
    gen_B2A = Generator().to(device)
    disc_A = Discriminator().to(device)
    disc_B = Discriminator().to(device)

    # Optimizers
    opt_G = optim.Adam(
        list(gen_A2B.parameters()) + list(gen_B2A.parameters()),
        lr=0.0002, betas=(0.5, 0.999)
    )
    opt_D = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=0.0002, betas=(0.5, 0.999)
    )

    # Loss functions (FROM SCRATCH)
    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    lambda_cycle = 10.0
    lambda_identity = 5.0

    # WandB
    if use_wandb:
        wandb.init(project='cyclegan-monet', config={
            'epochs': num_epochs,
            'batch_size': batch_size,
            'lambda_cycle': lambda_cycle,
            'lambda_identity': lambda_identity
        })

    os.makedirs(checkpoint_dir, exist_ok=True)

    # Training loop
    print(f"\n{'='*60}")
    print("üé® STARTING TRAINING")
    print(f"{'='*60}\n")

    for epoch in range(1, num_epochs + 1):
        gen_A2B.train()
        gen_B2A.train()
        disc_A.train()
        disc_B.train()

        epoch_losses = {'G': 0, 'D': 0, 'Cycle': 0}

        pbar = tqdm(dataloader, desc=f'Epoch {epoch}/{num_epochs}')
        for real_A, real_B in pbar:
            real_A, real_B = real_A.to(device), real_B.to(device)

            # ==================
            # TRAIN GENERATORS
            # ==================
            opt_G.zero_grad()

            # Generate fakes
            fake_B = gen_A2B(real_A)
            fake_A = gen_B2A(real_B)

            # Cycle consistency
            reconstructed_A = gen_B2A(fake_B)
            reconstructed_B = gen_A2B(fake_A)

            # Identity
            identity_A = gen_B2A(real_A)
            identity_B = gen_A2B(real_B)

            # Adversarial loss (LSGAN - FROM SCRATCH)
            pred_fake_B = disc_B(fake_B)
            pred_fake_A = disc_A(fake_A)
            loss_adv_B = mse_loss(pred_fake_B, torch.ones_like(pred_fake_B))
            loss_adv_A = mse_loss(pred_fake_A, torch.ones_like(pred_fake_A))

            # Cycle consistency loss (FROM SCRATCH)
            loss_cycle_A = l1_loss(reconstructed_A, real_A)
            loss_cycle_B = l1_loss(reconstructed_B, real_B)
            loss_cycle = (loss_cycle_A + loss_cycle_B) * lambda_cycle

            # Identity loss (FROM SCRATCH)
            loss_identity_A = l1_loss(identity_A, real_A)
            loss_identity_B = l1_loss(identity_B, real_B)
            loss_identity = (loss_identity_A + loss_identity_B) * lambda_identity

            # Total generator loss
            loss_G = loss_adv_A + loss_adv_B + loss_cycle + loss_identity
            loss_G.backward()
            opt_G.step()

            # =======================
            # TRAIN DISCRIMINATORS
            # =======================
            opt_D.zero_grad()

            # Discriminator A (FROM SCRATCH)
            pred_real_A = disc_A(real_A)
            pred_fake_A = disc_A(fake_A.detach())
            loss_real_A = mse_loss(pred_real_A, torch.ones_like(pred_real_A))
            loss_fake_A = mse_loss(pred_fake_A, torch.zeros_like(pred_fake_A))
            loss_D_A = (loss_real_A + loss_fake_A) * 0.5

            # Discriminator B (FROM SCRATCH)
            pred_real_B = disc_B(real_B)
            pred_fake_B = disc_B(fake_B.detach())
            loss_real_B = mse_loss(pred_real_B, torch.ones_like(pred_real_B))
            loss_fake_B = mse_loss(pred_fake_B, torch.zeros_like(pred_fake_B))
            loss_D_B = (loss_real_B + loss_fake_B) * 0.5

            loss_D = loss_D_A + loss_D_B
            loss_D.backward()
            opt_D.step()

            # Track losses
            epoch_losses['G'] += loss_G.item()
            epoch_losses['D'] += loss_D.item()
            epoch_losses['Cycle'] += loss_cycle.item()

            pbar.set_postfix({
                'G': f"{loss_G.item():.3f}",
                'D': f"{loss_D.item():.3f}",
                'Cycle': f"{loss_cycle.item():.3f}"
            })

        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(dataloader)

        # Log
        if use_wandb:
            wandb.log({'epoch': epoch, **epoch_losses})

        print(f"\nüìä Epoch {epoch}: G={epoch_losses['G']:.4f}, D={epoch_losses['D']:.4f}, Cycle={epoch_losses['Cycle']:.4f}\n")

        # Save checkpoint
        if epoch % save_every == 0 or epoch == num_epochs:
            checkpoint_path = f"{checkpoint_dir}/checkpoint_epoch_{epoch}.pth"
            torch.save({
                'epoch': epoch,
                'gen_A2B': gen_A2B.state_dict(),
                'gen_B2A': gen_B2A.state_dict(),
                'disc_A': disc_A.state_dict(),
                'disc_B': disc_B.state_dict(),
            }, checkpoint_path)
            print(f"üíæ Saved: {checkpoint_path}\n")

    print("\n‚úÖ TRAINING COMPLETE!")
    if use_wandb:
        wandb.finish()

    return gen_A2B, gen_B2A

In [10]:
!pip install wandb -q
import wandb
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mabarb2022[0m ([33mabarb2022-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
gen_A2B, gen_B2A = train_cyclegan(
    photo_dir='photo_jpg',
    monet_dir='monet_jpg',
    num_epochs=5
)

üöÄ Using device: cpu
üìÅ Loading data...
‚úÖ Loaded 7038 samples
üèóÔ∏è  Creating models...


  | |_| | '_ \/ _` / _` |  _/ -_)


KeyboardInterrupt: 

In [14]:
def resume_training(
    checkpoint_path='/content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_5.pth',
    photo_dir='photo_jpg',
    monet_dir='monet_jpg',
    num_epochs=25,
    batch_size=1,
    checkpoint_dir='/content/drive/MyDrive/cyclegan_checkpoints',
    save_every=1,
    wandb_run_id=None
):
    """
    Resume training from checkpoint
    Will start from epoch 6 if checkpoint is epoch 5
    """

    # Check device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"üöÄ Using device: {device}")

    if device == 'cpu':
        print("‚ö†Ô∏è  WARNING: No GPU detected!")
        print("    Training on CPU will take 50x longer!")
        print("    Please reconnect to GPU: Runtime ‚Üí Change runtime type ‚Üí T4 GPU")
        return None, None

    # Load checkpoint
    print(f"üìÇ Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    start_epoch = checkpoint['epoch']
    print(f"‚úÖ Loaded checkpoint from epoch {start_epoch}")
    print(f"üéØ Will train from epoch {start_epoch + 1} to {num_epochs}")

    # Data
    print("\nüìÅ Loading data...")
    dataset = MonetDataset(photo_dir, monet_dir)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        persistent_workers=True
    )
    print(f"‚úÖ Loaded {len(dataset)} samples")

    # Models
    print("\nüèóÔ∏è  Creating models...")
    gen_A2B = Generator().to(device)
    gen_B2A = Generator().to(device)
    disc_A = Discriminator().to(device)
    disc_B = Discriminator().to(device)

    # Load weights
    print("‚öôÔ∏è  Loading model weights from checkpoint...")
    gen_A2B.load_state_dict(checkpoint['gen_A2B'])
    gen_B2A.load_state_dict(checkpoint['gen_B2A'])
    disc_A.load_state_dict(checkpoint['disc_A'])
    disc_B.load_state_dict(checkpoint['disc_B'])
    print("‚úÖ All model weights loaded!")

    # Optimizers
    opt_G = optim.Adam(
        list(gen_A2B.parameters()) + list(gen_B2A.parameters()),
        lr=0.0002, betas=(0.5, 0.999)
    )
    opt_D = optim.Adam(
        list(disc_A.parameters()) + list(disc_B.parameters()),
        lr=0.0002, betas=(0.5, 0.999)
    )

    # Load optimizer states if available
    if 'opt_G' in checkpoint:
        print("‚öôÔ∏è  Loading optimizer states...")
        opt_G.load_state_dict(checkpoint['opt_G'])
        opt_D.load_state_dict(checkpoint['opt_D'])
        print("‚úÖ Optimizer states loaded!")

    # Loss functions
    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    lambda_cycle = 10.0
    lambda_identity = 5.0

    # WandB
    print("\nüåê Initializing WandB...")
    if wandb_run_id:
        print(f"   Resuming run: {wandb_run_id}")
        wandb.init(
            project='cyclegan-monet',
            id=wandb_run_id,
            resume='must'
        )
    else:
        print("   Creating new run")
        wandb.init(
            project='cyclegan-monet',
            config={
                'resumed_from': start_epoch,
                'target_epochs': num_epochs,
                'lambda_cycle': lambda_cycle,
                'lambda_identity': lambda_identity
            }
        )

    os.makedirs(checkpoint_dir, exist_ok=True)

    # Training loop
    print(f"\n{'='*60}")
    print(f"üé® RESUMING TRAINING")
    print(f"   Starting from: Epoch {start_epoch + 1}")
    print(f"   Training until: Epoch {num_epochs}")
    print(f"   Total remaining: {num_epochs - start_epoch} epochs")
    print(f"{'='*60}\n")

    for epoch in range(start_epoch + 1, num_epochs + 1):
        gen_A2B.train()
        gen_B2A.train()
        disc_A.train()
        disc_B.train()

        epoch_losses = {'G': 0, 'D': 0, 'Cycle': 0}

        pbar = tqdm(dataloader, desc=f'Epoch {epoch}/{num_epochs}')
        for real_A, real_B in pbar:
            real_A, real_B = real_A.to(device), real_B.to(device)

            # TRAIN GENERATORS
            opt_G.zero_grad()

            fake_B = gen_A2B(real_A)
            fake_A = gen_B2A(real_B)

            reconstructed_A = gen_B2A(fake_B)
            reconstructed_B = gen_A2B(fake_A)

            identity_A = gen_B2A(real_A)
            identity_B = gen_A2B(real_B)

            pred_fake_B = disc_B(fake_B)
            pred_fake_A = disc_A(fake_A)
            loss_adv_B = mse_loss(pred_fake_B, torch.ones_like(pred_fake_B))
            loss_adv_A = mse_loss(pred_fake_A, torch.ones_like(pred_fake_A))

            loss_cycle_A = l1_loss(reconstructed_A, real_A)
            loss_cycle_B = l1_loss(reconstructed_B, real_B)
            loss_cycle = (loss_cycle_A + loss_cycle_B) * lambda_cycle

            loss_identity_A = l1_loss(identity_A, real_A)
            loss_identity_B = l1_loss(identity_B, real_B)
            loss_identity = (loss_identity_A + loss_identity_B) * lambda_identity

            loss_G = loss_adv_A + loss_adv_B + loss_cycle + loss_identity
            loss_G.backward()
            opt_G.step()

            # TRAIN DISCRIMINATORS
            opt_D.zero_grad()

            pred_real_A = disc_A(real_A)
            pred_fake_A = disc_A(fake_A.detach())
            loss_real_A = mse_loss(pred_real_A, torch.ones_like(pred_real_A))
            loss_fake_A = mse_loss(pred_fake_A, torch.zeros_like(pred_fake_A))
            loss_D_A = (loss_real_A + loss_fake_A) * 0.5

            pred_real_B = disc_B(real_B)
            pred_fake_B = disc_B(fake_B.detach())
            loss_real_B = mse_loss(pred_real_B, torch.ones_like(pred_real_B))
            loss_fake_B = mse_loss(pred_fake_B, torch.zeros_like(pred_fake_B))
            loss_D_B = (loss_real_B + loss_fake_B) * 0.5

            loss_D = loss_D_A + loss_D_B
            loss_D.backward()
            opt_D.step()

            epoch_losses['G'] += loss_G.item()
            epoch_losses['D'] += loss_D.item()
            epoch_losses['Cycle'] += loss_cycle.item()

            pbar.set_postfix({
                'G': f"{loss_G.item():.3f}",
                'D': f"{loss_D.item():.3f}",
                'Cycle': f"{loss_cycle.item():.3f}"
            })

        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(dataloader)

        # Log to WandB
        wandb.log({'epoch': epoch, **epoch_losses})

        print(f"\nüìä Epoch {epoch}: G={epoch_losses['G']:.4f}, D={epoch_losses['D']:.4f}, Cycle={epoch_losses['Cycle']:.4f}\n")

        # Save checkpoint
        if epoch % save_every == 0 or epoch == num_epochs:
            checkpoint_path = f"{checkpoint_dir}/checkpoint_epoch_{epoch}.pth"
            torch.save({
                'epoch': epoch,
                'gen_A2B': gen_A2B.state_dict(),
                'gen_B2A': gen_B2A.state_dict(),
                'disc_A': disc_A.state_dict(),
                'disc_B': disc_B.state_dict(),
                'opt_G': opt_G.state_dict(),
                'opt_D': opt_D.state_dict(),
            }, checkpoint_path)
            print(f"üíæ Saved: {checkpoint_path}\n")

    print("\n‚úÖ TRAINING COMPLETE!")
    wandb.finish()

    return gen_A2B, gen_B2A

In [15]:
gen_A2B, gen_B2A = resume_training(
    checkpoint_path='/content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_22.pth',
    photo_dir='photo_jpg',
    monet_dir='monet_jpg',
    num_epochs=25,
    wandb_run_id='d1vlk1je'
)

üöÄ Using device: cuda
üìÇ Loading checkpoint: /content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_22.pth
‚úÖ Loaded checkpoint from epoch 22
üéØ Will train from epoch 23 to 25

üìÅ Loading data...
‚úÖ Loaded 7038 samples

üèóÔ∏è  Creating models...
‚öôÔ∏è  Loading model weights from checkpoint...
‚úÖ All model weights loaded!
‚öôÔ∏è  Loading optimizer states...
‚úÖ Optimizer states loaded!

üåê Initializing WandB...
   Resuming run: d1vlk1je



üé® RESUMING TRAINING
   Starting from: Epoch 23
   Training until: Epoch 25
   Total remaining: 3 epochs



Epoch 23/25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7038/7038 [57:19<00:00,  2.05it/s, G=3.371, D=0.246, Cycle=1.504]



üìä Epoch 23: G=4.1816, D=0.1638, Cycle=1.9433

üíæ Saved: /content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_23.pth



Epoch 24/25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7038/7038 [57:20<00:00,  2.05it/s, G=4.875, D=0.222, Cycle=2.599]



üìä Epoch 24: G=4.1580, D=0.1618, Cycle=1.9244

üíæ Saved: /content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_24.pth



Epoch 25/25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7038/7038 [57:21<00:00,  2.05it/s, G=3.838, D=0.120, Cycle=1.597]



üìä Epoch 25: G=4.1212, D=0.1652, Cycle=1.9109

üíæ Saved: /content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_25.pth


‚úÖ TRAINING COMPLETE!


0,1
Cycle,‚ñà‚ñÑ‚ñÅ
D,‚ñÖ‚ñÅ‚ñà
G,‚ñà‚ñÖ‚ñÅ
epoch,‚ñÅ‚ñÖ‚ñà

0,1
Cycle,1.91086
D,0.16524
G,4.12115
epoch,25.0


In [None]:
import os
import torch
import numpy as np
import wandb
from PIL import Image
from torchvision import transforms


wandb.init(
    project="cyclegan-monet",
    id="m2smdjhf",
    resume="allow"
)

checkpoint = torch.load(
    "/content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_5.pth",
    map_location=torch.device("cpu")
)

gen_A2B = Generator()
gen_A2B.load_state_dict(checkpoint["gen_A2B"])
gen_A2B.eval()


transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

photo_dir = "photo_jpg"
photo_files = sorted(os.listdir(photo_dir))[:5]

wandb_images = []


with torch.no_grad():
    for fname in photo_files:
        photo_path = os.path.join(photo_dir, fname)

        # Load image
        photo = Image.open(photo_path).convert("RGB")

        # Model input
        photo_tensor = transform(photo).unsqueeze(0)

        # Generate Monet-style image
        fake_monet = gen_A2B(photo_tensor)

        fake_monet_img = (
            (fake_monet[0] + 1) / 2
        ).permute(1, 2, 0).numpy()

        photo_vis = np.array(photo.resize((256, 256))) / 255.0

        wandb_images.append(
            wandb.Image(photo_vis, caption=f"Input: {fname}")
        )
        wandb_images.append(
            wandb.Image(fake_monet_img, caption=f"Output: {fname}")
        )


wandb.log({
    "Epoch 25 / Input ‚Üí Monet Samples": wandb_images
})

wandb.finish()


In [18]:
import os
import torch
import numpy as np
import wandb
from PIL import Image
from torchvision import transforms

!pip install pytorch-fid
from pytorch_fid import fid_score
from pathlib import Path

# ============================================
# STEP 1: Initialize WandB
# ============================================
wandb.init(
    project="cyclegan-monet",
    id="d1vlk1je",
    resume="allow"
)

# ============================================
# STEP 2: Load Checkpoint
# ============================================
checkpoint = torch.load(
    "/content/drive/MyDrive/cyclegan_checkpoints/checkpoint_epoch_25.pth",
    map_location=torch.device("cpu")
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gen_A2B = Generator()
gen_A2B.load_state_dict(checkpoint["gen_A2B"])
gen_A2B.to(device)
gen_A2B.eval()

# ============================================
# STEP 3: Setup transforms
# ============================================
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

# ============================================
# STEP 4: Generate images and save for FID
# ============================================
photo_dir = "photo_jpg"
monet_dir = "monet_jpg"  # Real Monet paintings directory

# Create temporary directories for FID calculation
generated_dir = "generated_monet_fid"
os.makedirs(generated_dir, exist_ok=True)

photo_files = sorted(os.listdir(photo_dir))
wandb_images = []

print(f"üé® Generating images for FID calculation...")
print(f"Processing {len(photo_files)} photos...")

with torch.no_grad():
    for i, fname in enumerate(photo_files):
        photo_path = os.path.join(photo_dir, fname)

        # Load image
        photo = Image.open(photo_path).convert("RGB")

        # Model input
        photo_tensor = transform(photo).unsqueeze(0).to(device)

        # Generate Monet-style image
        fake_monet = gen_A2B(photo_tensor)

        # Convert to image for saving (denormalize)
        fake_monet_img = (fake_monet[0].cpu() + 1) / 2
        fake_monet_img = fake_monet_img.clamp(0, 1)
        fake_monet_pil = transforms.ToPILImage()(fake_monet_img)

        # Save generated image for FID calculation
        save_path = os.path.join(generated_dir, fname)
        fake_monet_pil.save(save_path)

        # Log first 5 samples to WandB
        if i < 5:
            fake_monet_np = fake_monet_img.permute(1, 2, 0).numpy()
            photo_vis = np.array(photo.resize((256, 256))) / 255.0

            wandb_images.append(
                wandb.Image(photo_vis, caption=f"Input: {fname}")
            )
            wandb_images.append(
                wandb.Image(fake_monet_np, caption=f"Output: {fname}")
            )

        if (i + 1) % 100 == 0:
            print(f"  ‚úì Generated {i + 1}/{len(photo_files)} images")

print(f"‚úÖ Generated {len(photo_files)} images")

# ============================================
# STEP 5: Calculate FID Score
# ============================================
print("\nüìä Calculating FID score...")
print(f"  Real Monet dir: {monet_dir}")
print(f"  Generated dir: {generated_dir}")

try:
    fid_value = fid_score.calculate_fid_given_paths(
        paths=[monet_dir, generated_dir],
        batch_size=50,
        device=device,
        dims=2048,  # InceptionV3 feature dimension
        num_workers=0
    )

    print(f"\n‚ú® FID Score: {fid_value:.2f}")

    # Interpretation
    if fid_value < 50:
        quality = "Excellent! üåü"
    elif fid_value < 100:
        quality = "Good! üëç"
    elif fid_value < 200:
        quality = "Okay üòê"
    else:
        quality = "Needs improvement üìâ"

    print(f"   Quality: {quality}")

except Exception as e:
    print(f"‚ùå Error calculating FID: {e}")
    fid_value = None

# ============================================
# STEP 6: Log everything to WandB
# ============================================
log_dict = {
    "Epoch 25 / Input ‚Üí Monet Samples": wandb_images,
    "epoch": checkpoint.get("epoch", 25)
}

# Add FID score if calculated successfully
if fid_value is not None:
    log_dict["FID Score"] = fid_value

wandb.log(log_dict)

# ============================================
# STEP 7: Create comparison table
# ============================================
if fid_value is not None:
    # Create a summary table
# Create a summary table
# Simpler approach - just log metrics
    wandb.log({
        "Epoch 25 / Input ‚Üí Monet Samples": wandb_images,
        "FID Score": fid_value,
        "epoch": checkpoint.get("epoch", 25),
        "num_images_generated": len(photo_files)
    })

    print(f"\nüìä Summary:")
    print(f"   FID Score: {fid_value:.2f} ({quality})")
    print(f"   Images Generated: {len(photo_files)}")
    print(f"   Checkpoint Epoch: {checkpoint.get('epoch', 25)}")

print("\n‚úÖ Logged to WandB!")
wandb.finish()




0,1
Cycle,1.91086
D,0.16524
FID Score,86.6775
G,4.12115
epoch,25.0


üé® Generating images for FID calculation...
Processing 7038 photos...
  ‚úì Generated 100/7038 images
  ‚úì Generated 200/7038 images
  ‚úì Generated 300/7038 images
  ‚úì Generated 400/7038 images
  ‚úì Generated 500/7038 images
  ‚úì Generated 600/7038 images
  ‚úì Generated 700/7038 images
  ‚úì Generated 800/7038 images
  ‚úì Generated 900/7038 images
  ‚úì Generated 1000/7038 images
  ‚úì Generated 1100/7038 images
  ‚úì Generated 1200/7038 images
  ‚úì Generated 1300/7038 images
  ‚úì Generated 1400/7038 images
  ‚úì Generated 1500/7038 images
  ‚úì Generated 1600/7038 images
  ‚úì Generated 1700/7038 images
  ‚úì Generated 1800/7038 images
  ‚úì Generated 1900/7038 images
  ‚úì Generated 2000/7038 images
  ‚úì Generated 2100/7038 images
  ‚úì Generated 2200/7038 images
  ‚úì Generated 2300/7038 images
  ‚úì Generated 2400/7038 images
  ‚úì Generated 2500/7038 images
  ‚úì Generated 2600/7038 images
  ‚úì Generated 2700/7038 images
  ‚úì Generated 2800/7038 images
  ‚úì Generat

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 6/6 [00:01<00:00,  3.75it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 141/141 [00:40<00:00,  3.44it/s]



‚ú® FID Score: 86.68
   Quality: Good! üëç

üìä Summary:
   FID Score: 86.68 (Good! üëç)
   Images Generated: 7038
   Checkpoint Epoch: 25


NameError: name 'table' is not defined