In [1]:
import comet_ml
from config import API_KEY

import torch
import torch.nn as nn
import torch.nn.functional as F

import einops

torch.cuda.is_available()
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 4070 SUPER'

In [None]:
from datasets import load_dataset
from torchvision.transforms import v2
from datasets import DatasetDict


transform = v2.Compose([
    v2.ToTensor(),
    v2.Resize((256, 256)),
])

def preprocess(example):
    example['image'] = example['image'].float() / 255.0
    example['image'] = transform(example['image'])
    return example


ds = load_dataset("Artificio/WikiArt_Full").with_format('torch')

train_test_split = ds["train"].train_test_split(test_size=0.15)
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']

ds = {
    'train': train_dataset,
    'test': test_dataset
}

ds = DatasetDict(ds)



In [4]:
from torch.utils.data import DataLoader, random_split

train_loader = DataLoader(
        ds['train'],
        batch_size=80,
        num_workers=12,
        shuffle=True,
        pin_memory=True,
    )

test_loader = DataLoader(
        ds['test'],
        batch_size=80,
        num_workers=12,
        shuffle=False,
        pin_memory=True,
    )


In [5]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from torchmetrics.image import StructuralSimilarityIndexMeasure

class SSIMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda()

    def forward(self, x, y):
        ssim_score = self.ssim(x, y)
        return 1 - ssim_score
    

def log_images_to_comet(images, experiment, idx, epoch, tr_step):
    grid = make_grid(images, nrow=8, padding=2)
    
    np_grid = grid.permute(1, 2, 0).numpy()

    plt.figure(figsize=(15, 15))
    plt.imshow(np.clip(np_grid, 0, 1))
    plt.axis("off")
    experiment.log_figure(figure_name=f'{tr_step}:{idx}', figure=plt, step=epoch)
    plt.close()

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [6]:
import torch
import torch.nn.functional as F
from torchvision.models import vgg19
from torchvision import models

class CombinedLoss(nn.Module):
    def __init__(self, l1_weight=1.0, perceptual_weight=0.05, style_weight=0.0001):
        super(CombinedLoss, self).__init__()
        self.l1_weight = l1_weight
        self.perceptual_weight = perceptual_weight
        self.style_weight = style_weight

        self.l1_loss = nn.L1Loss()
        vgg = vgg19(pretrained=True).features
        self.vgg_layers = vgg[:36].eval()      
        for param in self.vgg_layers.parameters():
            param.requires_grad = False

    def compute_feature_difference(self, prediction, target):
        feature_diff = 0
        for p_f, t_f in zip(prediction, target):
            p_f_norm = F.normalize(p_f, p=2, dim=1)
            t_f_norm = F.normalize(t_f, p=2, dim=1)
            feature_diff += torch.norm(p_f_norm - t_f_norm, p=2)

        return feature_diff

    def extract_features(self, x):
        features = []
        for i, layer in enumerate(self.vgg_layers):
            x = layer(x)
            if i in {3, 8, 17, 26, 35}:  # conv1_2, conv2_2, conv3_4, conv4_4, conv5_4
                features.append(x)
        return features
    
    def forward(self, prediction, target):
        self.vgg_layers.to(dtype=prediction.dtype, device=prediction.device)

        prediction_features = self.extract_features(prediction)
        target_features = self.extract_features(target)

        # L1 Loss
        l1_loss = self.l1_loss(prediction, target)

        # Perceptual Loss (Feature-Level MSE)
        perceptual_loss = sum(F.mse_loss(p_f, t_f)
                              for p_f, t_f in zip(prediction_features, target_features))

        # Style Loss (Gram Matrix MSE)
        style_loss = self.compute_feature_difference(prediction_features, target_features)


        l1 = self.l1_weight * l1_loss
        perceptual = self.perceptual_weight * perceptual_loss
        style = self.style_weight * style_loss

        return l1, perceptual, style

In [7]:
from tqdm import tqdm

from comet_ml.integration.pytorch import log_model

from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary

from models import Encoder, Decoder, AutoEncoder

