In [None]:
# oasis_one_shot.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image
import random
import os

# -----------------------------
# Config / Device
# -----------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

# -----------------------------
# Paths (edit if needed)
# -----------------------------
style_path = '/content/dog.jpg'
content_path = '/content/Screenshot from 2025-10-04 23-08-58.png'
out_dir = './oasis_one_shot_out01'
os.makedirs(out_dir, exist_ok=True)

# -----------------------------
# Transforms / Load images
# -----------------------------
imsize = 256
transform = transforms.Compose([
    transforms.Resize((imsize, imsize)),
    transforms.ToTensor()
])

content_img = Image.open(content_path).convert('RGB')
style_img = Image.open(style_path).convert('RGB')

content = transform(content_img).unsqueeze(0).to(device)  # [1,3,H,W]
style = transform(style_img).unsqueeze(0).to(device)

# optionally create simple augmentations to help discriminator generalize
def augment(x):
    # small random flips and color jitter
    if random.random() < 0.5:
        x = torch.flip(x, dims=[3])  # horizontal flip
    return x

# -----------------------------
# VGG Encoder (same as yours)
# -----------------------------
vgg_pretrained = models.vgg19(pretrained=True).features.to(device).eval()
for p in vgg_pretrained.parameters():
    p.requires_grad = False

class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = {'3':'relu1_2','8':'relu2_2','17':'relu3_4','26':'relu4_4'}
        self.vgg = vgg_pretrained

    def forward(self, x):
        features = {}
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.layers:
                features[self.layers[name]] = x
        return features

encoder = VGGEncoder().to(device)

# -----------------------------
# AdaIN function
# -----------------------------
def adain(content_feat, style_feat, eps=1e-5):
    c_mean = content_feat.mean([2,3], keepdim=True)
    c_std = content_feat.std([2,3], keepdim=True)
    s_mean = style_feat.mean([2,3], keepdim=True)
    s_std = style_feat.std([2,3], keepdim=True)
    return s_std * (content_feat - c_mean) / (c_std + eps) + s_mean

# -----------------------------
# Decoder (same design as yours but slightly expanded)
# -----------------------------
class Decoder(nn.Module):
    def __init__(self, latent_dim=512):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(latent_dim, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )
    def forward(self, x):
        return self.model(x)

generator = Decoder(latent_dim=512).to(device)

