This code was written by Devorah Rotman 316472026 and Carmit Kaye 346038169

In [1]:
MONET_PATH = '../data/monet_jpg'
PHOTO_PATH = '../data/photo_jpg'

In [2]:
from tqdm import tqdm

import torch.nn.functional as F
import warnings
import os
import torch

from torchvision import models
from STGAN.Data_loader_STGAN import get_content_loader, get_style_loader
from STGAN.models_STGAN import Discriminator, Generator
from STGAN.utils import visualize_output, visualize_graphs

warnings.filterwarnings("ignore")

class Training:
    def __init__(self,
                 content_path,
                 style_path,
                 batch_size,
                 in_style,
                 channels,
                 epochs,
                 lr_gen,
                 #lr_style,
                 lr_dis,
                 ckpt_path,
                 device,
                 pretrained):

        self.device = device
        self.in_style = in_style
        self.content_loader = get_content_loader(content_path, batch_size)
        self.style_loader = get_style_loader(style_path, batch_size)
        self.content_iter = iter(self.content_loader)
        self.style_iter = iter(self.style_loader)

        self.dis = Discriminator(channels).to(device)
        self.gen = Generator(in_style, channels).to(device)
        #self.lambda_gp = 10

        #style_params = [self.gen.w]
        #gen_params = [p for n, p in self.gen.named_parameters() if not n.startswith("w.")]
        #style_params = list(self.gen.style_mapper.parameters())
        #gen_params = [p for n, p in self.gen.named_parameters() if not n.startswith("style_mapper")]
        self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr_gen, betas=(0.9, 0.999))
        self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=lr_dis, betas=(0.9, 0.999))
        #self.style_opt = torch.optim.Adam(self.gen.parameters(), lr=lr_style, betas=(0.9, 0.999))
        self.grow_rank = 0
        self.max_scale = len(channels)-1
        self.ckpt_path = ckpt_path if ckpt_path[-3:] == '.pt' else os.path.join(ckpt_path, 'last.pt')
        self.pretrained = pretrained
        self.lr_gen = lr_gen
        self.lr_dis = lr_dis
        #self.lr_style = lr_style
        self.epochs = epochs
        self.alpha = 0

        # Load pretrained VGG19s
        self.vgg = models.vgg19(pretrained=True).features[:21].to(device).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.vgg_input = models.vgg19(pretrained=True).features.to(device).eval()
        for param in self.vgg_input.parameters():
            param.requires_grad = False
        self.content_loss_weight = 1  # Adjust to control content preservation strength

        if self.pretrained:
            self.load_ckpts_train()

    def compute_content_loss(self, real_img, generated_img):
        # Resize to 128x128 as expected by VGG
        real_img = F.interpolate(real_img, size=(128,128), mode='bilinear', align_corners=False)
        generated_img = F.interpolate(generated_img, size=(128, 128), mode='bilinear', align_corners=False)

        # Normalize for VGG
        mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
        gen_norm = (generated_img - mean) / std

        real_feat = self.vgg(real_img)
        gen_feat = self.vgg(gen_norm)

        return F.mse_loss(gen_feat, real_feat)

    def train_loop(self):
        #define everything needed to smoothen alpha
        trick = 1
        steps_per_epoch = len(self.content_loader)
        num_epochs = self.epochs[self.grow_rank]
        #total_steps = steps_per_epoch * num_epochs
        #alpha_transition_steps = total_steps // 2
        alpha_p = int(0.5*num_epochs)

        generator_loss_list = []
        discriminator_loss_list = []

        #current_step = 0
        freeze_d = (self.grow_rank >= 3)


        for epoch in range(num_epochs):
            dis_total_loss = 0
            gen_total_loss = 0
            content_total_loss = 0
            adv_total_loss = 0

            for i in tqdm(range(len(self.content_loader)), ascii=True, desc=f"Epoch {epoch+1} Training"):
                try:
                    content_imgs, _ = next(self.content_iter)
                except StopIteration:
                    self.content_iter = iter(self.content_loader)
                    content_imgs, _ = next(self.content_iter)

                try:
                    style_imgs, _ = next(self.style_iter)
                except StopIteration:
                    self.style_iter = iter(self.style_loader)
                    style_imgs, _ = next(self.style_iter)

                content_imgs = content_imgs.to(self.device)
                style_imgs = style_imgs.to(self.device)
                self.alpha = min(1, trick/alpha_p*len(self.content_loader)) if self.grow_rank >0 else 1
                #self.alpha = min(1.0, current_step / alpha_transition_steps) if self.grow_rank >0 else 1

                ##### Extract features from VGG layer 36 #####
                # Resize and normalize
                new_input = F.interpolate(content_imgs, size=(128, 128), mode='bilinear', align_corners=False)

                with torch.no_grad():
                    x = new_input
                    vgg_feat = self.vgg_input(x)
                #print(vgg_feat.shape)

                ##### Train Discriminator #####
                #if not freeze_d or trick > 1000:
                if not freeze_d or trick > 0:
                    for p in self.dis.parameters():
                        p.requires_grad = True
                else:
                    for p in self.dis.parameters():
                        p.requires_grad = False

                self.dis_opt.zero_grad()

                fake_imgs = self.gen(vgg_feat, alpha = self.alpha).detach()

                #make sure that the fake images and the real images inputted into the discriminator have the same size
                target_size = fake_imgs.shape[2:]  # (H, W)
                if style_imgs.shape[2:] != target_size:
                    style_imgs = F.interpolate(style_imgs, size=target_size, mode='bilinear', align_corners=False)

                real_pred = self.dis(style_imgs, alpha = self.alpha)
                fake_pred = self.dis(fake_imgs, alpha = self.alpha)

                #WGAN loss
                #fake_pred_loss = torch.mean(fake_pred)
                #real_pred_loss = torch.mean(-real_pred)
                #gradient_penalty = self.calculate_gradient_penalty(style_imgs, fake_imgs)
                #d_loss = fake_pred_loss + real_pred_loss + self.lambda_gp * gradient_penalty
                #end of WGAN loss

                #loss version 2
                fake_pred_loss = F.softplus(fake_pred).mean()
                real_pred_loss = F.softplus(-real_pred).mean()
                d_loss = fake_pred_loss + real_pred_loss
                #end of loss v2

                d_loss.backward()
                self.dis_opt.step()

                dis_total_loss += d_loss.item()

                ##### Train Generator #####
                #perc_content = 0.95 #fraction of the epochs used to train the content only
                style_lss_weight = 0.5*trick/(steps_per_epoch * num_epochs)
                for _ in range(2):

                    self.gen_opt.zero_grad()
                    fake_imgs = self.gen(vgg_feat, alpha = self.alpha)
                    fake_pred = self.dis(fake_imgs, alpha = self.alpha)
                    content_loss = self.compute_content_loss(content_imgs, fake_imgs)
                    c_loss = self.content_loss_weight * content_loss
                    adv_loss = F.softplus(-fake_pred).mean()
                    a_loss = style_lss_weight * adv_loss

                    full_loss = a_loss+c_loss
                    full_loss.backward()


                    self.gen_opt.step()
                    gen_total_loss +=  a_loss.item() + c_loss.item()
                    adv_total_loss += a_loss.item()
                    content_total_loss += c_loss.item()


                trick += 1
                #current_step += 1

            #log the gen and dis losses
            gen_total_loss = gen_total_loss / steps_per_epoch
            adv_total_loss = adv_total_loss / steps_per_epoch
            content_total_loss = content_total_loss / steps_per_epoch
            dis_total_loss = dis_total_loss / steps_per_epoch
            generator_loss_list.append(gen_total_loss)
            discriminator_loss_list.append(dis_total_loss)
            print(f"Epoch {epoch+1}:",
                  f"GPU Mem = {round(torch.cuda.memory_reserved()/1E9, 3)} GB |",
                  f"D loss = {round(dis_total_loss, 3)} |",
                  f"Content loss = {round(content_total_loss, 3)}|",
                  f"Adv loss = {round(adv_total_loss, 3)}|",
                  f"G loss = {round(gen_total_loss, 3)}")


            self.save_model()
        return generator_loss_list, discriminator_loss_list

    def train(self):
        while True:
            generator_loss_list, discriminator_loss_list = self.train_loop()
            print(f"the final samples from scale {self.grow_rank+1}")

            # Generate fake image using a random content image
            content_imgs, _ = next(iter(self.content_loader))
            content_imgs = content_imgs.to(self.device)

            # Preprocess for VGG
            new_input = F.interpolate(content_imgs, size=(128, 128), mode='bilinear', align_corners=False)

            with torch.no_grad():
                x = new_input
                vgg_feat = self.vgg_input(x)

            visualize_output(new_input, f"original_content_ims_{self.grow_rank}")
            cosine_distance = F.cosine_similarity(vgg_feat[0], vgg_feat[1]).mean()
            print(cosine_distance)

            fake_imgs = self.gen(vgg_feat, alpha=self.alpha)
            visualize_output(fake_imgs, self.grow_rank)
            visualize_graphs(generator_loss_list, discriminator_loss_list, self.grow_rank)
            if self.grow_rank+1 == self.max_scale:
                print("Maximum scale has been reached.")
                break
            continue_training = self.grow()

    def load_ckpts_train(self):
        ckpt = torch.load(self.ckpt_path, map_location=self.device)
        for i in range(ckpt['grow_rank']):
            self.grow()
        self.gen.load_state_dict(ckpt['generator'])
        self.dis.load_state_dict(ckpt['discriminator'])
        self.gen_opt.load_state_dict(ckpt['generator_opt'])
        self.dis_opt.load_state_dict(ckpt['discriminator_opt'])
        self.grow_rank = ckpt['grow_rank']
        del ckpt

    def grow(self):
        torch.cuda.empty_cache()
        self.gen.grow()
        self.dis.grow()
        self.gen.cuda()
        self.dis.cuda()
        self.grow_rank += 1
        #gen_params = [p for n, p in self.gen.named_parameters() if not n.startswith("w.")]
        #gen_params = [p for n, p in self.gen.named_parameters() if not n.startswith("style_mapper")]
        self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr_gen*0.7**self.grow_rank, betas=(0.9, 0.999))
        #self.style_opt = torch.optim.Adam(self.gen.parameters(), lr=self.lr_style*0.7**self.grow_rank, betas=(0.9, 0.999))
        self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=self.lr_dis*0.7**self.grow_rank, betas=(0.9, 0.999))

    def save_model(self):
        ckpt = {
            'generator': self.gen.state_dict(),
            'discriminator': self.dis.state_dict(),
            'generator_opt': self.gen_opt.state_dict(),
            'discriminator_opt': self.dis_opt.state_dict(),
            'grow_rank': self.grow_rank
        }
        torch.save(ckpt, self.ckpt_path)
        del ckpt


