## Feed Forward Style Transfer

In [1]:
import os, random, time
from pathlib import Path
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models, utils

In [2]:
#select computation device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    USE_AMP = True
elif getattr(torch, "has_mps", False) and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    USE_AMP = False
else:
    DEVICE = torch.device("cpu")
    USE_AMP = False

print("Device:", DEVICE, "USE_AMP:", USE_AMP)

Device: mps USE_AMP: False


  elif getattr(torch, "has_mps", False) and torch.backends.mps.is_available():


In [3]:
#Hyperparameters
IMG_SIZE = 512
BATCH_SIZE = 6
NUM_EPOCHS = 12
LR = 1e-4
CONTENT_WEIGHT = 1.0
STYLE_WEIGHT = 1e6
TV_WEIGHT = 1e-6

In [4]:
#fielpaths
CONTENT_ROOT = "../Data/dataset/clean/animals_balanced"   
STYLE_ROOT= "../Data/dataset/clean/origami_images"
SPLIT_ROOT   = "../Data/dataset/split"   
CHECKPOINT_DIR = "./checkpoints_nststyle"
SAMPLES_DIR    = "./samples_nststyle"

for d in [SPLIT_ROOT, CHECKPOINT_DIR, SAMPLES_DIR]:
    os.makedirs(d, exist_ok=True)

TARGET_CLASS = "butterfly" #single class (inital)

for split in ['train', 'val', 'test']:
    for root in ['content', 'style']:
        path = os.path.join(SPLIT_ROOT, root, split, TARGET_CLASS)
        os.makedirs(path, exist_ok=True)
        

### VGG Configuration

In [8]:
#Vsame as NST
LAYER_INDICES = {
    'conv1_1': '0', 
    'conv1_2': '2', 
    'conv2_1': '5', 
    'conv2_2': '7',
    'conv3_1': '10', 
    'conv3_2': '12', 
    'conv3_3': '14', 
    'conv3_4': '16',
    'conv4_1': '19', 
    'conv4_2': '21', 
    'conv4_3': '23', 
    'conv4_4': '25',
    'conv5_1': '28', 
    'conv5_2': '30', 
    'conv5_3': '32', 
    'conv5_4': '34'
}

LAYER_CONFIGS = {
    'gatys': {
        'content': ['conv4_2'],
        'style': ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
        'style_weights': {
            'conv1_1': 1.0,
            'conv2_1': 0.8,
            'conv3_1': 0.5,
            'conv4_1': 0.3,
            'conv5_1': 0.1
        },
    }
}
ACTIVE_LAYER_CONFIG = 'gatys'


In [5]:
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD  = [0.229, 0.224, 0.225]

def list_images(dir_):
    return sorted([os.path.join(dir_,f) for f in os.listdir(dir_)
                   if f.lower().endswith(('.jpg','.jpeg','.png'))])
    
def exif_fix_and_open(path):
    img = Image.open(path)
    img = ImageOps.exif_transpose(img)
    return img.convert("RGB")

