In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.autograd import grad
from torch import nn, optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.utils import make_grid
from torchvision import transforms as torch_transforms

In [2]:
import sys
sys.path.append("./../")

from modules.dvae.model import DVAE
from modules.common_blocks import ResidualStack

In [None]:
# This function computes the accuracy on the test dataset
def compute_accuracy(dvae, clf, testloader, device):
    clf.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            latent = dvae.q_encode(images, hard=True)
            outputs = clf(latent)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [3]:
class CNNClassifier(nn.Module):
    def __init__(self,
                 n_classes,
                 embedding_dim,
                 num_blocks):
        super(CNNClassifier, self).__init__()
        
        self.resid = ResidualStack(
            in_channels=embedding_dim, 
            out_channels=embedding_dim, 
            num_residual_layers=num_blocks, 
            bias=True, 
            use_bn=True, 
            final_relu=False)
        
        channels_dims = [
            embedding_dim // 2,
            embedding_dim // 4,
        ]
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=embedding_dim,
                      out_channels=channels_dims[0],
                      kernel_size=3, padding=0),
            nn.BatchNorm2d(num_features=channels_dims[0]),
            nn.ReLU(),
            nn.Conv2d(in_channels=channels_dims[0], 
                      out_channels=channels_dims[1], 
                      kernel_size=3, padding=0),
            nn.BatchNorm2d(num_features=channels_dims[1]),
            nn.ReLU(),
            nn.Conv2d(in_channels=channels_dims[1], 
                      out_channels=n_classes, 
                      kernel_size=3, padding=0),
        )
        
        self.fc = nn.Linear(in_features=n_classes, out_features=n_classes)

    def forward(self, x):
        x = self.resid(x)
        x = self.conv(x)
        x = self.fc(x.squeeze())
        return x

In [35]:
class TrEncoderBlock(nn.Module):
    def __init__(self, n_features, n_attn_heads, n_hidden=64, dropout_prob=0.1):
        super(TrEncoderBlock, self).__init__()

        self.attn = nn.MultiheadAttention(n_features, n_attn_heads)
        self.ln1 = nn.LayerNorm(n_features)
        self.dropout1 = nn.Dropout(dropout_prob)

        self.mlp = nn.Sequential(
            nn.Linear(n_features, n_hidden),
            nn.Dropout(dropout_prob),
            nn.GELU(),
            nn.Linear(n_hidden, n_features)
        )
        self.ln2 = nn.LayerNorm(n_features)
        self.dropout2 = nn.Dropout(dropout_prob)

    def forward(self, x, pad_mask=None, attn_mask=None):
        xn = self.ln1(x)
        dx, _ = self.attn(query=xn, key=xn, value=xn, 
                          key_padding_mask=pad_mask, 
                          attn_mask=attn_mask)
        x = x + self.dropout1(dx)
        
        xn = self.ln2(x)
        dx = self.mlp(xn)
        x = x + self.dropout2(dx)
        
        return x


class ViTClassifier(nn.Module):
    def __init__(self,
                 n_classes,
                 embedding_dim,
                 hidden_height,
                 hidden_width,
                 num_blocks,
                 n_attn_heads,
                 hidden_dim,
                 dropout_prob):
        super(ViTClassifier, self).__init__()

        
        num_latent_positions = hidden_height * hidden_width + 1
        self.pe = nn.Parameter(torch.randn(1, num_latent_positions, embedding_dim))
        
        self.lin_proj = nn.Linear(embedding_dim, embedding_dim)
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))

        self.tr_encoder_blocks = nn.ModuleList([
            TrEncoderBlock(n_features=embedding_dim,
                           n_attn_heads=n_attn_heads,
                           n_hidden=hidden_dim,
                           dropout_prob=dropout_prob)
            for _ in range(num_blocks)
        ])

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, n_classes),
        )

    def forward(self, img_latent):
        b, c, h, w = img_latent.size()
        x = img_latent.view(b, c, h * w).permute(0, 2, 1)  # -> b, h*w, c

        x = self.lin_proj(x)
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pe

        x = x.permute(1, 0, 2)  # -> h*w, b, c

        for i, block in enumerate(self.tr_encoder_blocks):
            x = block(x)

        #cls_input = x.mean(dim=0)
        cls_input = x[0, :, :]
            
        cls = self.mlp_head(cls_input).squeeze()

        return cls

In [36]:
class Config:
    DEVICE                      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    img_channels                = 1
    vocab_size                  = 128

    noise_dim                   = 100
    hidden_height               = 7
    hidden_width                = 7

    num_blocks                  = 8
    n_attn_heads                = 8
    hidden_dim                  = 256
    dropout_prob                = 0.1

    dvae_num_x2upsamples        = 2
    dvae_num_resids_downsample  = 3
    dvae_num_resids_bottleneck  = 4
    dvae_hidden_dim             = 256

    load_dvae_path              = "/m/home/home8/82/sukhoba1/data/Desktop/TA-VQVAE/models/dvae_M_mnist/"
    dvae_model_name             = "dvae_M_mnist"
    data_path                   = "/m/home/home8/82/sukhoba1/data/Desktop/TA-VQVAE/data/MNIST/"

    NUM_EPOCHS                  = 10
    BATCH_SIZE                  = 512
    LR                          = 0.001
    LR_gamma                    = 0.1
    step_LR_milestones          = [90]


