In [66]:
# mlflow ui --port 6007 --backend-store-uri file:/share/lazy/will/ConstrastiveLoss/Logs
# watch -n 0.5 nvidia-smi

In [77]:
from torch import nn
from torchsummary import summary
import mlflow
from torchvision import transforms

In [171]:

class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, n_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            torch.sum(embed_onehot_sum)
            torch.sum(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, kernel_size=3, extra_layers=1, residual=True):
        super().__init__()
        self.residual=residual
        
        layers = [
            nn.Conv2d(in_channel, out_channel, stride=stride, kernel_size=kernel_size, padding=(kernel_size-1)//2),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)]
        
        extra_block = [
            nn.Conv2d(out_channel, out_channel, stride=1, kernel_size=3, padding=(3-1)//2),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)]

        layers.extend(extra_block)

        self.resblock = nn.Sequential(*layers)

    def forward(self, x):
        if self.residual:
            return x+self.resblock(x)
        else:
            return self.resblock(x)

class Encoder(nn.Module):
    def __init__(self, in_channel, channel, extra_layers, stride, kernel_size, residual, extra_residual_blocks, downsample):
        super().__init__()

        self.out_channels = channel

        blocks = [
            ResBlock(in_channel, channel, extra_layers=extra_layers, stride=stride, residual=residual),
            nn.ReLU(inplace=True)
        ]


        for i in range(extra_residual_blocks):
            blocks.append(ResBlock(in_channel=channel, out_channel=channel, extra_layers=extra_layers, residual=True))
            if (downsample=='Once') & (i==0):
                blocks.append(nn.MaxPool2d(2, 2))
            if (downsample=='Twice') & ((i==0) | (i==1)):
                blocks.append(nn.MaxPool2d(2, 2))

        self.encode = nn.Sequential(*blocks)

    def forward(self, input):
        return self.encode(input)


class Decoder(nn.Module):
    def __init__(self, channel, out_channel, extra_layers, extra_residual_blocks, upsample):
        super().__init__()

        blocks = []

        for i in range(extra_residual_blocks):
            blocks.append(ResBlock(in_channel=channel, out_channel=channel, extra_layers=extra_layers, residual=True))
            if (upsample=='Twice') & (i==0):
                blocks.append(nn.ConvTranspose2d(channel, channel, 2, 2))
                            
        blocks.append(nn.ConvTranspose2d(channel, out_channel, 2, 2))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


class VQVAE(nn.Module):
    def __init__(
        self,
        in_channel=3,
        channel=128,
        n_res_block=2,
        n_res_channel=32,
        embed_dim=64,
        n_embed=512,
        decay=0.99,
    ):
        super().__init__()
        # Encoders, first one should have two rounds of downsampling, second should have one
        self.enc_b = Encoder(in_channel=in_channel, channel=channel, extra_layers=2, stride=2, kernel_size=5, residual=False, extra_residual_blocks=2, downsample='Once')
        self.enc_t = Encoder(in_channel=channel, channel=channel, extra_layers=3, stride=1, kernel_size=3, residual=False, extra_residual_blocks=2, downsample='Once')

        self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
        self.quantize_t = Quantize(embed_dim, n_embed)

        # Decoders,
#         self.dec_t = Decoder(embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2)
        self.dec_t = Decoder(embed_dim, embed_dim, channel, extra_residual_blocks = n_res_block, upsample='Once')
        self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
        self.quantize_b = Quantize(embed_dim, n_embed)
        self.upsample_t = nn.ConvTranspose2d(embed_dim, embed_dim, 4, stride=2, padding=1)
#         self.dec = Decoder(embed_dim + embed_dim, in_channel, channel, n_res_block, n_res_channel, stride=4)
        self.dec = Decoder(embed_dim + embed_dim, in_channel, extra_layers=2, extra_residual_blocks=2, upsample='Twice')

    def forward(self, input):
        quant_t, quant_b, diff, _, _ = self.encode(input)
        dec = self.decode(quant_t, quant_b)

        return dec, diff

    def encode(self, input):
        enc_b = self.enc_b(input)
        enc_t = self.enc_t(enc_b)

        quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
        quant_t, diff_t, id_t = self.quantize_t(quant_t)
        quant_t = quant_t.permute(0, 3, 1, 2)
        diff_t = diff_t.unsqueeze(0)

        dec_t = self.dec_t(quant_t)
        enc_b = torch.cat([dec_t, enc_b], 1)

        quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
        quant_b, diff_b, id_b = self.quantize_b(quant_b)
        quant_b = quant_b.permute(0, 3, 1, 2)
        diff_b = diff_b.unsqueeze(0)

        return quant_t, quant_b, diff_t + diff_b, id_t, id_b

    def decode(self, quant_t, quant_b):
        upsample_t = self.upsample_t(quant_t)
        quant = torch.cat([upsample_t, quant_b], 1)
        dec = self.dec(quant)

        return dec

    def decode_code(self, code_t, code_b):
        quant_t = self.quantize_t.embed_code(code_t)
        quant_t = quant_t.permute(0, 3, 1, 2)
        quant_b = self.quantize_b.embed_code(code_b)
        quant_b = quant_b.permute(0, 3, 1, 2)

        dec = self.decode(quant_t, quant_b)

        return dec

In [175]:
model = VQVAE().to('cuda:0')
# summary(model, (3, 256, 256))
model(torch.ones(1,3,256,256).to('cuda:0'))[0].shape

we add once!
we add once!


torch.Size([1, 3, 256, 256])

In [120]:
thing1 = Encoder(in_channel=3, channel=64, extra_layers=2, stride=2, kernel_size=5, residual=False, extra_residual_blocks=2, downsample='Once')
thing2 = Encoder(in_channel=thing1.out_channels, channel=64, extra_layers=3, stride=1, kernel_size=3, residual=False, extra_residual_blocks=2, downsample='Once')
thing3 = Decoder(64, 32, extra_layers=2, extra_residual_blocks=2, upsample='Twice')


In [155]:
summary(thing1.to('cuda:0'), (3, 256, 256))
summary(thing2.to('cuda:0'), (thing1.out_channels, 64, 64))
# summary(thing3.to('cuda:0'), (64, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,792
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4         [-1, 64, 128, 128]          36,928
       BatchNorm2d-5         [-1, 64, 128, 128]             128
              ReLU-6         [-1, 64, 128, 128]               0
          ResBlock-7         [-1, 64, 128, 128]               0
              ReLU-8         [-1, 64, 128, 128]               0
            Conv2d-9         [-1, 64, 128, 128]          36,928
      BatchNorm2d-10         [-1, 64, 128, 128]             128
             ReLU-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,928
      BatchNorm2d-13         [-1, 64, 128, 128]             128
             ReLU-14         [-1, 64, 1

In [72]:
import argparse
import sys
import os

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from tqdm import tqdm
import mlflow

from VQVAE import VQVAE
import distributed as dist

class Params(object):
    def __init__(self, batch_size, epochs, lr, size):
        self.size = batch_size
        self.epoch = epochs
        self.lr = lr
        self.size = size

args = Params(64, 1000, 4e-4, 128)

In [73]:
def train(epoch, loader, model, optimizer, device):
    if dist.is_primary():
        loader = tqdm(loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (img, labels) in enumerate(loader):
        print(i)
        model.zero_grad()

        img = img.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        optimizer.step()

        mse_sum += recon_loss.item() * img.shape[0]
        mse_n += img.shape[0]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            loader.set_description(
                (
                    f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; "
                    f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                    f"lr: {lr:.5f}"
                )
            )

        if i % 20 == 0:
            model.eval()

            sample = img[:sample_size]

            with torch.no_grad():
                out, _ = model(sample)

            utils.save_image(
                torch.cat([sample, out], 0),
                f"samples/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.jpg",
                nrow=sample_size,
                normalize=True,
                range=(-1, 1),
            )

            model.train()

        yield {'Latent Loss':latent_loss.item(), 'Average MSE':mse_sum/mse_n, 'Reconstruction Loss':recon_loss.item()}


In [75]:
device = "cuda:1"

mlflow.tracking.set_tracking_uri('file:/share/lazy/will/ConstrastiveLoss/Logs')

mlflow.set_experiment('Vector Quantized Variational Autoencoder')

transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
#         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

dataset = datasets.ImageFolder('/share/lazy/will/ConstrastiveLoss/Imgs/color_images/train/', transform=transform)
# sampler = dist.data_sampler(dataset, shuffle=True, distributed=False)
loader = DataLoader(dataset, batch_size=128, shuffle=False, pin_memory = True)

model = VQVAE().to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr)

run_name = 'VQVAE 100k images'
with mlflow.start_run(run_name = run_name) as run:

    for key, value in vars(args).items():
        mlflow.log_param(key, value)

    mlflow.log_param('Parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))

    for epoch in range(args.epoch):
        print('We are on epoch number '+str(epoch))
        results = train(epoch, loader, model, optimizer, device)
        for Dict in results:
            for key, value in Dict.items():
                print(key, value)
                mlflow.log_metric(key, value, epoch)

            torch.save({
    'model':model.state_dict(),
    'optimizer':optimizer.state_dict(),
    }, 'run_stats.pyt')
    mlflow.log_artifact('run_stats.pyt')



  0%|          | 0/456 [00:00<?, ?it/s]

We are on epoch number 0


epoch: 1; mse: 0.26948; latent: 1.295; avg mse: 0.26948; lr: 0.00040:   0%|          | 0/456 [00:04<?, ?it/s]

0


epoch: 1; mse: 0.26948; latent: 1.295; avg mse: 0.26948; lr: 0.00040:   0%|          | 1/456 [00:04<32:04,  4.23s/it]

Latent Loss 1.2951692342758179
Average MSE 0.26947730779647827
Reconstruction Loss 0.26947730779647827


epoch: 1; mse: 0.24401; latent: 0.001; avg mse: 0.25674; lr: 0.00040:   0%|          | 1/456 [00:10<32:04,  4.23s/it]

1
Latent Loss 0.0010279725538566709
Average MSE 0.2567427307367325
Reconstruction Loss 0.2440081536769867


epoch: 1; mse: 0.23317; latent: 0.001; avg mse: 0.24888; lr: 0.00040:   0%|          | 2/456 [00:16<36:05,  4.77s/it]

2
Latent Loss 0.0013551212614402175
Average MSE 0.24888442953427634
Reconstruction Loss 0.233167827129364


epoch: 1; mse: 0.23317; latent: 0.001; avg mse: 0.24888; lr: 0.00040:   1%|          | 3/456 [00:19<48:46,  6.46s/it]


KeyboardInterrupt: 

In [142]:
from torch import distributed as dist
def get_world_size():
    if not torch.distributed.dist.is_available():
        return 1

    if not torch.distributed.dist.is_initialized():
        return 1

    return torch.distributed.dist.get_world_size()


def all_reduce(tensor, op=dist.ReduceOp.SUM):
    world_size = get_world_size()

    if world_size == 1:
        return tensor

    dist.all_reduce(tensor, op=op)

    return tensor

In [167]:
class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, n_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            torch.sum(embed_onehot_sum)
            torch.sum(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))


class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channel, channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out


class Encoder(nn.Module):
    def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
        super().__init__()

        if stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 3, padding=1),
            ]

        elif stride == 2:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 3, padding=1),
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


class Decoder(nn.Module):
    def __init__(
        self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
    ):
        super().__init__()

        blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        if stride == 4:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(
                        channel // 2, out_channel, 4, stride=2, padding=1
                    ),
                ]
            )

        elif stride == 2:
            blocks.append(
                nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
            )

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


