## Feed Forward Style Transfer

In [62]:
import os, random, time, platform
from pathlib import Path
from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models, utils
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

In [63]:
#select computation device
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    USE_AMP = True
elif getattr(torch, "has_mps", False) and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    USE_AMP = False
else:
    DEVICE = torch.device("cpu")
    USE_AMP = False

print("Device:", DEVICE, "USE_AMP:", USE_AMP)

Device: mps USE_AMP: False


  elif getattr(torch, "has_mps", False) and torch.backends.mps.is_available():


In [113]:
#Hyperparameters
IMG_SIZE = 512       
BATCH_SIZE = 6
NUM_EPOCHS = 12
LR = 1e-4 #changed
CONTENT_WEIGHT = 1 #
STYLE_WEIGHT   = 1e6   
TV_WEIGHT      = 1e-6 #for smoothness

In [114]:
#fielpaths
CONTENT_ROOT = "../Data/dataset/clean/animals_balanced"   
STYLE_ROOT= "../Data/dataset/clean/origami_images"
SPLIT_ROOT   = "../Data/dataset/split"   
CHECKPOINT_DIR = "./checkpoints_nststyle"
SAMPLES_DIR    = "./samples_nststyle"

for d in [SPLIT_ROOT, CHECKPOINT_DIR, SAMPLES_DIR]:
    os.makedirs(d, exist_ok=True)
    
TARGET_CLASS = "butterfly" #single class (inital)

for split in ['train', 'val', 'test']:
    for root in ['content', 'style']:
        path = os.path.join(SPLIT_ROOT, root, split, TARGET_CLASS)
        os.makedirs(path, exist_ok=True)

In [115]:
#uses dataloader for windows devices
if platform.system().lower().startswith("darwin"):
    NUM_WORKERS = 0
else:
    NUM_WORKERS = min(6, max(0, (os.cpu_count() or 2) - 2))

PIN_MEMORY = True if DEVICE.type == "cuda" else False
print("NUM_WORKERS:", NUM_WORKERS, "PIN_MEMORY:", PIN_MEMORY)

NUM_WORKERS: 0 PIN_MEMORY: False


#### VGG Layer Configs and Normalization

In [116]:
#Vsame as NST
LAYER_INDICES = {
    'conv1_1': '0', 
    'conv1_2': '2', 
    'conv2_1': '5', 
    'conv2_2': '7',
    'conv3_1': '10', 
    'conv3_2': '12', 
    'conv3_3': '14', 
    'conv3_4': '16',
    'conv4_1': '19', 
    'conv4_2': '21', 
    'conv4_3': '23', 
    'conv4_4': '25',
    'conv5_1': '28', 
    'conv5_2': '30', 
    'conv5_3': '32', 
    'conv5_4': '34'
}

LAYER_CONFIGS = {
    'gatys': {
        'content': ['conv4_2'],
        'style': ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1'],
        'style_weights': {
            'conv1_1': 1.0,
            'conv2_1': 0.8,
            'conv3_1': 0.5,
            'conv4_1': 0.3,
            'conv5_1': 0.1
        },
    }
}
ACTIVE_LAYER_CONFIG = 'gatys'

#normalization
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD  = [0.229, 0.224, 0.225]


#### Normalization for VGG and Feature Extraction

In [117]:
def exif_fix_and_open(path):
    img = Image.open(path)
    img = ImageOps.exif_transpose(img)
    return img.convert("RGB")

#same as nst.py
def normalize_for_vgg(x):
    mean = torch.tensor(IMG_MEAN).view(1,3,1,1).to(DEVICE)
    std  = torch.tensor(IMG_STD).view(1,3,1,1).to(DEVICE)
    return (x - mean) / std

# def extract_features_batch(x, layers, model):
#     x_vgg = normalize_for_vgg(x)
#     cur = x_vgg
#     features = {}
#     layers_to_extract = {LAYER_INDICES[name]: name for name in layers}
#     for idx, layer in model._modules.items():
#         cur = layer(cur)
#         if idx in layers_to_extract:
#             features[layers_to_extract[idx]] = cur
#     return features

