In [1]:
import os
if os.path.basename(os.getcwd())!="HUST-CV-Neural-Style-Transfer":
    %cd ../../

e:\pyenv\GTCC\KPG-RL\HUST-CV-Neural-Style-Transfer


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import os
import time
from PIL import Image 
from tqdm import tqdm

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_EPOCHS = 2
IMAGE_SIZE = 256

DATASET_PATH = "./datasets/coco" # This is now the folder with images directly inside
WEIGHTS_DIR = "models"
os.makedirs(WEIGHTS_DIR, exist_ok=True)

LAMBDA_PIXEL = 1.0
LAMBDA_FEATURE = 1.0

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

In [None]:
vgg19 = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
vgg19.to(DEVICE)
for param in vgg19.parameters():
    param.requires_grad = False

vgg_layer_indices = {
    'relu1_1': 1, 'relu2_1': 6, 'relu3_1': 11, 'relu4_1': 20, 'relu5_1': 29
}
vgg_feature_layers = nn.ModuleList([vgg19[i] for i in range(max(vgg_layer_indices.values()) + 1)])

def get_vgg_features(image, target_layer_name):
    target_index = vgg_layer_indices[target_layer_name]
    features = image
    for i, layer in enumerate(vgg_feature_layers):
        features = layer(features)
        if i == target_index:
            return features
    raise ValueError(f"Target layer {target_layer_name} not reached.")

In [None]:
def decoder_block(in_channels, out_channels):
     return nn.Sequential(
        nn.ReflectionPad2d((1, 1, 1, 1)),
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True)
    )

class Decoder(nn.Module):
    def __init__(self, level):
        super().__init__()
        self.level = level
        layers = []
        if level == 5:
            layers.extend([
                decoder_block(512, 512),
                nn.Upsample(scale_factor=2, mode='nearest'),
                decoder_block(512, 512), decoder_block(512, 512), decoder_block(512, 512),
            ])
        if level >= 4:
            in_ch = 512 if level == 4 else 512
            layers.extend([
                decoder_block(in_ch, 512),
                nn.Upsample(scale_factor=2, mode='nearest'),
                decoder_block(512, 256), decoder_block(256, 256), decoder_block(256, 256), decoder_block(256, 256),
            ])
        if level >= 3:
            in_ch = 256 if level == 3 else 256
            layers.extend([
                decoder_block(in_ch, 256),
                nn.Upsample(scale_factor=2, mode='nearest'),
                decoder_block(256, 128), decoder_block(128, 128),
            ])
        if level >= 2:
            in_ch = 128 if level == 2 else 128
            layers.extend([
                 decoder_block(in_ch, 128),
                 nn.Upsample(scale_factor=2, mode='nearest'),
                 decoder_block(128, 64),
            ])
        if level >= 1:
            in_ch = 64 if level == 1 else 64
            layers.extend([
                decoder_block(in_ch, 64),
                nn.ReflectionPad2d((1, 1, 1, 1)),
                nn.Conv2d(64, 3, kernel_size=3)
            ])
        self.decoder = nn.Sequential(*layers)

    def forward(self, features):
        return self.decoder(features)

In [None]:
class CocoImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, supported_extensions=('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif')):
        self.image_dir = image_dir
        self.transform = transform
        self.supported_extensions = supported_extensions

        if not os.path.isdir(image_dir):
            raise FileNotFoundError(f"Directory not found: {image_dir}")

        self.image_paths = [
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(supported_extensions) and os.path.isfile(os.path.join(image_dir, fname))
        ]

        if not self.image_paths:
            print(f"Warning: No images with supported extensions {supported_extensions} found in {image_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            # Open image using PIL
            image = Image.open(img_path).convert('RGB') # Ensure image is RGB
        except Exception as e:
            print(f"Warning: Could not load image {img_path}. Error: {e}")
            # Return a dummy tensor or skip? For training, might be better to skip or handle.
            # Here, we'll return a placeholder, but filtering corrupted files beforehand is ideal.
            # Or, you could re-raise the exception if you want the DataLoader to potentially skip.
            return torch.zeros((3, IMAGE_SIZE, IMAGE_SIZE)) # Adjust size if needed

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Custom dataset for this task doesn't need a label, so we return only the image.
        # DataLoader will handle batching.
        return image


# -- Data Loading Setup --
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])


dataset = CocoImageDataset(DATASET_PATH, transform=transform)
if len(dataset) == 0:
    raise ValueError("Dataset is empty. Check path and image extensions.")
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
print(f"Found {len(dataset)} images in {DATASET_PATH}")

In [None]:
pixel_loss_fn = nn.L1Loss().to(DEVICE)
feature_loss_fn = nn.MSELoss().to(DEVICE)

vgg_feature_layers.eval()

target_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']

