In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
# !mkdir -p /content/train2014
# !mkdir -p /content/painted
# !wget http://images.cocodataset.org/zips/train2014.zip -O train2014.zip
# !unzip -q train2014.zip -d /content/train2014
!pip3 install TensorboardX

Collecting TensorboardX
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Downloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: TensorboardX
Successfully installed TensorboardX-2.6.4


In [None]:
import os
import shutil
import random

# Path to the main directory containing the subfolders
src_dir = "/content/train2014"
dst_dir = "/content/train_sample"
sample_size = 20000

os.makedirs(dst_dir, exist_ok=True)

# Collect all images from all subdirectories
all_images = []
for root, dirs, files in os.walk(src_dir):
    for f in files:
        if f.lower().endswith((".jpg", ".jpeg", ".png")):
            all_images.append(os.path.join(root, f))

# Ensure sufficient images are available
if sample_size > len(all_images):
    print(f"Number of images is less than {sample_size}, using all available images: {len(all_images)}")
    sample_images = all_images
else:
    # Randomly select samples
    sample_images = random.sample(all_images, sample_size)

# Copy images to the sample directory
for img_path in sample_images:
    shutil.copy(img_path, os.path.join(dst_dir, os.path.basename(img_path)))

# print(f"تم إنشاء سامبل التدريب ({len(sample_images)} صورة) في: {dst_dir}")
print(f"Training sample created ({len(sample_images)} images) in: {dst_dir}")

تم إنشاء سامبل التدريب (20000 صورة) في: /content/train_sample


In [None]:
# 1. Install the Kaggle library
!pip install --quiet kaggle

# 2. Upload the kaggle.json file
from google.colab import files
files.upload()

# 3. Place the token in the correct directory and set permissions
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# 4. Download the Painter by Numbers (resized) dataset
!kaggle datasets download -d kovalevvyu/painter-by-numbers-resized --unzip -p /content/painter_dataset

# 5. List the content to verify
!ls /content/painter_dataset

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
18064.jpg   33552.jpg  49040.jpg  64529.jpg  80016.jpg	95505.jpg
18065.jpg   33553.jpg  49041.jpg  6452.jpg   80017.jpg	95506.jpg
18066.jpg   33554.jpg  49042.jpg  64530.jpg  80018.jpg	95507.jpg
18067.jpg   33555.jpg  49043.jpg  64531.jpg  80019.jpg	95508.jpg
18068.jpg   33556.jpg  49044.jpg  64532.jpg  8001.jpg	95509.jpg
18069.jpg   33557.jpg  49045.jpg  64533.jpg  80020.jpg	9550.jpg
1806.jpg    33558.jpg  49046.jpg  64534.jpg  80021.jpg	95510.jpg
18070.jpg   33559.jpg  49047.jpg  64535.jpg  80022.jpg	95511.jpg
18071.jpg   3355.jpg   49048.jpg  64536.jpg  80023.jpg	95512.jpg
18072.jpg   33560.jpg  49049.jpg  64537.jpg  80024.jpg	95513.jpg
18073.jpg   33561.jpg  4904.jpg   64538.jpg  80025.jpg	95514.jpg
18074.jpg   33562.jpg  49050.jpg  64539.jpg  80026.jpg	95515.jpg
18075.jpg   33563.jpg  49051.jpg  6453.jpg   80027.jpg	95516.jpg
18076.jpg   33564.jpg  49052.jpg  64540.jpg  80028.jpg	95517.jpg
18077.jpg   33565.jpg  4905

In [None]:
import os
import shutil
import random

# Path to the dataset downloaded from Kaggle
src_dir = "/content/painter_dataset"

# Destination directory for the sample
dst_dir = "/content/style_sample"

# Number of images desired in the sample
sample_size = 15000

# Create the sample directory
os.makedirs(dst_dir, exist_ok=True)

# Collect all images
all_images = []
for root, dirs, files in os.walk(src_dir):
    for f in files:
        if f.lower().endswith(('.jpg', '.jpeg', '.png')):
            all_images.append(os.path.join(root, f))

# Random selection
sample_images = random.sample(all_images, sample_size)

# Copy images to the new directory
for img_path in sample_images:
    shutil.copy(img_path, dst_dir)

# print("تم إنشاء style sample بنجاح في:", dst_dir)
print("Style sample created successfully in:", dst_dir)


تم إنشاء style sample بنجاح في: /content/style_sample


In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils import data
import argparse
import os
import torch.backends.cudnn as cudnn
from PIL import Image, ImageFile
from tensorboardX import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