In [3]:
content_ims_folder = PHOTO_PATH
style_ims_folder = MONET_PATH
batch_size = 8
in_style = 512 #size of w
channels = [512, 512, 512, 256, 128, 64, 32] # layer channel sizes
lr_gen = 0.007
lr_dis = 0.00001
epochs = [6,10,10,18,20,24]
checkpoint_path = 'C:/Users/dnrot/OneDrive/Desktop/BGU MSc/HW_deb/DLIP/final_project/STGAN/checkpoints'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained = True



Trainer = Training(content_ims_folder, style_ims_folder, batch_size, in_style, channels, epochs, lr_gen, lr_dis, checkpoint_path, device, pretrained)
Trainer.train()

Epoch 1 Training: 100%|##########| 879/879 [06:22<00:00,  2.30it/s]


Epoch 1: GPU Mem = 1.424 GB | D loss = 0.031 | Content loss = 17.656| Adv loss = 0.108| G loss = 17.764


Epoch 2 Training: 100%|##########| 879/879 [06:53<00:00,  2.13it/s]


Epoch 2: GPU Mem = 1.424 GB | D loss = 0.022 | Content loss = 17.459| Adv loss = 0.333| G loss = 17.791


Epoch 3 Training: 100%|##########| 879/879 [06:46<00:00,  2.16it/s]