def print_progress(prefix, step, total, every=10, end=False):
    if end:
        print(f"\n{prefix} done.")
        return
    if step % every == 0:
        print("==> " + "." * (step // every), end="\r")


In [6]:
#same as nst.py
def normalize_for_vgg(x):
    mean = torch.tensor(IMG_MEAN, device=DEVICE).view(1,3,1,1)
    std  = torch.tensor(IMG_STD,  device=DEVICE).view(1,3,1,1)
    return (x - mean) / std

def gram_matrix_batch(tensor):
    b, c, h, w = tensor.size()
    f = tensor.view(b, c, h*w)
    return torch.bmm(f, f.transpose(1,2)) / (c * h * w)

#### Data Splitting

In [7]:
def ensure_splits(class_name, val_frac=0.1, test_frac=0.05, seed=42):
    random.seed(seed)
    src_content = os.path.join(CONTENT_ROOT, class_name)
    src_style   = os.path.join(STYLE_ROOT,   class_name)
    assert os.path.isdir(src_content), f"Missing content folder: {src_content}"
    assert os.path.isdir(src_style),   f"Missing style folder: {src_style}"

    def split_list(files):
        n = len(files)
        n_val  = int(n * val_frac)
        n_test = int(n * test_frac)
        return files[n_val+n_test:], files[:n_val], files[n_val:n_val+n_test]

    def copy_split(src_folder, dst_folder, files):
        os.makedirs(dst_folder, exist_ok=True)
        for f in files:
            src = os.path.join(src_folder, f)
            dst = os.path.join(dst_folder, f)
            if not os.path.exists(dst):
                exif_fix_and_open(src).save(dst, "JPEG", quality=90)

    for domain, root in [("content", CONTENT_ROOT), ("style", STYLE_ROOT)]:
        src = os.path.join(root, class_name)
        files = [f for f in os.listdir(src) if f.lower().endswith(('.jpg','.jpeg','.png'))]
        print(f"{domain} - {class_name} - found {len(files)} files in {src}")
        random.shuffle(files)
        train, val, test = split_list(files)
        for split_name, flist in zip(["train","val","test"], [train, val, test]):
            out_dir = os.path.join(SPLIT_ROOT, domain, split_name, class_name)
            copy_split(src, out_dir, flist)

    print(f"Created train/val/test splits under {SPLIT_ROOT}/{class_name}")

ensure_splits(TARGET_CLASS)

splits ready under ../Data/dataset/split/butterfly


In [9]:
transform_img = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor()
])

### Data Sampler

In [None]:
class SingleClassPairedSampler:
    def __init__(self, split_root, class_name, transform):
        self.split_root = split_root
        self.class_name = class_name
        self.transform  = transform
        self.split = 'train'
        self._build_index()

    def _build_index(self):
        for split in ['train','val','test']:
            c_path = os.path.join(self.split_root, "content", split, self.class_name)
            s_path = os.path.join(self.split_root, "style",   split, self.class_name)
            c_files = sorted([os.path.join(c_path,f) for f in os.listdir(c_path)
                              if f.lower().endswith(('.jpg','.jpeg','.png'))])
            s_files = sorted([os.path.join(s_path,f) for f in os.listdir(s_path)
                              if f.lower().endswith(('.jpg','.jpeg','.png'))])
            setattr(self, f"{split}_content_files", c_files)
            setattr(self, f"{split}_style_files",   s_files)

    def set_split(self, split):
        self.split = split

    def sample_batch(self, batch_size):
        c_files = getattr(self, f"{self.split}_content_files")
        s_files = getattr(self, f"{self.split}_style_files")
        paths_c = random.choices(c_files, k=batch_size)
        paths_s = random.choices(s_files, k=batch_size)
        c_imgs = [self.transform(exif_fix_and_open(p)) for p in paths_c]
        s_imgs = [self.transform(exif_fix_and_open(p)) for p in paths_s]
        return torch.stack(c_imgs), torch.stack(s_imgs)

sampler = SingleClassPairedSampler(SPLIT_ROOT, TARGET_CLASS, transform=transform_img)
print("Sampler ready | train content:", len(sampler.train_content_files), "| train style:", len(sampler.train_style_files))

### VGG19 feature extractor

In [None]:
vgg = models.vgg19(pretrained=True).features.to(DEVICE).eval()
for p in vgg.parameters():
    p.requires_grad = False

class VGGFeatureExtractor(nn.Module):
    def __init__(self, vgg, layer_indices):
        super().__init__()
        self.vgg = vgg
        self.idx_to_name = {int(idx_str): name for name, idx_str in layer_indices.items()}
    def forward(self, x):
        cur = x
        feats = {}
        for idx, layer in self.vgg._modules.items():
            cur = layer(cur)
            i = int(idx)
            if i in self.idx_to_name:
                feats[self.idx_to_name[i]] = cur
        return feats