# ------------------------- UTILITY FUNCTIONS -------------------------
def calc_mean_std(feat, eps=1e-5):
    """Calculate mean and standard deviation of feature maps for style normalization"""
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def mean_variance_norm(feat):
    """Normalize features with mean and variance"""
    size = feat.size()
    mean, std = calc_mean_std(feat)
    normalized_feat = (feat - mean.expand(size)) / std.expand(size)
    return normalized_feat

def _calc_feat_flatten_mean_std(feat):
    """Flatten 3D feature (C, H, W) and compute per-channel mean and std"""
    assert (feat.size()[0] == 3)
    assert (isinstance(feat, torch.FloatTensor))
    feat_flatten = feat.view(3, -1)
    mean = feat_flatten.mean(dim=-1, keepdim=True)
    std = feat_flatten.std(dim=-1, keepdim=True)
    return feat_flatten, mean, std

# ------------------------- DECODER NETWORK -------------------------
decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

# ------------------------- VGG ENCODER -------------------------
vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)

# ------------------------- SELF-ATTENTION NETWORK (SANet) -------------------------
class SANet(nn.Module):
    """Self-Attention Network for style transfer"""
    def __init__(self, in_planes):
        super(SANet, self).__init__()
        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)
        self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))

    def forward(self, content, style):
        F = self.f(mean_variance_norm(content))
        G = self.g(mean_variance_norm(style))
        H = self.h(style)
        b, c, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        G = G.view(b, -1, w * h)
        S = torch.bmm(F, G)
        S = self.sm(S)
        H = H.view(b, -1, w * h)
        O = torch.bmm(H, S.permute(0, 2, 1))
        O = O.view(b, c, h, w)
        O = self.out_conv(O) + content
        return O

# ------------------------- TRANSFORM MODULE -------------------------
class Transform(nn.Module):
    """Module combining SANet layers at multiple feature scales"""
    def __init__(self, in_planes):
        super(Transform, self).__init__()
        self.sanet4_1 = SANet(in_planes=in_planes)
        self.sanet5_1 = SANet(in_planes=in_planes)
        self.upsample5_1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1))
        self.merge_conv = nn.Conv2d(in_planes, in_planes, (3, 3))

    def forward(self, content4_1, style4_1, content5_1, style5_1):
        return self.merge_conv(
            self.merge_conv_pad(
                self.sanet4_1(content4_1, style4_1) +
                self.upsample5_1(self.sanet5_1(content5_1, style5_1))
            )
        )

# ------------------------- FULL NETWORK -------------------------
class Net(nn.Module):
    """Full style transfer network with encoder, transform, and decoder"""
    def __init__(self, encoder, decoder, start_iter):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])
        self.enc_2 = nn.Sequential(*enc_layers[4:11])
        self.enc_3 = nn.Sequential(*enc_layers[11:18])
        self.enc_4 = nn.Sequential(*enc_layers[18:31])
        self.enc_5 = nn.Sequential(*enc_layers[31:44])
        self.transform = Transform(in_planes=512)
        self.decoder = decoder
        if(start_iter > 0):
            self.transform.load_state_dict(torch.load(f'transformer_iter_{start_iter}.pth'))
            self.decoder.load_state_dict(torch.load(f'decoder_iter_{start_iter}.pth'))
        self.mse_loss = nn.MSELoss()
        # freeze encoder parameters
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    def encode_with_intermediate(self, input):
        """Extract intermediate features from all encoder blocks"""
        results = [input]
        for i in range(5):
            func = getattr(self, f'enc_{i + 1}')
            results.append(func(results[-1]))
        return results[1:]

    def calc_content_loss(self, input, target, norm=False):
        if not norm:
            return self.mse_loss(input, target)
        else:
            return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target))

    def calc_style_loss(self, input, target):
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std)

    def forward(self, content, style):
        style_feats = self.encode_with_intermediate(style)
        content_feats = self.encode_with_intermediate(content)
        stylized = self.transform(content_feats[3], style_feats[3], content_feats[4], style_feats[4])
        g_t = self.decoder(stylized)
        g_t_feats = self.encode_with_intermediate(g_t)

        # content loss
        loss_c = self.calc_content_loss(g_t_feats[3], content_feats[3], norm=True) + \
                 self.calc_content_loss(g_t_feats[4], content_feats[4], norm=True)

        # style loss
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 5):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])

        # identity loss
        Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4]))
        Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4]))
        l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style)
        Fcc = self.encode_with_intermediate(Icc)
        Fss = self.encode_with_intermediate(Iss)
        l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0])
        for i in range(1, 5):
            l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i], style_feats[i])

        return loss_c, loss_s, l_identity1, l_identity2