Epoch 3: GPU Mem = 1.424 GB | D loss = 0.021 | Content loss = 17.351| Adv loss = 0.563| G loss = 17.914


Epoch 4 Training: 100%|##########| 879/879 [06:36<00:00,  2.22it/s]


Epoch 4: GPU Mem = 1.424 GB | D loss = 0.025 | Content loss = 17.273| Adv loss = 0.79| G loss = 18.062


Epoch 5 Training: 100%|##########| 879/879 [06:39<00:00,  2.20it/s]


Epoch 5: GPU Mem = 1.424 GB | D loss = 0.03 | Content loss = 17.203| Adv loss = 0.998| G loss = 18.201


Epoch 6 Training: 100%|##########| 879/879 [06:51<00:00,  2.14it/s]


Epoch 6: GPU Mem = 1.424 GB | D loss = 0.035 | Content loss = 17.149| Adv loss = 1.203| G loss = 18.352


Epoch 7 Training: 100%|##########| 879/879 [06:47<00:00,  2.16it/s]


Epoch 7: GPU Mem = 1.424 GB | D loss = 0.039 | Content loss = 17.103| Adv loss = 1.39| G loss = 18.494


Epoch 8 Training: 100%|##########| 879/879 [06:48<00:00,  2.15it/s]


Epoch 8: GPU Mem = 1.424 GB | D loss = 0.04 | Content loss = 17.06| Adv loss = 1.582| G loss = 18.642