vgg_feat = VGGFeatureExtractor(vgg, LAYER_INDICES).to(DEVICE).eval()

# Active layer config
_cfg = LAYER_CONFIGS[ACTIVE_LAYER_CONFIG]
content_layers  = _cfg['content']
style_layers    = _cfg['style']
style_weights   = _cfg['style_weights']
print("Content layers:", content_layers)
print("Style layers:", style_layers)

#### Transformer Network

In [12]:
class ConvLayer(nn.Module):
    def __init__(self, in_c, out_c, kernel, stride):
        super().__init__()
        padding = kernel // 2
        self.conv = nn.Conv2d(in_c, out_c, kernel, stride, padding)
        self.inorm = nn.InstanceNorm2d(out_c, affine=True)
    def forward(self, x):
        return F.relu(self.inorm(self.conv(x)))

In [13]:
# class ResidualBlock(nn.Module): #learn style modifications
#     def __init__(self, channels):
#         super().__init__()
#         self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in1 = nn.InstanceNorm2d(channels, affine=True)
#         self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in2 = nn.InstanceNorm2d(channels, affine=True)
#     def forward(self, x):
#         out = F.relu(self.in1(self.conv1(x)))
#         out = self.in2(self.conv2(out))
#         return out + x

class StylizedResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.in1   = nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.in2   = nn.InstanceNorm2d(channels, affine=True)
        self.style_gate = nn.Sequential(nn.Conv2d(channels, channels, 1), nn.Sigmoid())
    def forward(self, x):
        out  = F.relu(self.in1(self.conv1(x)))
        out  = self.in2(self.conv2(out))
        gate = self.style_gate(out)
        return out * gate + x



In [14]:
class UpsampleConv(nn.Module):
    def __init__(self, in_c, out_c, kernel, upsample=None):
        super().__init__()
        self.upsample = upsample
        padding = kernel // 2
        self.conv  = nn.Conv2d(in_c, out_c, kernel, 1, padding)
        self.inorm = nn.InstanceNorm2d(out_c, affine=True)
    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=self.upsample, mode='nearest')
        return F.relu(self.inorm(self.conv(x)))

In [15]:
# class TransformerNet(nn.Module):
    
#     def __init__(self):
#         super().__init__()
#         self.conv1 = ConvLayer(3, 32, 9, 1)
#         self.conv2 = ConvLayer(32, 64, 3, 2)
#         self.conv3 = ConvLayer(64, 128, 3, 2)
#         self.res1 = ResidualBlock(128)
#         self.res2 = ResidualBlock(128)
#         self.res3 = ResidualBlock(128)
#         self.res4 = ResidualBlock(128)
#         self.res5 = ResidualBlock(128)
#         self.up1 = UpsampleConv(128, 64, 3, upsample=2)
#         self.up2 = UpsampleConv(64, 32, 3, upsample=2)
#         self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)
        
#     def forward(self, x):
#         y = self.conv1(x)
#         y = self.conv2(y)
#         y = self.conv3(y)
#         y = self.res1(y)
#         y = self.res2(y)
#         y = self.res3(y)
#         y = self.res4(y)
#         y = self.res5(y)
#         y = self.up1(y)
#         y = self.up2(y)
#         y = self.conv_out(y)
#         return torch.sigmoid(y)

class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvLayer(3, 32, 9, 1)
        self.conv2 = ConvLayer(32, 64, 3, 2)
        self.conv3 = ConvLayer(64, 128, 3, 2)
        self.res1 = StylizedResidualBlock(128)
        self.res2 = StylizedResidualBlock(128)
        self.res3 = StylizedResidualBlock(128)
        self.res4 = StylizedResidualBlock(128)
        self.res5 = StylizedResidualBlock(128)
        self.up1 = UpsampleConv(128, 64, 3, upsample=2)
        self.up2 = UpsampleConv(64, 32, 3, upsample=2)
        self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)
    def forward(self, x):
        y = self.conv1(x); y = self.conv2(y); y = self.conv3(y)
        y = self.res1(y);  y = self.res2(y);  y = self.res3(y); y = self.res4(y); y = self.res5(y)
        y = self.up1(y);   y = self.up2(y)
        y = self.conv_out(y)
        return torch.sigmoid(y)




