In [1]:
# Cell 1: Imports
import sys
sys.path.append('../src')

import torch
from models import FeatureExtractor, AttentionDecoder
from discriminator import PatchGANDiscriminator
from train_adversarial import AdversarialTrainer
from dataset import get_dataloaders

print("All imports successful!")

All imports successful!


In [2]:
# Cell 2: Auto-detect safe device
import sys
sys.path.append('../src')

from train_adversarial import select_device_safe

device = select_device_safe('vgg16', 'block1')
print(f"Using device: {device}")

Using device: cpu


In [3]:
# Cell 3: Create models
architecture = 'vgg16'
layer_name = 'block1'

# Encoder (frozen)
encoder = FeatureExtractor(architecture=architecture, layer_name=layer_name)
encoder.eval()

# Get feature shape
with torch.no_grad():
    dummy = torch.randn(1, 3, 224, 224)
    feat = encoder(dummy)
    feat_channels, feat_h, feat_w = feat.shape[1:]

print(f"Feature shape: {feat_channels} x {feat_h} x {feat_w}")

# Decoder
decoder = AttentionDecoder(
    input_channels=feat_channels,
    input_size=feat_h,
    output_size=224,
    num_blocks=4
)

# Discriminator
discriminator = PatchGANDiscriminator(in_channels=3)

print(f"Decoder params: {sum(p.numel() for p in decoder.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")

Feature shape: 64 x 112 x 112
Decoder params: 233,667
Discriminator params: 2,766,529


In [4]:
# Cell 4: Load small dataset (limit to 10 images for quick test)
train_loader, val_loader, _ = get_dataloaders(
    data_dir = '../data/',
    batch_size=2,
    num_workers=0,
    limit=10  # Only 10 images for quick test
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")


✓ Found 10 images in ../data//DIV2K_train_HR
✓ Found 10 images in ../data//DIV2K_train_HR
✓ Found 100 images in ../data//DIV2K_test_HR
Train batches: 4
Val batches: 1


In [5]:
# Cell 5: Create trainer
trainer = AdversarialTrainer(
    encoder=encoder,
    decoder=decoder,
    discriminator=discriminator,
    device=device,
    mse_weight=1.0,
    adv_weight=0.01,
    lr=0.001
)

print("Trainer initialized!")

Trainer initialized!


In [6]:
# Cell 6: Test single batch forward pass
test_batch = next(iter(train_loader)).to(device)
print(f"Test batch shape: {test_batch.shape}")

# Encode
with torch.no_grad():
    features = encoder(test_batch)
    print(f"Features shape: {features.shape}")

# Decode
reconstructed = decoder(features)
print(f"Reconstructed shape: {reconstructed.shape}")

# Discriminator
D_output = discriminator(test_batch)
print(f"Discriminator output shape: {D_output.shape}")

print("\nForward pass successful!")

Test batch shape: torch.Size([2, 3, 224, 224])
Features shape: torch.Size([2, 64, 112, 112])
Reconstructed shape: torch.Size([2, 3, 224, 224])
Discriminator output shape: torch.Size([2, 1, 26, 26])

Forward pass successful!


In [7]:
# Cell 7: Test one training step
print("Testing one training step...")

# Train discriminator
loss_D, D_real, D_fake = trainer.train_discriminator(test_batch, reconstructed)
print(f"Discriminator - Loss: {loss_D:.4f}, D_real: {D_real:.4f}, D_fake: {D_fake:.4f}")

# Train generator
loss_G, mse, adv = trainer.train_generator(test_batch, reconstructed)
print(f"Generator - Loss: {loss_G:.4f}, MSE: {mse:.4f}, Adv: {adv:.4f}")

print("\nTraining step successful!")

Testing one training step...
Discriminator - Loss: 0.7130, D_real: -0.1480, D_fake: -0.1318
Generator - Loss: 1.5700, MSE: 1.5528, Adv: 1.7267

Training step successful!


In [8]:
# Cell 8: Test 2 epochs (should take ~2-3 minutes on CPU)
print("Running 2-epoch test...")

history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=2,
    save_dir='../results/test_adversarial/checkpoints'
)

print("\n2-epoch test complete!")
print(f"Final train loss: {history['train_loss_G'][-1]:.4f}")
print(f"Final val loss: {history['val_loss_G'][-1]:.4f}")

Running 2-epoch test...

Epoch 1/2
----------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 4/4 [12:28<00:00, 187.06s/it]


Train - Loss_G: 1.329113 (MSE: 1.316924, Adv: 1.218884)
Train - Loss_D: 1.112810 (D_real: 0.134, D_fake: -0.371)
Val   - Loss_G: 1.060474 (MSE: 1.056406, Adv: 0.406795)
LR    - G: 0.001000, D: 0.001000
[SAVED] Best model at epoch 1 (val_loss: 1.060474)

Epoch 2/2
----------------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 4/4 [12:07<00:00, 181.89s/it]


Train - Loss_G: 1.064019 (MSE: 1.053793, Adv: 1.022615)
Train - Loss_D: 0.674401 (D_real: -0.171, D_fake: -0.474)
Val   - Loss_G: 0.881816 (MSE: 0.875698, Adv: 0.611715)
LR    - G: 0.001000, D: 0.001000
[SAVED] Best model at epoch 2 (val_loss: 0.881816)

[SAVED] Training history: ../results/test_adversarial/metrics_adversarial/training_history.csv

2-epoch test complete!
Final train loss: 1.0640
Final val loss: 0.8818


In [9]:
# Cell 1: Just import torch
import torch
print(f"PyTorch version: {torch.__version__}")
print("Cell 1 OK")

PyTorch version: 2.9.0
Cell 1 OK
