In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import ndimage

In [4]:
class BasicBlock(nn.Module):

    def __init__(self, in_channels, out_channels, batch_norm=False):
        super().__init__()

        self.batch_norm = nn.BatchNorm3d(in_channels) if batch_norm else None

        self.conv = nn.Sequential(                                                          # B I   H   L   W
            nn.Conv3d(in_channels, out_channels, 3, 1, padding=1),                          # B O   H   L   W
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(out_channels, out_channels, (3, 2, 2), (1, 2, 2), padding=(1, 0, 0)), # B O   H   L/2 W/2
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout3d(0.2, inplace=True)
        )
    
    def forward(self, x):
        # x     B   I   H   L   W

        x = self.batch_norm(x) if self.batch_norm else x    # B I   H   L   W
        out = self.conv(x)                                  # B O   H   L/2 W/2

        return out

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, channels, batch_norms=(False, True)):
        super().__init__()

        if len(channels) != 3:
            raise Exception('You should pass 3 channels as in, hidden and out channels')

        if len(batch_norms) != 2:
            raise Exception('You should pass 2 batch_norms for 2 basic layers')

        self.conv = nn.Sequential(                                  # B I   H   L   W
            BasicBlock(channels[0], channels[1], batch_norms[0]),   # B M   H   L/2 W/2
            BasicBlock(channels[1], channels[2], batch_norms[1]),   # B O   H   L/4 W/4
        )

        self.downsample = nn.Sequential(                            # B I   H   L   W
            nn.Conv3d(channels[0], channels[2], 3, 1, padding=1),   # B O   H   L   W
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool3d((3, 4, 4), (1, 4, 4), padding=(1, 0, 0)),  # B O   H   L/4 W/4
        )
    
    def forward(self, x):
        # x     B   I   H   L   W

        out = self.conv(x)                  # B O   H   L/4 W/4
        downsampled = self.downsample(x)    # B O   H   L/4 W/4

        out += downsampled                  # B O   H   L/4 W/4

        return out

In [33]:
class MaskDiscriminator(nn.Module):

    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(                          # B 2   H   64  64
            ResidualBlock((2, 4, 8), (False, True)),        # B 8   H   16  16
            ResidualBlock((8, 16, 32), (False, True)),      # B 32  H   4   4
            ResidualBlock((32, 64, 128), (False, True)),    # B 128 H   1   1
        )

        # H B   128
        self.lstm = nn.LSTM(128, 128, num_layers=2)         # 2 B   128 + 2 B   128 =>  B   512

        self.decider = nn.Sequential(                       # B 512
            nn.Linear(512, 256),                            # B 256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(256, 128),                            # B 128
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(128, 1),                              # B 1
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # x     B   2   H   64  64

        x = self.conv(x)                # B 128 H   1   1
        x = x.squeeze()                 # B 128 H
        x = x.permute(2, 0, 1)          # H B   128
        _, (h, c) = self.lstm(x)        # 2 B   128 + 2 B   128
        out = torch.cat((h, c))         # 4 B   128
        out = out.permute(1, 0, 2)      # B 4   128
        out = out.reshape(-1, 512)      # B 512
        out = self.decider(out)         # B 1
        out = out.squeeze()             # B

        return out

In [39]:
def get_batches(X, y=None, batch_size=128, shuffle=True):
    if y is not None:
        assert X.shape[0] == y.shape[0]

    num_batches = int(np.ceil(X.shape[0] * 1.0 / batch_size))

    if shuffle:
        indices = np.random.permutation(X.shape[0])
        X = X[indices]
        if y is not None:
            y = y[indices]

    for batch in range(num_batches):
        start = batch * batch_size
        end = min((batch + 1) * batch_size, X.shape[0])
        yield (batch, X[start:end], y[start:end]) if y is not None else (batch, X[start:end])

In [51]:
def train_discriminator(model, optimizer, X, y, batch_size):
    epoch_loss = 0

    model.train()

    for iter, b_X, b_y in get_batches(X, y, batch_size=batch_size):
        # b_X   B   2   H   64  64
        # b_y   B
        images = torch.tensor(b_X, device='cuda')
        target = torch.tensor(b_y, device='cuda')

        prediction = model(images)

        loss = F.binary_cross_entropy(prediction, target)

        epoch_loss += float(loss)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if (iter + 1) % 2 == 0:
            print(f'[Train] Iteration {iter + 1:3d} - loss: {epoch_loss / (iter + 1):.2e}')

    epoch_loss /= iter + 1
    return epoch_loss

In [52]:
def evaluate_discriminator(model, X, y, batch_size):
    epoch_loss = 0

    with torch.no_grad():
        model.eval()
        for iter, b_X in get_batches(X, y, batch_size=batch_size, shuffle=False):
            images = torch.tensor(b_X, device='cuda')
            target = torch.tensor(b_y, device='cuda')

            prediction = model(images)

            loss = F.binary_cross_entropy(prediction, target)

            epoch_loss += float(loss)

            if (iter + 1) % 2 == 0:
                print(f'[Valid] Iteration {iter + 1:3d} - loss: {epoch_loss / (iter + 1):.2e}')
        
    epoch_loss /= iter + 1
    return epoch_loss