In [1]:
import torch
from torch import optim, nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, is_last=False, use_bn=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity() 
        self.act = nn.Tanh() if is_last else nn.ReLU()       

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.act(out)

        return out

class Generator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed = nn.Embedding(10, 100)
        # self.fc = nn.Linear(100*2, 1024*4*4)
        self.first_conv = UpConvBlock(100*2, 1024, 4, 1, 0)
        self.conv1 = UpConvBlock(1024, 512, 4, 2, 1, use_bn=True)
        self.conv2 = UpConvBlock(512, 256, 4, 2, 1, use_bn=True)
        self.conv3 = UpConvBlock(256, 128, 4, 2, 1, use_bn=True)
        self.conv4 = UpConvBlock(128, 1, 4, 2, 1, use_bn=False, is_last=True)

    def forward(self, x, y):
        x = torch.cat([x, self.embed(y)], dim=1)

        # out = self.fc(x)
        # out = out.view(-1, 1024, 4, 4)

        x = x.unsqueeze(2).unsqueeze(3)
        out = self.first_conv(x)
        
        out = self.conv1(out)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)

        return out
    
x = torch.zeros(10, 100)
y = torch.arange(10)
model = Generator()
output = model(x, y)
output.shape

torch.Size([10, 1, 64, 64])

In [4]:
class DownConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_bn=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.InstanceNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = nn.LeakyReLU(2e-1)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.act(out)
        
        return out
    
class Discriminator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed = nn.Embedding(10, 1*64*64)
        self.conv1 = DownConvBlock(1*2, 128, 4, 2, 1, use_bn=False)
        self.conv2 = DownConvBlock(128, 256, 4, 2, 1, use_bn=True)
        self.conv3 = DownConvBlock(256, 512, 4, 2, 1, use_bn=True)
        self.conv4 = DownConvBlock(512, 1024, 4, 2, 1, use_bn=True)
        self.last_conv = nn.Conv2d(1024, 1, 4, 2, 0)

    def forward(self, x, y):
        x = torch.cat([x, self.embed(y).view(-1, 1, 64, 64)], dim=1)

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)

        out = self.last_conv(out)
        out = out.view(-1, 1)

        return out
    
x = torch.zeros(10, 1, 64, 64)
y = torch.arange(10)
model = Discriminator()
output = model(x, y)
output.shape

torch.Size([10, 1])

In [5]:
def init_weights(model):
    for m in model.modules():
        if type(m) in {
            nn.Linear,
            nn.Conv2d,
            nn.ConvTranspose2d,
            nn.BatchNorm2d,
            nn.Embedding,
        }:
            nn.init.normal_(m.weight, mean=0, std=2e-2)

In [6]:
model_d = Discriminator().to(device)
model_g = Generator().to(device)
init_weights(model_d), init_weights(model_g)

lr = 1e-4
beta1 = 0.5
beta2 = 0.9
optimizer_d = optim.Adam(model_d.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_g = optim.Adam(model_g.parameters(), lr=lr, betas=(beta1, beta2))

img_size = 64
num_channels = 1
transform = transforms.Compose(
    [
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(num_channels)], [0.5 for _ in range(num_channels)]
        )
    ]
)
train_data = datasets.MNIST(root='mnist/train', train=True, transform=transform, download=False)
val_data = datasets.MNIST(root='mnist/val', train=False, transform=transform, download=False)

noise_dim = 100
num_classes = 10
fixed_noise = torch.randn((num_classes, noise_dim)).to(device)
fixed_label = torch.arange(num_classes).to(device)

writer_fake = SummaryWriter('logs/fake')
writer_real = SummaryWriter('logs/real')