In [16]:
model = TransformerNet().to(DEVICE)
opt   = optim.Adam(model.parameters(), lr=LR)

def tv_loss_fn(x):
    return torch.mean(torch.abs(x[:,:,:, :-1] - x[:,:,:,1:])) + \
           torch.mean(torch.abs(x[:,:, :-1,:] - x[:,:,1:,:]))

### Style Gram Precompute

In [None]:
def precompute_style_grams(style_dir, transform):
    files = [os.path.join(style_dir, f) for f in os.listdir(style_dir)
             if f.lower().endswith(('.jpg','.jpeg','.png'))]
    grams_list = []
    with torch.no_grad():
        for p in files:
            x = transform(exif_fix_and_open(p)).unsqueeze(0).to(DEVICE)
            feats = vgg_feat(normalize_for_vgg(x))
            gdict = {l: gram_matrix_batch(feats[l]).cpu() for l in feats.keys()}
            grams_list.append(gdict)
    print(f"Precomputed {len(grams_list)} style gram dicts.")
    return grams_list

style_dir = os.path.join(STYLE_ROOT, TARGET_CLASS)
style_grams = precompute_style_grams(style_dir, transform_img)

### Training

In [29]:
def train_single_stage(model, sampler, style_grams, num_epochs, stage_name="stage_single"):
    sampler.set_split('train')
    scaler = torch.cuda.amp.GradScaler() if (USE_AMP and DEVICE.type == "cuda") else None

    # define steps per epoch (bounded to data size but keep it reasonably large)
    n_train = len(sampler.train_content_files)
    steps_per_epoch = max(100, n_train // max(1, BATCH_SIZE))

    for epoch in range(1, num_epochs + 1):
        model.train()
        epoch_loss = 0.0
        ep_prefix = f"[{stage_name}] Epoch {epoch}/{num_epochs}"

        for step in range(1, steps_per_epoch + 1):
            content_batch, style_batch = sampler.sample_batch(BATCH_SIZE)
            content_batch = content_batch.to(DEVICE)
            # style_batch not needed directly since we use precomputed grams
            opt.zero_grad()

            if scaler is not None:
                with torch.cuda.amp.autocast():
                    output = model(content_batch)
                    c_norm, o_norm = normalize_for_vgg(content_batch), normalize_for_vgg(output)
                    c_feats, o_feats = vgg_feat(c_norm), vgg_feat(o_norm)

                    # Content loss
                    c_loss = sum(torch.mean((o_feats[l] - c_feats[l])**2) for l in content_layers)

                    # Style loss
                    Gs = random.choice(style_grams)
                    s_loss = 0.0
                    for l in style_layers:
                        Go = gram_matrix_batch(o_feats[l])
                        s_loss += style_weights.get(l, 1.0) * torch.mean((Go - Gs[l].to(DEVICE))**2)

                    # TV
                    tv = tv_loss_fn(output)
                    total_loss = CONTENT_WEIGHT * c_loss + STYLE_WEIGHT * s_loss + TV_WEIGHT * tv

                scaler.scale(total_loss).backward()
                scaler.step(opt)
                scaler.update()
            else:
                output = model(content_batch)
                c_norm, o_norm = normalize_for_vgg(content_batch), normalize_for_vgg(output)
                c_feats, o_feats = vgg_feat(c_norm), vgg_feat(o_norm)

                c_loss = sum(torch.mean((o_feats[l] - c_feats[l])**2) for l in content_layers)

                Gs = random.choice(style_grams)
                s_loss = 0.0
                for l in style_layers:
                    Go = gram_matrix_batch(o_feats[l])
                    s_loss += style_weights.get(l, 1.0) * torch.mean((Go - Gs[l].to(DEVICE))**2)

                tv = tv_loss_fn(output)
                total_loss = CONTENT_WEIGHT * c_loss + STYLE_WEIGHT * s_loss + TV_WEIGHT * tv
                total_loss.backward()
                opt.step()

            epoch_loss += total_loss.item()

            # tiny '==> ....' progress
            print_progress(f"{ep_prefix}", step, steps_per_epoch, every=10, end=False)

            # occasional sample save
            if step % 300 == 0 or step == steps_per_epoch:
                model.eval()
                with torch.no_grad():
                    sample_out = model(content_batch[:1]).cpu()
                    utils.save_image(sample_out, f"{SAMPLES_DIR}/{stage_name}_ep{epoch}_step{step}.png")
                model.train()

        avg_loss = epoch_loss / steps_per_epoch
        utils.save_image(output[:1].detach().cpu(), f"{SAMPLES_DIR}/{stage_name}_ep{epoch}_lastbatch.png")
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f"{stage_name}_epoch{epoch}.pth"))
        print_progress(f"{ep_prefix} | Avg loss {avg_loss:.4f}", steps_per_epoch, steps_per_epoch, end=True)
        print(f"{ep_prefix} | Avg loss {avg_loss:.4f}")