if __name__ == '__main__':

    lambda_ssim = 0.2
    lambda_mse = 0.8


    latent_width = 1024
    scaler = GradScaler()
    encoder = Encoder(latent_width)
    decoder = Decoder(latent_width)
    model = AutoEncoder(encoder=encoder, decoder=decoder)
    model.apply(init_weights)
    model.cuda()

    comet_experiment = comet_ml.Experiment(api_key=API_KEY, project_name='UN_latent')
    comet_experiment.log_parameters(
        {
            'batch_size': train_loader.batch_size,
            'train_size': ds['train'].num_rows,
            'test_size': ds['test'].num_rows,
        }
    )

    summ = summary(model, (1,3,256,256), device='cuda',depth=5)
    comet_experiment.set_model_graph(f'{model.__repr__()}\n{summ}')

    num_epochs = 50

    loss_fn = CombinedLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    comet_experiment.log_parameter('num_epochs', num_epochs)

    for epoch in range(num_epochs):
        comet_experiment.set_epoch(epoch)
                    

        model.train()
        with comet_experiment.train() as train:
            for idx, batch in tqdm(enumerate(train_loader), desc=f'TRAIN_{epoch}'):
                comet_experiment.set_step(idx + epoch * len(train_loader))

                optimizer.zero_grad()
                images = batch['image'] / 255.0
                labels = batch['image']
                images = images.cuda()
                with torch.autocast(device_type='cuda'):
                    predictions, latents = model(images)
                    l1, per, style = loss_fn(predictions, images)
                    loss = l1 + per + style
                scaler.scale(loss.float()).backward()
                scaler.step(optimizer)
                scaler.update()


                comet_experiment.log_metric('l1_loss', l1)
                comet_experiment.log_metric('per_loss', per)
                comet_experiment.log_metric('style_loss', style)
        
        model.eval()
        with comet_experiment.validate() as test, torch.no_grad() as nograd:
            for idx, batch in tqdm(enumerate(test_loader), desc=f'TEST_{epoch}'):
                comet_experiment.set_step(idx + epoch * len(test_loader))

                images = batch['image'] / 255.0
                images = images.cuda()
                with torch.autocast(device_type='cuda'):
                    predictions, latents = model(images)
                    l1, per, style = loss_fn(predictions, images)
                    loss = l1 + per + style

                comet_experiment.log_metric('l1_loss', l1)
                comet_experiment.log_metric('per_loss', per)
                comet_experiment.log_metric('style_loss', style)

                if idx < 2:
                    concatenated = torch.cat([images, predictions], dim=3).cpu()
                    log_images_to_comet(concatenated, comet_experiment, idx, epoch, 'TEST')


        if (epoch + 1) % 10 == 0:
            torch.save(encoder.state_dict(), f"models/encoder_{epoch}.pth")
            torch.save(decoder.state_dict(), f"models/decoder_{epoch}.pth")
            torch.save(model.state_dict(), f"models/model_{epoch}.pth")


    log_model(comet_experiment, model, model_name="AutoEncoder")
    comet_experiment.end()
        

  scaler = GradScaler()
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/foxtold/un-latent/4683bdd15fd64b889cc6e7de4e0c3b40

TRAIN_0: 1098it [09:14,  1.98it/s]
TEST_0: 194it [00:49,  3.89it/s]
TRAIN_1: 1098it [08:40,  2.11it/s]
TEST_1: 194it [00:47,  4.09it/s]
TRAIN_2: 1098it [08:53,  2.06it/s]
TEST_2: 194it [00:47,  4.05it/s]
TRAIN_3: 1098it [08:55,  2.05it/s]
TEST_3: 194it [00:49,  3.94it/s]
TRAIN_4: 1098it [08:54,  2.05it/s]
TEST_4: 194it [00:50,  3.86it/s]
TRAIN_5: 1098it [09:01,  2.03it/s]
TEST_5: 194it [00:47,  4.07it/s]
TRAIN_6: 1098it [08:36,  2.13it/s]
TEST_6: 194it [00:47,  4.07it/s]
TRAIN_7: 1098it [08:46,  2.09it/s]
TEST_7: 194it [00:49,  3.94it/s]
TRAIN_8: 1098it [09:05,  2.01it/s]
TEST_8: 194it [00:49,  3.93it/s]
TRAIN_9: 1098it [08:40,  2.11it/s]
TEST_9: 194it [00:48,  4.02it/s]
TRAIN_10: 195it [01:33,  2.08it/s]


KeyboardInterrupt: 