In [1]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from rich.progress import (
    Progress,
    TextColumn,
    BarColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
import time

In [2]:
myseed = 6666  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

In [3]:
source_transform = transforms.Compose(
    [
        # Turn RGB to grayscale. (Bacause Canny do not support RGB images.)
        transforms.Grayscale(),
        # cv2 do not support skimage.Image, so we transform it to np.array,
        # and then adopt cv2.Canny algorithm.
        transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
        # Transform np.array back to the skimage.Image.
        transforms.ToPILImage(),
        # 50% Horizontal Flip. (For Augmentation)
        transforms.RandomHorizontalFlip(),
        # Rotate +- 15 degrees. (For Augmentation), and filled with zero
        # if there's empty pixel after rotation.
        transforms.RandomRotation(15, fill=(0,)),
        # Transform to tensor for model inputs.
        transforms.ToTensor(),
    ]
)
target_transform = transforms.Compose(
    [
        # Turn RGB to grayscale.
        transforms.Grayscale(),
        # Resize: size of source data is 32x32, thus we need to
        #  enlarge the size of target data from 28x28 to 32x32。
        transforms.Resize((32, 32)),
        # 50% Horizontal Flip. (For Augmentation)
        transforms.RandomHorizontalFlip(),
        # Rotate +- 15 degrees. (For Augmentation), and filled with zero
        # if there's empty pixel after rotation.
        transforms.RandomRotation(15, fill=(0,)),
        # Transform to tensor for model inputs.
        transforms.ToTensor(),
    ]
)

source_dataset = ImageFolder("real_or_drawing/train_data", transform=source_transform)
target_dataset = ImageFolder("real_or_drawing/test_data", transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=20, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=20, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)


## Models

In [4]:
# grl
from torch.autograd import Function


class GradientReverseFunction(Function):
    @staticmethod
    def forward(ctx, x, coeff):
        ctx.coeff = coeff
        output = x * 1.0
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.coeff, None


class GradientReverseLayer(nn.Module):
    def __init__(self):
        super(GradientReverseLayer, self).__init__()

    def forward(self, *input):
        return GradientReverseFunction.apply(*input)


class WarmStartGradientReverseLayer(nn.Module):
    def __init__(self, alpha=1.0, lo=0.0, hi=1.0, max_iters=1000.0, auto_step=True):
        super(WarmStartGradientReverseLayer, self).__init__()
        self.alpha = alpha
        self.lo = lo
        self.hi = hi
        self.iter_num = 0
        self.max_iters = max_iters
        self.auto_step = auto_step

    def forward(self, x):
        coeff = np.float64(2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) - (self.hi - self.lo) + self.lo)

        if self.auto_step:
            self.step()
        
        return GradientReverseFunction.apply(x, coeff)
    
    def step(self):
        self.iter_num += 1


In [5]:
# nwd
class NuclearWassersteinDiscrepancy(nn.Module):
    def __init__(self, classifier):
        super(NuclearWassersteinDiscrepancy, self).__init__()
        self.grl = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=1.0, max_iters=1000.0, auto_step=True)
        self.classifier = classifier

    @staticmethod
    def n_discrepancy(y_s, y_t):
        pre_s, pre_t = F.softmax(y_s, dim=1), F.softmax(y_t, dim=1)
        loss = (-torch.norm(pre_t, 'nuc') + torch.norm(pre_s, 'nuc')) / y_t.shape[0]
        return loss
    
    def forward(self, f):
        f_grl = self.grl(f)
        y = self.classifier(f_grl.squeeze())
        y_s, y_t = y.chunk(2, dim=0)

        loss = self.n_discrepancy(y_s, y_t)
        return loss


In [6]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        f = self.feature_extractor(x)
        y = self.fc(f.squeeze())

        return y, f


## Pre-processing

In [7]:
classifier = Classifier().cuda()
nwd = NuclearWassersteinDiscrepancy(classifier.fc).cuda()

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(classifier.parameters(), lr=0.005, momentum=0.9, weight_decay=1e-3)


## Start training

In [8]:
def train_epoch(source_dataloader, target_dataloader, progress, lamb):
    total_loss = 0.0
    batch_tqdm = progress.add_task(description=f"batch_progress", total=len(source_dataloader))

    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):
        source_data = source_data.cuda()
        source_label = source_label.cuda()
        target_data = target_data.cuda()
        
        mixed_data = torch.cat([source_data, target_data], dim=0)

        y, f = classifier(mixed_data)
        y_s = y[:source_label.shape[0], :]

        cls_loss = criterion(y_s, source_label)
        nwd_loss = -nwd(f)
        loss = cls_loss + lamb * nwd_loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        progress.advance(batch_tqdm, advance=1)

    progress.remove_task(batch_tqdm)
    return total_loss / (i + 1)


num_epochs = 500
best_loss = 1e9

with Progress(
    TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), TimeElapsedColumn()
) as progress:
    epoch_tqdm = progress.add_task(description="epoch progress", total=num_epochs)
    for epoch in range(num_epochs):
        lamb = (2 / (1 + np.exp(-10 * (epoch) / num_epochs))) - 1
        train_loss = train_epoch(source_dataloader, target_dataloader, progress, lamb=lamb)

        progress.advance(epoch_tqdm, advance=1)

        if train_loss < best_loss:
            best_loss = train_loss
            torch.save(classifier.state_dict(), f"weights/DALN.bin")
            print("epoch {:>3d}: train loss: {:6.4f}".format(epoch+1, train_loss))


Output()

## Inference

In [9]:
result = []
classifier.eval()

with Progress(
    TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), TimeElapsedColumn()
) as progress:
    test_tqdm = progress.add_task(description="inference progress", total=len(test_dataloader))
    for i, (test_data, _) in enumerate(test_dataloader):
        test_data = test_data.cuda()

        class_logits, _ = classifier(test_data)

        x = torch.argmax(class_logits, dim=1).cpu().detach().numpy()
        result.append(x)
        progress.advance(test_tqdm)

import pandas as pd

result = np.concatenate(result)

# Generate your submission
df = pd.DataFrame({"id": np.arange(0, len(result)), "label": result})
df.to_csv("DALN_submission.csv", index=False)


Output()