Training (no dataloader)â€¦
Epoch 1: ====................  20.0% | loss: 24.14405E1 S200/1000 | loss 26.4994
E1 S1000/1000 | loss 15.2456
[E1] checkpoint saved.
Epoch 2: ====................  20.0% | loss: 18.7341E2 S200/1000 | loss 14.3131
E2 S1000/1000 | loss 11.1061
[E2] checkpoint saved.
Epoch 3: ====................  20.0% | loss: 10.1141E3 S200/1000 | loss 11.5658
E3 S1000/1000 | loss 10.4699
[E3] checkpoint saved.
Epoch 4: ====................  20.0% | loss: 5.40036E4 S200/1000 | loss 11.5797
E4 S1000/1000 | loss 10.4822
[E4] checkpoint saved.
Epoch 5: ====................  20.0% | loss: 27.7538E5 S200/1000 | loss 10.9001
E5 S1000/1000 | loss 9.4419
[E5] checkpoint saved.
Epoch 6: ====................  20.0% | loss: 11.0864E6 S200/1000 | loss 11.3464
E6 S1000/1000 | loss 10.6915
[E6] checkpoint saved.


In [None]:
print("Starting single-stage training...")
train_single_stage(model, sampler, style_grams, NUM_EPOCHS, stage_name=f"stage_single_{IMG_SIZE}")
print("Training finished.")

#### Testing for 1 image

In [30]:
ckpt = os.path.join(CHECKPOINT_DIR, f"stage_single_{IMG_SIZE}_epoch{NUM_EPOCHS}.pth")
if os.path.exists(ckpt):
    model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
    print("Loaded checkpoint:", ckpt)
else:
    print("Checkpoint not found; using in-memory weights.")

model.eval()
test_img_path = "test_imgs/butterfly.jpg"
if not os.path.exists(test_img_path):
    print("Test image not found at", test_img_path)
else:
    img = exif_fix_and_open(test_img_path)
    tf  = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                              transforms.CenterCrop(IMG_SIZE),
                              transforms.ToTensor()])
    content_tensor = tf(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = model(content_tensor)
    out_path = os.path.join(SAMPLES_DIR, "stylized_image.png")
    utils.save_image(out.cpu(), out_path)
    print("Saved stylized image to", out_path)

Saved stylized image to ./samples_nststyle/stylized_image.png