CONFIG = Config()

In [37]:
data_transforms = torch_transforms.Compose([
    torch_transforms.RandomRotation(10),
    torch_transforms.ToTensor()
])

trainset = datasets.MNIST(
    CONFIG.data_path,
    train=True,
    download=False,
    transform=data_transforms)

train_loader = DataLoader(
    trainset,
    batch_size=CONFIG.BATCH_SIZE,
    shuffle=True)


testset = datasets.MNIST(
    CONFIG.data_path,
    train=False,
    download=False,
    transform=data_transforms)

test_loader = DataLoader(
    testset,
    batch_size=CONFIG.BATCH_SIZE,
    shuffle=True)

In [39]:
dvae = DVAE(
    in_channels=CONFIG.img_channels,
    vocab_size=CONFIG.vocab_size,
    num_x2downsamples=CONFIG.dvae_num_x2upsamples,
    num_resids_downsample=CONFIG.dvae_num_resids_downsample,
    num_resids_bottleneck=CONFIG.dvae_num_resids_bottleneck,
    hidden_dim=CONFIG.dvae_hidden_dim)

dvae.eval()
dvae.load_model(
    root_path=CONFIG.load_dvae_path,
    model_name=CONFIG.dvae_model_name)
dvae.to(CONFIG.DEVICE)


# clf = CNNClassifier(
#     n_classes=10,
#     embedding_dim=CONFIG.vocab_size,
#     num_blocks=CONFIG.num_blocks)

clf = ViTClassifier(
    n_classes=10,
    embedding_dim=CONFIG.vocab_size,
    hidden_height=CONFIG.hidden_height,
    hidden_width=CONFIG.hidden_width,
    num_blocks=CONFIG.num_blocks,
    n_attn_heads=CONFIG.n_attn_heads,
    hidden_dim=CONFIG.hidden_dim,
    dropout_prob=CONFIG.dropout_prob)

clf.train()
clf.to(CONFIG.DEVICE)

pass


In [43]:
print("Device in use: {}".format(CONFIG.DEVICE))

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(clf.parameters(), lr=CONFIG.LR)

iteration = 0
for epoch in range(CONFIG.NUM_EPOCHS):
    for x, label in train_loader:
        
        label = label.to(CONFIG.DEVICE)
        x = x.to(CONFIG.DEVICE)
        
        with torch.no_grad():
            latent = dvae.q_encode(x, hard=True)
            #latent = dvae.sm_encode(x)
        
        optimizer.zero_grad()
        pred = clf(latent)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()

        iteration += 1
        
        if iteration % 55 == 0:
            acc = compute_accuracy(dvae, clf, test_loader, device=CONFIG.DEVICE)
            print("Epoch: {} Iter: {} Loss: {} Test Accuracy: {}".format(
                epoch, iteration, loss.item(), acc))

Device in use: cuda
Epoch: 0 Iter: 55 Loss: 2.302194833755493 Test Accuracy: 0.1016
Epoch: 0 Iter: 110 Loss: 1.7957141399383545 Test Accuracy: 0.3501
Epoch: 1 Iter: 165 Loss: 0.8636090755462646 Test Accuracy: 0.6912
Epoch: 1 Iter: 220 Loss: 0.46149003505706787 Test Accuracy: 0.862
Epoch: 2 Iter: 275 Loss: 0.2852141261100769 Test Accuracy: 0.9076
Epoch: 2 Iter: 330 Loss: 0.21170789003372192 Test Accuracy: 0.9173
Epoch: 3 Iter: 385 Loss: 0.24467302858829498 Test Accuracy: 0.9371
Epoch: 3 Iter: 440 Loss: 0.14906670153141022 Test Accuracy: 0.9388
Epoch: 4 Iter: 495 Loss: 0.2537703514099121 Test Accuracy: 0.9465
Epoch: 4 Iter: 550 Loss: 0.14439043402671814 Test Accuracy: 0.9453
Epoch: 5 Iter: 605 Loss: 0.1485423892736435 Test Accuracy: 0.9463
Epoch: 5 Iter: 660 Loss: 0.1970188170671463 Test Accuracy: 0.9526
Epoch: 6 Iter: 715 Loss: 0.0956149473786354 Test Accuracy: 0.9523
Epoch: 6 Iter: 770 Loss: 0.12703555822372437 Test Accuracy: 0.9584
Epoch: 6 Iter: 825 Loss: 0.1273706555366516 Test Accu