In [7]:
class Trainer:
    def __init__(
            self,
            optimizer_d,
            optimizer_g,
            model_d,
            model_g,
            penalty_coeff,
            device = device
    ):
        self.optimizer_d = optimizer_d
        self.optimizer_g = optimizer_g
        self.model_d = model_d
        self.model_g = model_g
        self.penalty_coeff = penalty_coeff
        self.device = device


    def calc_grad_penalty(self, real, fake, label):
        batch_size, num_channels, height, width = real.shape
        epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, num_channels, height, width).to(self.device)

        joint_distribution = epsilon*real + (1-epsilon)*fake
        critic_term = self.model_d(joint_distribution, label)

        gradient = torch.autograd.grad(
            outputs = critic_term,
            inputs = joint_distribution,
            grad_outputs = torch.ones_like(critic_term),
            retain_graph = True,
            create_graph = True,
        )[0].view(batch_size, -1)

        l2_norm = torch.norm(gradient, p=2, dim=1)
        grad_penalty = torch.mean((l2_norm - 1)**2)        

        return grad_penalty


    def calc_disc_loss(self, real, label, noise, is_train):
        with torch.set_grad_enabled(is_train):
            real_arg = self.model_d(real, label)

            fake = self.model_g(noise, label)
            fake_arg = self.model_d(fake, label)

            grad_penalty = self.calc_grad_penalty(real, fake, label)

            loss_d = (torch.mean(fake_arg) - torch.mean(real_arg)) + self.penalty_coeff*grad_penalty

        return loss_d


    def calc_gen_loss(self, noise, label, is_train):
        with torch.set_grad_enabled(is_train):
            fake = self.model_g(noise, label)
            fake_arg = self.model_d(fake, label)
            loss_g = - torch.mean(fake_arg)

        return loss_g


    def calc_metrics(self, metrics_dict, train_loader, val_loader):
        self.model_d.eval(), self.model_g.eval()

        with torch.no_grad():
            final_str = ''
            loaders_list = [('Train', train_loader), ('Val', val_loader)]
            
            if metrics_dict == None:
                metrics_dict = {'Train': {'DiscLoss': [], 'GenLoss': []}, 'Val': {'DiscLoss': [], 'GenLoss': []}}

            for name, loader in loaders_list:
                len_data = 0
                total_loss_d = 0
                total_loss_g = 0

                for real, label in loader:
                    real, label = real.to(self.device), label.to(self.device)

                    batch_size = real.shape[0]
                    len_data += batch_size

                    noise = torch.randn((batch_size, noise_dim)).to(self.device)

                    loss_d = self.calc_disc_loss(real, label, noise, is_train=False)
                    total_loss_d += loss_d

                    loss_g = self.calc_gen_loss(noise, label, is_train=False)
                    total_loss_g += loss_g
                
                disc_loss = total_loss_d/len_data
                gen_loss = total_loss_g/len_data

                final_str += ' -- {} Disc Loss: {:.5f} -- {} Gen Loss: {:.5f}'.format(name, disc_loss, name, gen_loss)
                
                metrics_dict[name]['DiscLoss'].append(disc_loss.item())
                metrics_dict[name]['GenLoss'].append(gen_loss.item()) 

        self.model_d.train(), self.model_g.train()

        return final_str, metrics_dict


    def visualize_tensorboard(self, fixed_noise, fixed_label, epoch):
        with torch.no_grad():
            fake = self.model_g(fixed_noise, fixed_label)
            fake_images = make_grid(fake, nrow=5, normalize=True)
            writer_fake.add_image('Fake', fake_images, global_step=epoch)

        return None
    

    def fit(self, n_epochs, n_disc, train_loader, val_loader):
        self.model_d.train(), self.model_g.train()

        metrics_dict = None
        steps = 1
        for epoch in range(1, n_epochs+1):
            for batch_idx, (real, label) in enumerate(train_loader):
                real, label = real.to(self.device), label.to(self.device)

                batch_size = real.shape[0]

                for _ in range(n_disc):
                    noise = torch.randn(batch_size, noise_dim).to(self.device)
                    loss_d = self.calc_disc_loss(real, label, noise, is_train=True)

                    self.optimizer_d.zero_grad()
                    loss_d.backward()
                    self.optimizer_d.step()
                
                noise = torch.randn(batch_size, noise_dim).to(self.device)
                loss_g = self.calc_gen_loss(noise, label, is_train=True)

                self.optimizer_g.zero_grad()
                loss_g.backward()
                self.optimizer_g.step()

                if batch_idx % 20 == 0:
                    print(f'Epoch: {epoch:2d}/{n_epochs} -- Batch: {batch_idx+1:3d}/{len(train_loader)}' + f' -- Train Disc Loss: {loss_d:.4f} -- Train Gen Loss: {loss_g:.4f}')
                    self.visualize_tensorboard(fixed_noise, fixed_label, steps)
                    steps += 1
            
            # if epoch == 1 or epoch%2 == 0:
            #     final_str, metrics_dict = self.calc_metrics(metrics_dict, train_loader, val_loader)
            #     print('Epoch: {:2d}'.format(epoch) + final_str)
            
        self.metrics_dict = metrics_dict

        return None

In [8]:
batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [9]:
trainer = Trainer(optimizer_d, optimizer_g, model_d, model_g, penalty_coeff=10, device=device)
trainer.fit(2, 2, train_loader, val_loader)

Epoch:  1/2 -- Batch:   1/469 -- Train Disc Loss: 550.3231 -- Train Gen Loss: 2.2336
Epoch:  1/2 -- Batch:  21/469 -- Train Disc Loss: -49.2235 -- Train Gen Loss: 45.8081
Epoch:  1/2 -- Batch:  41/469 -- Train Disc Loss: -100.4171 -- Train Gen Loss: 75.9870
Epoch:  1/2 -- Batch:  61/469 -- Train Disc Loss: -104.1197 -- Train Gen Loss: 83.9225
Epoch:  1/2 -- Batch:  81/469 -- Train Disc Loss: -102.6078 -- Train Gen Loss: 87.2601
Epoch:  1/2 -- Batch: 101/469 -- Train Disc Loss: -98.0380 -- Train Gen Loss: 89.7384
Epoch:  1/2 -- Batch: 121/469 -- Train Disc Loss: -92.2195 -- Train Gen Loss: 87.9884
Epoch:  1/2 -- Batch: 141/469 -- Train Disc Loss: -88.5782 -- Train Gen Loss: 89.5317
