## Model. Resnet18

In [24]:
import torch.nn as nn
import time

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # because MNIST is already 1x1 here:
        # disable avg pooling
        #x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas
    
    def get_representation(self, x):
#         self.eval()
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
#         x = self.layer4(x)
        x = x.view(x.size(0), -1)
        return x



def resnet18(num_classes):
    """Constructs a ResNet-18 model."""
    model = ResNet(block=BasicBlock, 
                   layers=[2, 2, 2, 2],
                   num_classes=num_classes,
                   grayscale=True)
    return model


In [8]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
import numpy as np

train_dataset = MNIST(root='MNIST', download=True, train=True, transform=transforms.Compose([transforms.ToTensor()]))
val_dataset = MNIST(root='MNIST', download=True, train=False, transform=transforms.Compose([transforms.ToTensor()]))

batch_size = 64
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)



In [9]:
def compute_accuracy(model, data_loader, device):
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):
            
        features = features.to(device)
        targets = targets.to(device)

        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100
    
def train(model, optimizer, train_loader,  n_epochs=10):
    start_time = time.time()
    for epoch in range(n_epochs):

        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):

            features = features.to(device)
            targets = targets.to(device)

            ### FORWARD AND BACK PROP
            logits, probas = model(features)
            cost = F.cross_entropy(logits, targets)
            optimizer.zero_grad()

            cost.backward()

            ### UPDATE MODEL PARAMETERS
            optimizer.step()

            ### LOGGING
            if not batch_idx % 50:
                print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                       %(epoch+1, n_epochs, batch_idx, 
                         len(train_loader), cost))

        model.eval()
        with torch.set_grad_enabled(False): # save memory during inference
            print('Epoch: %03d/%03d | Train: %.3f%%' % (
                  epoch+1, n_epochs, 
                  compute_accuracy(model, train_loader, device=device)))

        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))


In [27]:
device = 'cuda'
model = resnet18(10)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train(model, optimizer, train_dataloader, n_epochs=2)

with torch.set_grad_enabled(False): # save memory during inference
    print('Test accuracy: %.2f%%' % (compute_accuracy(model, val_dataloader, device=device)))
    

Epoch: 001/002 | Batch 0000/0938 | Cost: 2.4674
Epoch: 001/002 | Batch 0050/0938 | Cost: 0.2613
Epoch: 001/002 | Batch 0100/0938 | Cost: 0.2263
Epoch: 001/002 | Batch 0150/0938 | Cost: 0.1247
Epoch: 001/002 | Batch 0200/0938 | Cost: 0.1589
Epoch: 001/002 | Batch 0250/0938 | Cost: 0.0675
Epoch: 001/002 | Batch 0300/0938 | Cost: 0.0403
Epoch: 001/002 | Batch 0350/0938 | Cost: 0.1786
Epoch: 001/002 | Batch 0400/0938 | Cost: 0.1176
Epoch: 001/002 | Batch 0450/0938 | Cost: 0.0367
Epoch: 001/002 | Batch 0500/0938 | Cost: 0.0800
Epoch: 001/002 | Batch 0550/0938 | Cost: 0.0328
Epoch: 001/002 | Batch 0600/0938 | Cost: 0.1290
Epoch: 001/002 | Batch 0650/0938 | Cost: 0.0583
Epoch: 001/002 | Batch 0700/0938 | Cost: 0.1797
Epoch: 001/002 | Batch 0750/0938 | Cost: 0.1221
Epoch: 001/002 | Batch 0800/0938 | Cost: 0.0414
Epoch: 001/002 | Batch 0850/0938 | Cost: 0.0158
Epoch: 001/002 | Batch 0900/0938 | Cost: 0.1115
Epoch: 001/002 | Train: 98.133%
Time elapsed: 0.29 min
Epoch: 002/002 | Batch 0000/0938 

In [28]:
torch.save(model.state_dict(), 'resnet18_fid_mnist.pth')

In [25]:
import torch.nn.functional as F
from models import UNet
device = 'cuda'
model = UNet(1, 32, (1, 2, 4), time_emb_dim=16)
T = 1000
from functools import partial
from schedule import noising_sch
import torch.nn as nn

noising_sch = partial(noising_sch, T=T)

from beta import KL, sufficient_stats, alpha_beta, get_dist

class BetaUnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = UNet(1, 32, (1, 2, 4), time_emb_dim=16)

    def forward(self, suff_stats, t):
        return F.sigmoid(self.model(suff_stats, t))
    

    def generate_alpha_normed(model, filename='beta_ssddpm_samples.png'):
        model.eval()
        x_t = torch.rand([batch_size, 1, 28, 28]).to(device)
        suff_stats = sufficient_stats(x_t, torch.tensor([T], device=device).repeat(batch_size))
        alphas = noising_sch(torch.tensor([T], device=device).repeat(batch_size))

        suff_stats_normed = suff_stats / alphas.reshape([-1, 1, 1, 1]).repeat([1, 1, 28, 28])

        to_plot = []

        samples_history = []

        for t in range(T, 1, -1):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x_0 = model(suff_stats_normed, t_batch)

            mu = noising_sch(t_batch)
            alphas += mu

            dist = get_dist(mu, x_0)
            x_t = dist.sample()
            samples_history.append(x_t)

            alphas += mu

            suff_stats += sufficient_stats(x_t, torch.tensor([t - 1], device=device).repeat(batch_size))

            suff_stats_normed = suff_stats / alphas.reshape([-1, 1, 1, 1]).repeat([1, 1, 28, 28])
            to_plot.append(suff_stats_normed[:4, :, 10, 10].squeeze())

        t_batch = torch.tensor([1], device=device).repeat(batch_size)
        x_0 = model(suff_stats_normed, t_batch)

        return x_0

gen_model = BetaUnet().to(device)
gen_model.load_state_dict(torch.load('beta_ddpm_the_best.pth'))


samples = gen_model.generate_alpha_normed()

In [26]:
from tqdm import tqdm
from scipy.linalg import sqrtm

def compute_fid_mnist(dataloader, gen_model, n_fake_samples=10000):
    device = 'cuda'
    model = resnet18(10)
    model.load_state_dict(torch.load('resnet18_fid_mnist.pth'))
    
    model.to(device)
    model.eval()
    
    true_dist = []
        
    for batch, _ in tqdm(iter(dataloader)):
        
        batch = batch.to(device)
        
        with torch.no_grad():
            representation = model.get_representation(batch).cpu()
            true_dist.append(representation)
        
    true_dist = torch.cat(true_dist, dim=0)
    mu_true = torch.mean(true_dist, dim=0)
    cov_true = torch.cov(true_dist.T)
    
    batch_size = 256
    gen_model.eval()
    
    fake_dist = []
    
    for i in tqdm(range(n_fake_samples // batch_size)):
        
        with torch.no_grad():
            batch = gen_model.generate_alpha_normed()
            representation = model.get_representation(batch).cpu()
            fake_dist.append(representation)
    
    fake_dist = torch.cat(fake_dist, dim=0)
    
    mu_fake = torch.mean(fake_dist, dim=0)
    cov_fake = torch.cov(fake_dist.T)
    
    conv_sqrtm = 2 * sqrtm(cov_true @ cov_fake)
    
    fid = torch.norm(mu_true - mu_fake) ** 2 + torch.trace(cov_true + cov_fake - conv_sqrtm)
    
    return fid

compute_fid_mnist(train_dataloader, gen_model)

100%|██████████████████████████████████████████████████████████████████████████| 938/938 [00:06<00:00, 152.45it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 39/39 [06:04<00:00,  9.36s/it]


tensor(71.3666-0.1368j)