class VQVAE(nn.Module):
    def __init__(
        self,
        in_channel=3,
        channel=128,
        n_res_block=2,
        n_res_channel=32,
        embed_dim=64,
        n_embed=512,
        decay=0.99,
    ):
        super().__init__()

        self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
        self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
        self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
        self.quantize_t = Quantize(embed_dim, n_embed)
        self.dec_t = Decoder(
            embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
        )
        self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
        self.quantize_b = Quantize(embed_dim, n_embed)
        self.upsample_t = nn.ConvTranspose2d(
            embed_dim, embed_dim, 4, stride=2, padding=1
        )
        self.dec = Decoder(
            embed_dim + embed_dim,
            in_channel,
            channel,
            n_res_block,
            n_res_channel,
            stride=4,
        )

    def forward(self, input):
        quant_t, quant_b, diff, _, _ = self.encode(input)
        dec = self.decode(quant_t, quant_b)

        return dec, diff

    def encode(self, input):
        enc_b = self.enc_b(input)
        enc_t = self.enc_t(enc_b)

        quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
        quant_t, diff_t, id_t = self.quantize_t(quant_t)
        quant_t = quant_t.permute(0, 3, 1, 2)
        diff_t = diff_t.unsqueeze(0)

        dec_t = self.dec_t(quant_t)
        enc_b = torch.cat([dec_t, enc_b], 1)

        quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
        quant_b, diff_b, id_b = self.quantize_b(quant_b)
        quant_b = quant_b.permute(0, 3, 1, 2)
        diff_b = diff_b.unsqueeze(0)

        return quant_t, quant_b, diff_t + diff_b, id_t, id_b

    def decode(self, quant_t, quant_b):
        upsample_t = self.upsample_t(quant_t)
        quant = torch.cat([upsample_t, quant_b], 1)
        dec = self.dec(quant)

        return dec

    def decode_code(self, code_t, code_b):
        quant_t = self.quantize_t.embed_code(code_t)
        quant_t = quant_t.permute(0, 3, 1, 2)
        quant_b = self.quantize_b.embed_code(code_b)
        quant_b = quant_b.permute(0, 3, 1, 2)

        dec = self.decode(quant_t, quant_b)

        return dec

In [168]:
model = VQVAE().to('cuda:0')
summary(model, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,136
              ReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3          [-1, 128, 64, 64]         131,200
              ReLU-4          [-1, 128, 64, 64]               0
            Conv2d-5          [-1, 128, 64, 64]         147,584
              ReLU-6          [-1, 128, 64, 64]               0
            Conv2d-7           [-1, 32, 64, 64]          36,896
              ReLU-8           [-1, 32, 64, 64]               0
            Conv2d-9          [-1, 128, 64, 64]           4,224
         ResBlock-10          [-1, 128, 64, 64]               0
             ReLU-11          [-1, 128, 64, 64]               0
           Conv2d-12           [-1, 32, 64, 64]          36,896
             ReLU-13           [-1, 32, 64, 64]               0
           Conv2d-14          [-1, 128,

TypeError: can't multiply sequence by non-int of type 'list'