for level, layer_name in enumerate(target_layers, 1):
    print(f"\n--- Training Decoder for {layer_name} (Level {level}) ---")

    decoder = Decoder(level).to(DEVICE)
    optimizer = optim.Adam(decoder.parameters(), lr=LEARNING_RATE)

    total_steps = len(dataloader)
    start_time = time.time()

    for epoch in range(NUM_EPOCHS):
        epoch_pixel_loss = 0.0
        epoch_feature_loss = 0.0
        epoch_total_loss = 0.0
        decoder.train()

        pbar = tqdm(enumerate(dataloader), total=total_steps, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        
        for i, images in pbar: 
            images = images.to(DEVICE)

            with torch.no_grad():
                target_features = get_vgg_features(images, layer_name)

            reconstructed_images = decoder(target_features)

            loss_p = pixel_loss_fn(reconstructed_images, images)

            recon_features = get_vgg_features(reconstructed_images, layer_name)
            loss_f = feature_loss_fn(recon_features, target_features.detach())

            total_loss = LAMBDA_PIXEL * loss_p + LAMBDA_FEATURE * loss_f

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_pixel_loss += loss_p.item()
            epoch_feature_loss += loss_f.item()
            epoch_total_loss += total_loss.item()

            pbar.set_postfix({
                'PixLoss': f"{loss_p.item():.4f}",
                'FeatLoss': f"{loss_f.item():.4f}",
                'TotalLoss': f"{total_loss.item():.4f}"
            })

        avg_pixel_loss = epoch_pixel_loss / total_steps
        avg_feature_loss = epoch_feature_loss / total_steps
        avg_total_loss = epoch_total_loss / total_steps
        epoch_time = time.time() - start_time

        print(f"Epoch {epoch+1}/{NUM_EPOCHS} Summary:")
        print(f"  Avg Pixel Loss: {avg_pixel_loss:.4f}")
        print(f"  Avg Feature Loss: {avg_feature_loss:.4f}")
        print(f"  Avg Total Loss: {avg_total_loss:.4f}")
        print(f"  Time: {epoch_time:.2f}s")

    decoder_save_path = os.path.join(WEIGHTS_DIR, f"decoder_relu{level}_1.pth")
    torch.save(decoder.state_dict(), decoder_save_path)
    print(f"Saved decoder weights for {layer_name} to {decoder_save_path}")

Using device: cuda
Found 50 images in ./datasets/dataset

--- Training Decoder for relu1_1 (Level 1) ---


Epoch 1/2: 100%|██████████| 13/13 [00:05<00:00,  2.29it/s, PixLoss=1.0998, FeatLoss=0.1317, TotalLoss=1.2315]


Epoch 1/2 Summary:
  Avg Pixel Loss: 1.0822
  Avg Feature Loss: 0.2523
  Avg Total Loss: 1.3345
  Time: 5.93s


Epoch 2/2: 100%|██████████| 13/13 [00:04<00:00,  3.06it/s, PixLoss=0.9007, FeatLoss=0.2427, TotalLoss=1.1434]


Epoch 2/2 Summary:
  Avg Pixel Loss: 0.9894
  Avg Feature Loss: 0.1894
  Avg Total Loss: 1.1788
  Time: 10.44s
Saved decoder weights for relu1_1 to decoder_weights_wct\decoder_relu1_1.pth

--- Training Decoder for relu2_1 (Level 2) ---


Epoch 1/2: 100%|██████████| 13/13 [00:04<00:00,  3.24it/s, PixLoss=0.9379, FeatLoss=1.1128, TotalLoss=2.0507]


Epoch 1/2 Summary:
  Avg Pixel Loss: 1.0653
  Avg Feature Loss: 1.8208
  Avg Total Loss: 2.8861
  Time: 4.22s


Epoch 2/2: 100%|██████████| 13/13 [00:03<00:00,  3.30it/s, PixLoss=0.8118, FeatLoss=1.2275, TotalLoss=2.0393]


Epoch 2/2 Summary:
  Avg Pixel Loss: 0.9640
  Avg Feature Loss: 1.2324
  Avg Total Loss: 2.1963
  Time: 8.37s
Saved decoder weights for relu2_1 to decoder_weights_wct\decoder_relu2_1.pth

--- Training Decoder for relu3_1 (Level 3) ---


Epoch 1/2: 100%|██████████| 13/13 [00:04<00:00,  2.75it/s, PixLoss=0.9170, FeatLoss=8.0694, TotalLoss=8.9864] 


Epoch 1/2 Summary:
  Avg Pixel Loss: 1.0768
  Avg Feature Loss: 8.7448
  Avg Total Loss: 9.8216
  Time: 5.07s


Epoch 2/2: 100%|██████████| 13/13 [00:04<00:00,  2.86it/s, PixLoss=0.9345, FeatLoss=3.3505, TotalLoss=4.2851] 


Epoch 2/2 Summary:
  Avg Pixel Loss: 0.9834
  Avg Feature Loss: 6.7129
  Avg Total Loss: 7.6963
  Time: 9.90s
Saved decoder weights for relu3_1 to decoder_weights_wct\decoder_relu3_1.pth

--- Training Decoder for relu4_1 (Level 4) ---


Epoch 1/2:  77%|███████▋  | 10/13 [00:04<00:01,  2.12it/s, PixLoss=1.0252, FeatLoss=20.3209, TotalLoss=21.3461]


KeyboardInterrupt: 