In [1]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm


In [2]:
myseed = 6666  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

In [3]:
source_transform = transforms.Compose(
    [
        # Turn RGB to grayscale. (Bacause Canny do not support RGB images.)
        transforms.Grayscale(),
        # cv2 do not support skimage.Image, so we transform it to np.array,
        # and then adopt cv2.Canny algorithm.
        transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
        # Transform np.array back to the skimage.Image.
        transforms.ToPILImage(),
        # 50% Horizontal Flip. (For Augmentation)
        transforms.RandomHorizontalFlip(),
        # Rotate +- 15 degrees. (For Augmentation), and filled with zero
        # if there's empty pixel after rotation.
        transforms.RandomRotation(15, fill=(0,)),
        # Transform to tensor for model inputs.
        transforms.ToTensor(),
    ]
)
target_transform = transforms.Compose(
    [
        # Turn RGB to grayscale.
        transforms.Grayscale(),
        # Resize: size of source data is 32x32, thus we need to
        #  enlarge the size of target data from 28x28 to 32x32。
        transforms.Resize((32, 32)),
        # 50% Horizontal Flip. (For Augmentation)
        transforms.RandomHorizontalFlip(),
        # Rotate +- 15 degrees. (For Augmentation), and filled with zero
        # if there's empty pixel after rotation.
        transforms.RandomRotation(15, fill=(0,)),
        # Transform to tensor for model inputs.
        transforms.ToTensor(),
    ]
)

source_dataset = ImageFolder("real_or_drawing/train_data", transform=source_transform)
target_dataset = ImageFolder("real_or_drawing/test_data", transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)


## Model

In [5]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        f = self.feature_extractor(x)
        y = self.fc(f.squeeze())

        return y, f

In [7]:
classifier = Classifier().cuda()

classifier.load_state_dict(torch.load("weights/DALN.bin"))

<All keys matched successfully>

## Pseudo Labeling

In [8]:
threshold = 0.98

pseudo_data, pseudo_label = torch.LongTensor([]).cuda(), torch.LongTensor([]).cuda()

softmax = nn.Softmax(dim=1)
classifier.eval()

for (data, label) in tqdm(test_dataloader, desc="Pseudo Labeling"):
    with torch.no_grad():
        data = data.cuda()
        label = label.cuda()
        c, _ = classifier(data)
        c = softmax(c)
        c, label = torch.max(c, 1)
        mask = c > threshold
        pseudo_data = torch.cat([pseudo_data, data[mask]], dim=0)
        pseudo_label = torch.cat([pseudo_label, label[mask]], dim=0)

print("\nPseudo-labeling finished, %d samples generated." % len(pseudo_data))

Pseudo Labeling: 100%|██████████| 782/782 [10:44<00:00,  1.21it/s]


Pseudo-labeling finished, 79525 samples generated.





In [9]:
print("Saving...")
print(f"pseudo_label: {pseudo_label.shape}")
print(f"pseudo_data: {pseudo_data.shape}")

np.save("DALN_pseudo_label.npy", pseudo_label.cpu().numpy())
np.save("DALN_pseudo_data.npy", pseudo_data.cpu().numpy())

Saving...
pseudo_label: torch.Size([79525])
pseudo_data: torch.Size([79525, 1, 32, 32])