def gram_matrix_batch(tensor):
    b, c, h, w = tensor.size()
    f = tensor.view(b, c, h*w)
    return torch.bmm(f, f.transpose(1,2)) / (c * h * w)


#### Data Splitting

In [118]:
def ensure_splits(class_name, val_frac=0.1, test_frac=0.05, seed=42):
    
    random.seed(seed)
    src_content = os.path.join(CONTENT_ROOT, class_name)
    src_style   = os.path.join(STYLE_ROOT, class_name)
    
    assert os.path.isdir(src_content), f"Missing content folder: {src_content}"
    assert os.path.isdir(src_style), f"Missing style folder: {src_style}"

    def split_list(files):
        n = len(files)
        
        n_val = int(n * val_frac)
        n_test = int(n * test_frac)
        
        return files[n_val+n_test:], files[:n_val], files[n_val:n_val+n_test]

    def copy_split(src_folder, dst_folder, files):#copy to split folder
        os.makedirs(dst_folder, exist_ok=True)
        
        for f in files:
            src = os.path.join(src_folder, f)
            dst = os.path.join(dst_folder, f)
            
            if not os.path.exists(dst):
                Image.open(src).convert("RGB").save(dst, "JPEG", quality=90)

    for domain in ["content", "style"]: #for both sets of images
        src = os.path.join(CONTENT_ROOT if domain=="content" else STYLE_ROOT, class_name)
        files = [f for f in os.listdir(src) if f.lower().endswith(('.jpg','.jpeg','.png'))]
        
        print(f"{domain} - {class_name} - found {len(files)} files in {src}")
        
        random.shuffle(files)
        train, val, test = split_list(files)

        for split_name, flist in zip(["train", "val", "test"], [train, val, test]):
            out_dir = os.path.join(SPLIT_ROOT, domain, split_name, class_name)
            copy_split(src, out_dir, flist)

    print(f"Created train/val/test splits under {SPLIT_ROOT}/{class_name}")

ensure_splits(TARGET_CLASS)

content - butterfly - found 1160 files in ../Data/dataset/clean/animals_balanced/butterfly
style - butterfly - found 116 files in ../Data/dataset/clean/origami_images/butterfly
Created train/val/test splits under ../Data/dataset/split/butterfly


#### Data Sampling

In [119]:
# progressive training for GPU meomory

IMG_SIZE_STAGE1 = 256   
IMG_SIZE_STAGE2 = 512   
NUM_EPOCHS_STAGE1 = 8
NUM_EPOCHS_STAGE2 = 3

def get_transforms(img_size):
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor()
    ])

content_transform_stage1 = get_transforms(IMG_SIZE_STAGE1)
style_transform_stage1   = content_transform_stage1

content_transform_stage2 = get_transforms(IMG_SIZE_STAGE2)
style_transform_stage2   = content_transform_stage2

