In [20]:
from collections import defaultdict
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, sampler
import lpips
import numpy as np
from tqdm import tqdm
from torchsummary import summary
from tensorboardX import SummaryWriter

from src import (
    Encoder,
    Decoder,
    NormNoiseQuantization,
    AEModel,
    GoogleDataset
)

In [22]:
device = torch.device('cpu')

# models

In [23]:
model = Encoder()
summary(model, (3, 512, 512), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 510, 510]             448
         LeakyReLU-2         [-1, 16, 510, 510]               0
       BatchNorm2d-3         [-1, 16, 510, 510]              32
         Dropout2d-4         [-1, 16, 510, 510]               0
         MaxPool2d-5         [-1, 16, 170, 170]               0
            Conv2d-6         [-1, 32, 168, 168]           4,640
         LeakyReLU-7         [-1, 32, 168, 168]               0
       BatchNorm2d-8         [-1, 32, 168, 168]              64
         Dropout2d-9         [-1, 32, 168, 168]               0
        MaxPool2d-10           [-1, 32, 56, 56]               0
           Conv2d-11           [-1, 64, 52, 52]          51,264
        LeakyReLU-12           [-1, 64, 52, 52]               0
      BatchNorm2d-13           [-1, 64, 52, 52]             128
        Dropout2d-14           [-1, 64,

In [4]:
model = Decoder()
summary(model, (256, 1, 1), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 128, 2, 2]         131,200
              GELU-2            [-1, 128, 2, 2]               0
       BatchNorm2d-3            [-1, 128, 2, 2]             256
   ConvTranspose2d-4             [-1, 64, 4, 4]          73,792
              GELU-5             [-1, 64, 4, 4]               0
       BatchNorm2d-6             [-1, 64, 4, 4]             128
UpsamplingBilinear2d-7           [-1, 64, 16, 16]               0
   ConvTranspose2d-8           [-1, 64, 16, 16]          36,928
              GELU-9           [-1, 64, 16, 16]               0
      BatchNorm2d-10           [-1, 64, 16, 16]             128
UpsamplingBilinear2d-11           [-1, 64, 64, 64]               0
  ConvTranspose2d-12           [-1, 32, 64, 64]          18,464
             GELU-13           [-1, 32, 64, 64]               0
      BatchNorm2d-14           [-1

# metrics

In [8]:
# MSE
def mse_loss(result, target):
    return F.mse_loss(result, target)

# PSNR
def psnr(result, target):
    mse = mse_loss(result, target)
    return 10 * torch.log10(1 / mse) * (torch.max(result) ** 2)

# Intermediate vector entropy
def latent_entropy_aprox(result):
    probabilities = F.softmax(result, dim=1)
    entropy = F.cross_entropy(probabilities, torch.ones_like(probabilities) / probabilities.size(1))
    return entropy


def normalize_img(img: torch.Tensor) -> torch.Tensor:
    img -= img.min()
    img /= img.max()
    # img *= 255
    return img

# train

In [9]:
d = '/data/ucheba/master_2sem/archive/'
ds = GoogleDataset(csv_file=d + 'train.csv', image_dir=d, batch_size=6)

indices = list(range(len(ds)))
np.random.shuffle(indices)
split_ix = int(0.9*len(ds))
train_indices, val_indices = indices[:split_ix], indices[split_ix:]
train_sampler = sampler.SubsetRandomSampler(train_indices)
test_sampler = sampler.SubsetRandomSampler(val_indices)

train = DataLoader(ds, batch_size=ds.batch_size, num_workers=8, sampler=train_sampler)
test = DataLoader(ds, batch_size=ds.batch_size, num_workers=8, sampler=test_sampler)

In [10]:
quantizator = NormNoiseQuantization()
loss_fn_alex = lpips.LPIPS(net='alex', verbose=False)
loss_fn_alex.to(device)


model = AEModel()#.to(device)
model.load_state_dict(torch.load(open('./weights/ae_step58200.pt', 'rb'), map_location=device))
model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)

writer = SummaryWriter(logdir='runs/Jun24_01-39-04_mi/')
train_step = 8775 + 56238
test_step = 0

In [None]:
for epoch in range(10):
    losses = defaultdict(lambda: [])
    
    for batch in tqdm(train, total=len(train)):
        
        batch = batch.to(device)
        
        state = model.encode(batch)
        quantized = quantizator(state)
        pred = model.decode(quantized)
        
        loss_mse = mse_loss(pred, batch)
        loss_content = loss_fn_alex((pred), (batch)).mean()
        
        loss = loss_mse + loss_content
        loss.backward()
        opt.step()
        opt.zero_grad()
        
        with torch.no_grad():
            loss_psnr = psnr(pred, batch)
            loss_entropy = latent_entropy_aprox(quantized)

            writer.add_scalar('train/mse', loss_mse.item(), train_step)
            writer.add_scalar('train/psnr', loss_psnr.item(), train_step)
            writer.add_scalar('train/entropy', loss_entropy.item(), train_step)
            writer.add_scalar('train/conent_loss', loss_content.item(), train_step)
            writer.add_scalar('train/loss', loss.item(), train_step)
            if train_step % 10 == 0:
                writer.add_image('train/pred', normalize_img(pred[0]), train_step)
                writer.add_image('train/target', normalize_img(batch[0]), train_step)

            if train_step % 5000 == 0:
                model.cpu()
                torch.save(model.state_dict(), f'weights/ae_step{train_step}.pt')
                model.to(device)
            train_step += 1


100%|██████████| 19880/19880 [3:37:21<00:00,  1.52it/s]  
 66%|██████▌   | 13129/19880 [2:23:32<1:17:31,  1.45it/s]

In [None]:
train_step