# -----------------------------
# PatchGAN Discriminator (OASIS-ish fully conv)
# -----------------------------
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=3, base=64):
        super().__init__()
        # simple PatchGAN: series of convs reducing resolution
        def block(in_c, out_c, stride=2, norm=True):
            layers = [nn.Conv2d(in_c, out_c, 4, stride=stride, padding=1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.net = nn.Sequential(
            *block(in_channels, base, stride=2, norm=False),  # 128x128
            *block(base, base*2, stride=2),  # 64x64
            *block(base*2, base*4, stride=2), # 32x32
            *block(base*4, base*8, stride=1), # 31x31 (patches)
            nn.Conv2d(base*8, 1, 4, stride=1, padding=1)  # final single-channel patch map
        )
    def forward(self, x):
        return self.net(x)  # e.g. [B,1,Hp,Wp]

discriminator = PatchDiscriminator().to(device)

# -----------------------------
# Loss functions
# -----------------------------
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()  # if we use logits
l1 = nn.L1Loss()

def gram_matrix(feat):
    b, c, h, w = feat.size()
    f = feat.view(b, c, h*w)
    gram = torch.bmm(f, f.transpose(1,2)) / (c*h*w)
    return gram

# -----------------------------
# Optimizers
# -----------------------------
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=4e-4, betas=(0.5, 0.999))

# -----------------------------
# Training hyperparams
# -----------------------------
num_iters = 6000            # increase for better results
print_every = 200
adv_weight = 1.0            # adversarial weight
content_weight = 1.0        # perceptual content loss weight
style_weight = 1e6          # style Gram loss weight (you used large before)
fm_weight = 10.0            # feature matching weight from discriminator (stabilizes)
# label smoothing for discriminator
real_label = 0.9
fake_label = 0.0

# Precompute encoded features for style and content (relu4_4 for AdaIN)
with torch.no_grad():
    content_feats_full = encoder(content)
    style_feats_full = encoder(style)
    content_feat = content_feats_full['relu4_4']
    style_feat = style_feats_full['relu4_4']

# For perceptual/content loss, we will compare relu4_4 features of output to original content_feat
# The generator inputs: the AdaIN-mixed relu4_4 features -> decoder -> output

# -----------------------------
# Training loop (one-shot)
# -----------------------------
for it in range(1, num_iters+1):
    generator.train()
    discriminator.train()

    # --- Generator forward (create fake)
    # re-encode fresh (keeps consistency)
    # we can add small noise to content_feat for augmentation, but keep simple
    t = adain(content_feat, style_feat)  # target latent for generator

    fake = generator(t)  # outputs in [-1,1] due to Tanh

    # --------------------------
    # Train Discriminator on real vs fake (PatchGAN)
    # --------------------------
    d_optimizer.zero_grad()

    # Real: use style image as "real" reference for style? or use content as real?
    # We want discriminator to push outputs to be realistic; choose style image as real target for texture
    real_inp = augment(style)  # [1,3,H,W]
    fake_inp = augment(fake.detach())

    d_real = discriminator(real_inp)
    d_fake = discriminator(fake_inp)

    # targets are patch maps: create tensors with same shape as outputs
    real_targets = torch.full_like(d_real, real_label, device=device)
    fake_targets = torch.full_like(d_fake, fake_label, device=device)

    loss_d_real = mse(d_real, real_targets)
    loss_d_fake = mse(d_fake, fake_targets)
    loss_d = 0.5 * (loss_d_real + loss_d_fake)
    loss_d.backward()
    d_optimizer.step()

    # --------------------------
    # Train Generator (adversarial + perceptual + style + feature-matching)
    # --------------------------
    g_optimizer.zero_grad()

    # Adversarial loss (want discriminator to predict real_label on fake)
    d_fake_for_g = discriminator(fake)
    adv_targets = torch.full_like(d_fake_for_g, real_label, device=device)
    loss_adv = mse(d_fake_for_g, adv_targets)

    # Perceptual (content) loss: compare relu4_4 features of fake and content_feat
    out_feats = encoder( (fake + 1.0) / 2.0 )  # encoder expects [0,1] input; fake in [-1,1]
    loss_content = l1(out_feats['relu4_4'], content_feat.detach())

    # Style loss: Gram matrix mismatch between fake features and style features across multiple layers
    style_feats = encoder( (style + 1.0) / 2.0 )
    loss_style = 0.0
    for layer in ['relu1_2','relu2_2','relu3_4','relu4_4']:
        gm_fake = gram_matrix(out_feats[layer])
        gm_style = gram_matrix(style_feats[layer])
        loss_style = loss_style + mse(gm_fake, gm_style)

    # Feature matching loss (use intermediate activations from discriminator)
    # A simple approximation: compare discriminator activations on real and fake (feature matching)
    with torch.no_grad():
        d_real_feats = discriminator(real_inp)
    d_fake_feats = discriminator(fake)
    loss_fm = l1(d_fake_feats, d_real_feats.detach())

    # Total generator loss
    loss_g = adv_weight * loss_adv + content_weight * loss_content + style_weight * loss_style + fm_weight * loss_fm
    loss_g.backward()
    g_optimizer.step()

    # --------------------------
    # Logging / save intermediate outputs
    # --------------------------
    if it % print_every == 0 or it == 1:
        print(f"Iter {it}/{num_iters} | D_loss: {loss_d.item():.6f} | G_adv: {loss_adv.item():.6f} | Content: {loss_content.item():.6f} | Style: {loss_style.item():.6f} | FM: {loss_fm.item():.6f}")
        # save output image (de-normalize)
        out_vis = (fake.detach().clamp(-1,1) + 1.0) / 2.0  # [0,1]
        save_image(out_vis, os.path.join(out_dir, f'fake_{it:06d}.png'))
        # optionally save real style and content for reference
        save_image( (content + 1.0)/2.0 if content.max()<=1.0 else content, os.path.join(out_dir, 'content.png') )
        save_image( (style + 1.0)/2.0 if style.max()<=1.0 else style, os.path.join(out_dir, 'style.png') )

# -----------------------------
# Final save
# -----------------------------
final = (fake.detach().clamp(-1,1) + 1.0) / 2.0
save_image(final, os.path.join(out_dir, 'final_stylized.png'))
print("Done. Outputs saved to:", out_dir)


Using device: cuda
Iter 1/6000 | D_loss: 0.346559 | G_adv: 5.584313 | Content: 0.130698 | Style: 0.000005 | FM: 2.333939
Iter 200/6000 | D_loss: 0.175266 | G_adv: 0.194975 | Content: 0.216334 | Style: 0.000001 | FM: 0.239352
Iter 400/6000 | D_loss: 0.036701 | G_adv: 0.724440 | Content: 0.228337 | Style: 0.000000 | FM: 0.658393
Iter 600/6000 | D_loss: 0.019636 | G_adv: 0.767407 | Content: 0.222822 | Style: 0.000000 | FM: 0.749331
Iter 800/6000 | D_loss: 0.011885 | G_adv: 0.966678 | Content: 0.228439 | Style: 0.000000 | FM: 0.795267
Iter 1000/6000 | D_loss: 0.195138 | G_adv: 0.217692 | Content: 0.209739 | Style: 0.000000 | FM: 0.061681
Iter 1200/6000 | D_loss: 0.189582 | G_adv: 0.223911 | Content: 0.195090 | Style: 0.000000 | FM: 0.060032
Iter 1400/6000 | D_loss: 0.198154 | G_adv: 0.230931 | Content: 0.194732 | Style: 0.000000 | FM: 0.059811
Iter 1600/6000 | D_loss: 0.185414 | G_adv: 0.293772 | Content: 0.191995 | Style: 0.000000 | FM: 0.138206
Iter 1800/6000 | D_loss: 0.218063 | G_adv: 