In [120]:
class SingleClassPairedSampler:
    
    def __init__(self, split_root, class_name, transform=None, use_dataloader=(NUM_WORKERS>0)):
        self.split_root = split_root
        self.class_name = class_name
        self._build_index()
        self.split = 'train'
        self.transform = transform
        self._use_dataloader = use_dataloader
        self._init_dataloaders_if_needed()

    #build index
    def _build_index(self):
        for split in ['train','val','test']:
            c_path = os.path.join(self.split_root, "content", split, self.class_name)
            s_path = os.path.join(self.split_root, "style",   split, self.class_name)

            c_files = sorted([
                os.path.join(c_path, f) for f in os.listdir(c_path)
                if f.lower().endswith(('.jpg','.jpeg','.png'))
            ])
            s_files = sorted([
                os.path.join(s_path, f) for f in os.listdir(s_path)
                if f.lower().endswith(('.jpg','.jpeg','.png'))
            ])

            setattr(self, f"{split}_content_files", c_files)
            setattr(self, f"{split}_style_files", s_files)
        

    def _init_dataloaders_if_needed(self): #only for non mac devices
        if not self._use_dataloader:
            self._content_loader = None
            self._style_loader = None
            return
        
        class _DS(Dataset):
            def __init__(self, files, transform):
                self.files = files
                self.transform = transform
                
            def __len__(self): return len(self.files)
            
            def __getitem__(self, idx):
                img = exif_fix_and_open(self.files[idx])
                return self.transform(img)
            
        c_ds = _DS(getattr(self, "train_content_files"), self.transform or content_transform_stage1)
        s_ds = _DS(getattr(self, "train_style_files"), self.transform or style_transform_stage1)
        
        self._content_loader = DataLoader(c_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
        self._style_loader   = DataLoader(s_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
        
        self._content_iter = iter(self._content_loader)
        self._style_iter = iter(self._style_loader)

    def set_split(self, split):#switches b/w train/test/val
        self.split = split
        # if not using dataloader, nothing else to do
        if not self._use_dataloader:
            return
        # rebuild dataloaders for chosen split
        class _DS(Dataset):
            def __init__(self, files, transform):
                self.files = files
                self.transform = transform
                
            def __len__(self): return len(self.files)
            
            def __getitem__(self, idx):
                img = exif_fix_and_open(self.files[idx])
                return self.transform(img)
            
        c_files = getattr(self, f"{split}_content_files")
        s_files = getattr(self, f"{split}_style_files")
        
        c_ds = _DS(c_files, self.transform or content_transform_stage1)
        s_ds = _DS(s_files, self.transform or style_transform_stage1)
        
        self._content_loader = DataLoader(c_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
        self._style_loader   = DataLoader(s_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
        
        self._content_iter = iter(self._content_loader)
        self._style_iter = iter(self._style_loader)

    #returns a batach of paired content
    def sample_batch(self, batch_size):
        if self._use_dataloader:
            # dataloader path (works)
            try:
                c = next(self._content_iter)
            except Exception:
                self._content_iter = iter(self._content_loader)
                c = next(self._content_iter)
            try:
                s = next(self._style_iter)
            except Exception:
                self._style_iter = iter(self._style_loader)
                s = next(self._style_iter)
            return c, s
        else:
            # fallback if not using dataloader
            c_files = getattr(self, f"{self.split}_content_files")
            s_files = getattr(self, f"{self.split}_style_files")

            paths_c = random.choices(c_files, k=batch_size)
            paths_s = random.choices(s_files, k=batch_size)

            c_imgs = [content_transform_stage1(exif_fix_and_open(p)) for p in paths_c]
            s_imgs = [style_transform_stage1(exif_fix_and_open(p)) for p in paths_s]

            return torch.stack(c_imgs), torch.stack(s_imgs)
        
sampler = SingleClassPairedSampler(SPLIT_ROOT, TARGET_CLASS, transform=content_transform_stage1)

print("Sampler ready, train content files:", len(sampler.train_content_files))
print("Sampler ready, train style files:", len(sampler.train_style_files))


Sampler ready, train content files: 986
Sampler ready, train style files: 100


#### Load Pre-Trained VGG

In [121]:
#load pretrained VGG (same form nst.py)
vgg = models.vgg19(pretrained=True).features.to(DEVICE).eval()
for p in vgg.parameters():
    p.requires_grad = False
    
class VGGFeatureExtractor(nn.Module):
    def __init__(self, vgg, layer_indices):
        super().__init__()
        self.vgg = vgg
        self.idx_to_name = {int(idx_str): name for name, idx_str in layer_indices.items()}
        
    def forward(self, x):
        cur = x
        feats = {}
        
        for idx, layer in self.vgg._modules.items():
            cur = layer(cur)
            i = int(idx)
            if i in self.idx_to_name:
                feats[self.idx_to_name[i]] = cur
        return feats

vgg_feat = VGGFeatureExtractor(vgg, LAYER_INDICES).to(DEVICE).eval()

print("VGG feature extractor ready.")

VGG feature extractor ready.


#### Transformer Network

In [122]:
class ConvLayer(nn.Module): #processes the image and extracts features while keeping style consistent
    def __init__(self, in_c, out_c, kernel, stride):
        super().__init__()
        padding = kernel // 2
        self.conv = nn.Conv2d(in_c, out_c, kernel, stride, padding)
        self.inorm = nn.InstanceNorm2d(out_c, affine=True)
    def forward(self, x):
        return F.relu(self.inorm(self.conv(x)))



In [123]:
# class ResidualBlock(nn.Module): #learn style modifications
#     def __init__(self, channels):
#         super().__init__()
#         self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in1 = nn.InstanceNorm2d(channels, affine=True)
#         self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
#         self.in2 = nn.InstanceNorm2d(channels, affine=True)
#     def forward(self, x):
#         out = F.relu(self.in1(self.conv1(x)))
#         out = self.in2(self.conv2(out))
#         return out + x

class StylizedResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)
        
        # style gate to enhance stylized contrast/edges
        self.style_gate = nn.Sequential(
            nn.Conv2d(channels, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        out = F.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        gate = self.style_gate(out)
        # modulate residual with learned style gate
        out = out * gate + x
        
        return out


In [124]:
class UpsampleConv(nn.Module): # upsampling the image (making it bigger)
    def __init__(self, in_c, out_c, kernel, upsample=None):
        super().__init__()
        self.upsample = upsample
        padding = kernel // 2
        self.conv = nn.Conv2d(in_c, out_c, kernel, 1, padding)
        self.inorm = nn.InstanceNorm2d(out_c, affine=True)
        
    def forward(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=self.upsample, mode='nearest')
            
        return F.relu(self.inorm(self.conv(x)))

In [125]:
# class TransformerNet(nn.Module):
    
#     def __init__(self):
#         super().__init__()
#         self.conv1 = ConvLayer(3, 32, 9, 1)
#         self.conv2 = ConvLayer(32, 64, 3, 2)
#         self.conv3 = ConvLayer(64, 128, 3, 2)
#         self.res1 = ResidualBlock(128)
#         self.res2 = ResidualBlock(128)
#         self.res3 = ResidualBlock(128)
#         self.res4 = ResidualBlock(128)
#         self.res5 = ResidualBlock(128)
#         self.up1 = UpsampleConv(128, 64, 3, upsample=2)
#         self.up2 = UpsampleConv(64, 32, 3, upsample=2)
#         self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)
        
#     def forward(self, x):
#         y = self.conv1(x)
#         y = self.conv2(y)
#         y = self.conv3(y)
#         y = self.res1(y)
#         y = self.res2(y)
#         y = self.res3(y)
#         y = self.res4(y)
#         y = self.res5(y)
#         y = self.up1(y)
#         y = self.up2(y)
#         y = self.conv_out(y)
#         return torch.sigmoid(y)

class TransformerNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvLayer(3, 32, 9, 1)
        self.conv2 = ConvLayer(32, 64, 3, 2)
        self.conv3 = ConvLayer(64, 128, 3, 2)

        #stylized Residuals
        self.res1 = StylizedResidualBlock(128)
        self.res2 = StylizedResidualBlock(128)
        self.res3 = StylizedResidualBlock(128)
        self.res4 = StylizedResidualBlock(128)
        self.res5 = StylizedResidualBlock(128)

        self.up1 = UpsampleConv(128, 64, 3, upsample=2)
        self.up2 = UpsampleConv(64, 32, 3, upsample=2)
        self.conv_out = nn.Conv2d(32, 3, 9, 1, 4)

    def forward(self, x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.up1(y)
        y = self.up2(y)
        y = self.conv_out(y)
        return torch.sigmoid(y)

In [126]:
def precompute_style_grams(style_dir, transform, batch_size=8):
    files = [os.path.join(style_dir, f) for f in os.listdir(style_dir) if f.lower().endswith(('.jpg','.jpeg','.png'))]
    
    class _DS(Dataset):
        def __init__(self, files, transform):
            self.files = files
            self.transform = transform
            
        def __len__(self): return len(self.files)
        def __getitem__(self, idx):
            return self.transform(exif_fix_and_open(self.files[idx]))
        
    ds = _DS(files, transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=max(0, NUM_WORKERS//2))
    grams_list = []
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(DEVICE)
            feats = vgg_feat(normalize_for_vgg(batch))
            
            for i in range(batch.size(0)):
                gdict = {l: gram_matrix_batch(feats[l][i:i+1]).cpu() for l in feats.keys()}
                grams_list.append(gdict)
                
    print(f"Precomputed {len(grams_list)} style gram dicts.")
    return grams_list

style_dir = os.path.join(STYLE_ROOT, TARGET_CLASS)

style_grams_stage1 = precompute_style_grams(style_dir, style_transform_stage1, batch_size=8)


Precomputed 116 style gram dicts.


In [127]:
#initialize model and optimizer
model = TransformerNet().to(DEVICE)
opt = optim.Adam(model.parameters(), lr=LR)

In [128]:
def tv_loss_fn(x):
    return torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \
           torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))

#train
def train_stage(model, sampler, style_grams, transform, num_epochs, stage_name):
    sampler.transform = transform
    sampler._init_dataloaders_if_needed()
    sampler.set_split('train')
    
    scaler = torch.cuda.amp.GradScaler() if (USE_AMP and DEVICE.type == "cuda") else None #for gpu

    #finds number of steps
    n_train = len(getattr(sampler, "train_content_files", []))
    steps_per_epoch = max(100, n_train // BATCH_SIZE)

    for epoch in range(1, num_epochs + 1):
        model.train()
        epoch_loss = 0.0
        
        pbar = tqdm(range(steps_per_epoch), desc=f"{stage_name} Epoch {epoch}") #progress bar

        for step in pbar: 
            content_batch, style_batch = sampler.sample_batch(BATCH_SIZE)
            content_batch = content_batch.to(DEVICE)
            style_batch = style_batch.to(DEVICE)

            opt.zero_grad()

            if scaler is not None: #windows
                
                with torch.cuda.amp.autocast():
                    output = model(content_batch)
                    c_norm, o_norm = normalize_for_vgg(content_batch), normalize_for_vgg(output)
                    c_feats, o_feats = vgg_feat(c_norm), vgg_feat(o_norm)

                    # Content loss
                    c_loss = sum(torch.mean((o_feats[l] - c_feats[l])**2) for l in content_layers)

                    # Style loss
                    Gs = random.choice(style_grams)
                    s_loss = 0.0
                    for l in style_layers:
                        Go = gram_matrix_batch(o_feats[l])
                        s_loss += style_weights.get(l, 1.0) * torch.mean((Go - Gs[l].to(DEVICE))**2)

                    # TV loss
                    tv = tv_loss_fn(output)
                    total_loss = CONTENT_WEIGHT * c_loss + STYLE_WEIGHT * s_loss + TV_WEIGHT * tv

                # Backprop with scaler
                scaler.scale(total_loss).backward()
                scaler.step(opt)
                scaler.update()
            else:

                output = model(content_batch) #output image
                
                c_norm, o_norm = normalize_for_vgg(content_batch), normalize_for_vgg(output) #normalize
                c_feats, o_feats = vgg_feat(c_norm), vgg_feat(o_norm) #extract features

                c_loss = sum(torch.mean((o_feats[l] - c_feats[l])**2) for l in content_layers) #content loss

                #style loss
                Gs = random.choice(style_grams)
                s_loss = 0.0
                for l in style_layers:
                    Go = gram_matrix_batch(o_feats[l])
                    s_loss += style_weights.get(l, 1.0) * torch.mean((Go - Gs[l].to(DEVICE))**2)

                tv = tv_loss_fn(output)#TV loss
                total_loss = CONTENT_WEIGHT * c_loss + STYLE_WEIGHT * s_loss + TV_WEIGHT * tv

                total_loss.backward()
                opt.step()

            epoch_loss += total_loss.item()

        
            if step % 50 == 0:
                pbar.set_description(f"{stage_name} E{epoch} S{step} Loss {total_loss.item():.4f}")

            #save samples
            if step % 300 == 0:
                model.eval()
                
                with torch.no_grad():
                    sample_out = model(content_batch[:1]).cpu()
                    utils.save_image(sample_out, f"{SAMPLES_DIR}/{stage_name}_ep{epoch}_step{step}.png")
                    
                model.train()

        # Save checkpoint and print epoch stats
        avg_loss = epoch_loss / steps_per_epoch
        
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f"{stage_name}_epoch{epoch}.pth"))
        
        print(f"[{stage_name}] Epoch {epoch} done | Avg loss {avg_loss:.4f}")

#### Training

In [129]:
#Stage 1: low res
sampler = SingleClassPairedSampler(SPLIT_ROOT, TARGET_CLASS, transform=content_transform_stage1, use_dataloader=(NUM_WORKERS>0))
style_grams = style_grams_stage1  #precomputed earlier

print("Starting Stage 1 (low-res) training")
train_stage(model, sampler, style_grams, content_transform_stage1, NUM_EPOCHS_STAGE1, stage_name=f"stage1_{IMG_SIZE_STAGE1}")

#Stage 2: high res
print("Precomputing style grams for stage2 (high-res)...")
style_grams_stage2 = precompute_style_grams(style_dir, style_transform_stage2, batch_size=4)

sampler = SingleClassPairedSampler(SPLIT_ROOT, TARGET_CLASS, transform=content_transform_stage2, use_dataloader=(NUM_WORKERS>0))
print("Starting Stage 2 (high-res) fine-tune")
train_stage(model, sampler, style_grams_stage2, content_transform_stage2, NUM_EPOCHS_STAGE2, stage_name=f"stage2_{IMG_SIZE_STAGE2}")


Starting Stage 1 (low-res) training


stage1_256 Epoch 1:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 1 done | Avg loss 109.1278


stage1_256 Epoch 2:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 2 done | Avg loss 82.5825


stage1_256 Epoch 3:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 3 done | Avg loss 85.0462


stage1_256 Epoch 4:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 4 done | Avg loss 98.4891


stage1_256 Epoch 5:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 5 done | Avg loss 86.0847


stage1_256 Epoch 6:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 6 done | Avg loss 91.3243


stage1_256 Epoch 7:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 7 done | Avg loss 83.1822


stage1_256 Epoch 8:   0%|          | 0/164 [00:00<?, ?it/s]

[stage1_256] Epoch 8 done | Avg loss 83.7357
Precomputing style grams for stage2 (high-res)...
Precomputed 116 style gram dicts.
Starting Stage 2 (high-res) fine-tune


stage2_512 Epoch 1:   0%|          | 0/164 [00:00<?, ?it/s]

[stage2_512] Epoch 1 done | Avg loss 58.6360


stage2_512 Epoch 2:   0%|          | 0/164 [00:00<?, ?it/s]

[stage2_512] Epoch 2 done | Avg loss 60.5667


stage2_512 Epoch 3:   0%|          | 0/164 [00:00<?, ?it/s]

[stage2_512] Epoch 3 done | Avg loss 46.4449


#### Testing for 1 image

In [130]:
model = TransformerNet().to(DEVICE)
checkpoint_path = f"{CHECKPOINT_DIR}/stage2_512_epoch3.pth"  # use the stylized residual version
model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
model.eval()
print("Loaded stylized residual checkpoint successfully!")

âœ… Loaded stylized residual checkpoint successfully!


In [131]:
ckpt = os.path.join(CHECKPOINT_DIR, f"stage2_{IMG_SIZE_STAGE2}_epoch{NUM_EPOCHS_STAGE2}.pth")

if os.path.exists(ckpt):
    print("Loading checkpoint:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
    
else:
    print("Checkpoint not found; using model in memory (may be already trained).")

model.eval()
test_img_path = "test_imgs/butterfly.jpg"

if not os.path.exists(test_img_path):
    print("Test image not found at", test_img_path)
else:
    img = exif_fix_and_open(test_img_path)
    
    tf = get_transforms = transforms.Compose([transforms.Resize((IMG_SIZE_STAGE2, IMG_SIZE_STAGE2)),ctransforms.ToTensor()])
    content_tensor = tf(img).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        out = model(content_tensor)
        
    utils.save_image(out.cpu(), os.path.join(SAMPLES_DIR, "stylized_image.png"))
    
    print("Saved stylized image to", os.path.join(SAMPLES_DIR, "stylized_image.png"))

Loading checkpoint: ./checkpoints_nststyle/stage2_512_epoch3.pth
Saved stylized image to ./samples_nststyle/stylized_image.png