Epoch 9 Training: 100%|##########| 879/879 [06:43<00:00,  2.18it/s]


Epoch 9: GPU Mem = 1.424 GB | D loss = 0.042 | Content loss = 17.026| Adv loss = 1.767| G loss = 18.793


Epoch 10 Training: 100%|##########| 879/879 [06:49<00:00,  2.14it/s]


Epoch 10: GPU Mem = 1.424 GB | D loss = 0.044 | Content loss = 16.994| Adv loss = 1.955| G loss = 18.949


Epoch 11 Training: 100%|##########| 879/879 [06:48<00:00,  2.15it/s]


Epoch 11: GPU Mem = 1.424 GB | D loss = 0.045 | Content loss = 16.968| Adv loss = 2.129| G loss = 19.097


Epoch 12 Training: 100%|##########| 879/879 [06:11<00:00,  2.37it/s]


Epoch 12: GPU Mem = 1.424 GB | D loss = 0.048 | Content loss = 16.973| Adv loss = 2.315| G loss = 19.288


Epoch 13 Training: 100%|##########| 879/879 [06:30<00:00,  2.25it/s]


Epoch 13: GPU Mem = 1.424 GB | D loss = 0.051 | Content loss = 16.951| Adv loss = 2.486| G loss = 19.437


Epoch 14 Training: 100%|##########| 879/879 [06:47<00:00,  2.16it/s]


Epoch 14: GPU Mem = 1.424 GB | D loss = 0.057 | Content loss = 16.944| Adv loss = 2.649| G loss = 19.593


Epoch 15 Training: 100%|##########| 879/879 [05:35<00:00,  2.62it/s]


Epoch 15: GPU Mem = 1.424 GB | D loss = 0.06 | Content loss = 16.956| Adv loss = 2.783| G loss = 19.74


Epoch 16 Training: 100%|##########| 879/879 [05:53<00:00,  2.48it/s]


Epoch 16: GPU Mem = 1.424 GB | D loss = 0.063 | Content loss = 16.977| Adv loss = 2.925| G loss = 19.902


Epoch 17 Training: 100%|##########| 879/879 [05:44<00:00,  2.55it/s]


Epoch 17: GPU Mem = 1.424 GB | D loss = 0.067 | Content loss = 17.002| Adv loss = 3.05| G loss = 20.052


Epoch 18 Training: 100%|##########| 879/879 [05:34<00:00,  2.63it/s]


Epoch 18: GPU Mem = 1.424 GB | D loss = 0.072 | Content loss = 17.044| Adv loss = 3.181| G loss = 20.225


Epoch 19 Training: 100%|##########| 879/879 [05:32<00:00,  2.65it/s]


Epoch 19: GPU Mem = 1.424 GB | D loss = 0.071 | Content loss = 17.114| Adv loss = 3.332| G loss = 20.446


Epoch 20 Training: 100%|##########| 879/879 [05:32<00:00,  2.64it/s]


Epoch 20: GPU Mem = 1.424 GB | D loss = 0.077 | Content loss = 17.177| Adv loss = 3.43| G loss = 20.607


Epoch 21 Training: 100%|##########| 879/879 [05:28<00:00,  2.67it/s]


Epoch 21: GPU Mem = 1.424 GB | D loss = 0.082 | Content loss = 17.264| Adv loss = 3.578| G loss = 20.842


Epoch 22 Training: 100%|##########| 879/879 [05:24<00:00,  2.71it/s]


Epoch 22: GPU Mem = 1.424 GB | D loss = 0.086 | Content loss = 17.347| Adv loss = 3.717| G loss = 21.065


Epoch 23 Training: 100%|##########| 879/879 [05:19<00:00,  2.75it/s]


Epoch 23: GPU Mem = 1.424 GB | D loss = 0.098 | Content loss = 17.446| Adv loss = 3.842| G loss = 21.288


Epoch 24 Training: 100%|##########| 879/879 [05:19<00:00,  2.75it/s]


Epoch 24: GPU Mem = 1.424 GB | D loss = 0.098 | Content loss = 17.561| Adv loss = 3.999| G loss = 21.56
the final samples from scale 6
tensor(0.0795, device='cuda:0')
Maximum scale has been reached.