# ------------------------- DATA SAMPLERS -------------------------
def InfiniteSampler(n):
    """Yield infinite indices with reshuffling after each epoch"""
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(data.sampler.Sampler):
    """Wrapper for PyTorch DataLoader to sample infinitely"""
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

# ------------------------- ENVIRONMENT SETUP -------------------------
cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------------- DATA TRANSFORM -------------------------
def train_transform():
    """Transforms applied to input images"""
    transform_list = [
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)

# ------------------------- DATASET -------------------------
class FlatFolderDataset(data.Dataset):
    """Dataset for loading images from nested folders"""
    def __init__(self, root, transform):
        super(FlatFolderDataset, self).__init__()
        self.root = root
        self.paths = []
        for subdir, dirs, files in os.walk(self.root):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                    self.paths.append(os.path.join(subdir, file))
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'FlatFolderDataset'

# ------------------------- HELPER FUNCTIONS -------------------------
def adjust_learning_rate(optimizer, iteration_count):
    """Adjust learning rate using linear decay"""
    lr = args.lr / (1.0 + args.lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# ------------------------- ARGUMENTS -------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--content_dir', type=str, default='/content/train_sample')
parser.add_argument('--style_dir', type=str, default='/content/style_sample')
parser.add_argument('--vgg', type=str, default='/content/vgg_normalised.pth')
parser.add_argument('--save_dir', default='./experiments')
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=5e-5)
parser.add_argument('--max_iter', type=int, default=15000)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--style_weight', type=float, default=3.0)
parser.add_argument('--content_weight', type=float, default=1.0)
parser.add_argument('--n_threads', type=int, default=2)
parser.add_argument('--save_model_interval', type=int, default=5000)
parser.add_argument('--start_iter', type=float, default=0)
args = parser.parse_args('')

# ------------------------- DEVICE -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# ------------------------- TRANSFORMS -------------------------
content_tf = train_transform()
style_tf = train_transform()

# ------------------------- DATASET AND DATALOADER -------------------------
content_dataset = FlatFolderDataset(args.content_dir, content_tf)
style_dataset = FlatFolderDataset(args.style_dir, style_tf)

content_loader = data.DataLoader(
    content_dataset,
    batch_size=args.batch_size,
    sampler=InfiniteSamplerWrapper(content_dataset),
    num_workers=args.n_threads,
    pin_memory=True
)
style_loader = data.DataLoader(
    style_dataset,
    batch_size=args.batch_size,
    sampler=InfiniteSamplerWrapper(style_dataset),
    num_workers=args.n_threads,
    pin_memory=True
)

content_iter = iter(content_loader)
style_iter = iter(style_loader)

# ------------------------- MODEL AND OPTIMIZER -------------------------
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:44])
network = Net(vgg, decoder, args.start_iter)
network.train()
network.to(device)

optimizer = torch.optim.Adam([
    {'params': network.decoder.parameters()},
    {'params': network.transform.parameters()}
], lr=args.lr)

scaler = GradScaler()  # AMP scaler

writer = SummaryWriter(log_dir=args.log_dir)

# ------------------------- TRAINING LOOP -------------------------
for i in tqdm(range(int(args.start_iter), args.max_iter)):
    adjust_learning_rate(optimizer, iteration_count=i)

    # get next batch (infinite sampler prevents StopIteration)
    content_images = next(content_iter).to(device, non_blocking=True)
    style_images = next(style_iter).to(device, non_blocking=True)

    # forward + backward pass with mixed precision
    optimizer.zero_grad()
    with autocast():
        loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)
        loss_c = args.content_weight * loss_c
        loss_s = args.style_weight * loss_s
        loss = loss_c + loss_s + l_identity1 * 50 + l_identity2 * 1

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # logging
    writer.add_scalar('loss/total', loss.item(), i + 1)
    writer.add_scalar('loss/content', loss_c.item(), i + 1)
    writer.add_scalar('loss/style', loss_s.item(), i + 1)
    writer.add_scalar('loss/identity1', l_identity1.item(), i + 1)
    writer.add_scalar('loss/identity2', l_identity2.item(), i + 1)

    # save checkpoints
    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        # save decoder
        state_dict = network.decoder.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, f'{args.save_dir}/decoder_iter_{i+1}.pth')

        # save transformer
        state_dict = network.transform.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, f'{args.save_dir}/transformer_iter_{i+1}.pth')

        # save optimizer
        torch.save(optimizer.state_dict(), f'{args.save_dir}/optimizer_iter_{i+1}.pth')

writer.close()







Device: cuda


  scaler = GradScaler()  # AMP scaler
  with autocast():
100%|██████████| 15000/15000 [1:21:57<00:00,  3.05it/s]
