# Drawn Apart

Solution author: Walnit

# Dataset

In [None]:
# ! kaggle competitions download -c drawn-apart-aicc-round-3
# ! 7z x drawn-apart-aicc-round-3.zip
# ! mv task_data/task_data/* task_data/

# Author's Solution

This solution builds on the approach that the paper that introduced the dataset which this challenges draws its data from (DomainNet), called *M3SDA* - Moment-Matching in Multi Source Domain Adaptation.

The gist of the idea is to make use of the fact that we are given **multiple sources**. Thus, we fine tune a feature extractor to *align the features of multiple domains*, and classify using those aligned domains.

Note: this solutions aims to illustrate M3SDA. It was not modified excessively to give maximum performance. Please experiment with the code below to see whether you can further improve the performance!

### Loading the dataset

In [3]:
import torchvision
from torchvision.datasets import ImageFolder

In [4]:
cartoon_ds = ImageFolder(root="/kaggle/input/drawn-apart-aicc-round-3/task_data/task_data/cartoon")
photo_ds = ImageFolder(root="/kaggle/input/drawn-apart-aicc-round-3/task_data/task_data/photograph")
sketch_ds = ImageFolder(root="/kaggle/input/drawn-apart-aicc-round-3/task_data/task_data/sketch")

We'll stick to using ResNet34 for the purposes of illustrating M3SDA. You may experiment with different backbones.  

We also split the datasets instead of combining them as in the baseline.

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset, DataLoader, random_split
from torchvision.transforms import v2

from torchvision import transforms

# Define the transformations for the images
transform = v2.Compose([
    v2.ToImage(),
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


cartoon_ds.transform = transform
photo_ds.transform = transform
sketch_ds.transform = transform

cartoon_loader = DataLoader(cartoon_ds, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True)
photo_loader = DataLoader(photo_ds, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True)
sketch_loader = DataLoader(sketch_ds, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True)

M3SDA trains somewhat similarly to a GAN, with multiple stages of its algorithm. Thus, training may take a bit longer, but the concept is theoretically sound.  

Figure 3 of the M3SDA paper illustrates the architecture best. The only models involved are a feature extractor, and *two* classifiers (I'll explain in a minute!) To maximize simplicity, I'll use the ResNet34 from the baseline, and a single Linear layer as the classifier.

In [6]:
import torchvision.models as models

backbone = models.resnet34(weights='DEFAULT')
backbone.fc = nn.Identity()

num_classes = len(cartoon_ds.classes)
classifier_1 = nn.Linear(512, num_classes)
classifier_2 = nn.Linear(512, num_classes)

device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = backbone.to(device)
classifier_1 = classifier_1.to(device)
classifier_2 = classifier_2.to(device)

We set up the optimizers here. Weight decay follows the value recommended in the paper, but experiment if changing the value boosts the score.

In [7]:
optim_backbone = torch.optim.AdamW(backbone.parameters(), 3e-4, weight_decay=5e-4)
optim_c1 = torch.optim.AdamW(classifier_1.parameters(), 3e-4, weight_decay=5e-4)
optim_c2 = torch.optim.AdamW(classifier_2.parameters(), 3e-4, weight_decay=5e-4)

From here, let's walk through one iteration of the algorithm to illustrate how it works.  

We start by sampling one batch of data from all domains, and get their labels if possible.

In [8]:
sample_batch_cartoon = next(iter(cartoon_loader))
sample_batch_photo = next(iter(photo_loader))
sample_batch_sketch = next(iter(sketch_loader))

sample_cartoon_imgs = sample_batch_cartoon[0].to(device)
sample_cartoon_labels = sample_batch_cartoon[1].to(device)
sample_photo_imgs = sample_batch_photo[0].to(device)
sample_photo_labels = sample_batch_photo[1].to(device)
sample_sketch_imgs = sample_batch_sketch[0].to(device)

For all domains, extract their features using the feature extractor.

In [9]:
feat_cartoon = backbone(sample_cartoon_imgs)
feat_photo = backbone(sample_photo_imgs)
feat_sketch = backbone(sample_sketch_imgs)

Then, the classifiers we defined earlier, predict on *the domains we have labels for*.

In [10]:
c1_cartoon = classifier_1(feat_cartoon)
c2_cartoon = classifier_2(feat_cartoon)
c1_photo = classifier_1(feat_photo)
c2_photo = classifier_2(feat_photo)

Now, we calculate the loss for this first phase of the algorithm.

The objective for this phase rather simple - perform supervised learning on the labelled domain, using typical cross-entropy loss. However, the secret sauce is the mysterious `msda_regularizer` below. What is it, and how does it make this approach different?  

First, we must understand what a *moments* are. In statistics, moments describe data. More specifically, the first, second, and third moments represent the mean, variance, and skewness of the dataset.

To encourage domain-invariant features, we first center the features by subtracting the mean. Next, we penalise the euclidean distance between the features in order to encourage them to be similar across the domains. Finally, we repeat this over the other two moments, which results in a cost function that allows for features across domains to share similar structure in their distribution.  

Mathematically inclined folks should read the proof in the paper for a clearer, more accurate description.

In [11]:
def euclidean(x1,x2):
	return ((x1-x2)**2).sum().sqrt()

def k_moment(feat_1, feat_2, feat_3, k):
	feat_1 = (feat_1**k).mean(0)
	feat_2 = (feat_2**k).mean(0)
	feat_3 = (feat_3**k).mean(0)
	return euclidean(feat_1, feat_2) + euclidean(feat_1, feat_3) + euclidean(feat_2, feat_3)

def msda_regulizer(feat_1, feat_2, feat_3, belta_moment):
    feat_1_mean = feat_1.mean(0)
    feat_2_mean = feat_2.mean(0)
    feat_3_mean = feat_3.mean(0)

    feat_1 = feat_1 - feat_1_mean
    feat_2 = feat_2 - feat_2_mean
    feat_3 = feat_3 - feat_3_mean

    moment1 = euclidean(feat_1, feat_2) + euclidean(feat_1, feat_3) + euclidean(feat_2, feat_3)
    reg_info = moment1

    for i in range(belta_moment-1):
        reg_info += k_moment(feat_1, feat_2, feat_3, i+2)

    return reg_info

In [12]:
loss_msda = 0.0002*msda_regulizer(feat_cartoon, feat_photo, feat_sketch, 5) # NOTE HPARAM HERE, PLEASE EXPERIMENT WITH THIS!
loss_c1 = nn.functional.cross_entropy(c1_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c1_photo, sample_photo_labels)
loss_c2 = nn.functional.cross_entropy(c2_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c2_photo, sample_photo_labels)

loss = loss_c1 + loss_c2 + loss_msda

And, with loss calculated, all we have to do is backpropagate! Here, we backpropagate on all three models.

In [13]:
loss.backward()

optim_backbone.step()
optim_c1.step()
optim_c2.step()
optim_backbone.zero_grad()
optim_c1.zero_grad()
optim_c2.zero_grad()

Now, we enter phase two of the algorithm. The first step is to perform the exact same steps as above, to get our *supervised loss* - as the terms are only calculated from the labelled data.

In [14]:
feat_cartoon = backbone(sample_cartoon_imgs)
feat_photo = backbone(sample_photo_imgs)
feat_sketch = backbone(sample_sketch_imgs)
c1_cartoon = classifier_1(feat_cartoon)
c2_cartoon = classifier_2(feat_cartoon)
c1_photo = classifier_1(feat_photo)
c2_photo = classifier_2(feat_photo)
loss_msda = 0.0002*msda_regulizer(feat_cartoon, feat_photo, feat_sketch, 5) # NOTE HPARAM HERE
loss_c1 = nn.functional.cross_entropy(c1_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c1_photo, sample_photo_labels)
loss_c2 = nn.functional.cross_entropy(c2_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c2_photo, sample_photo_labels)

loss_s = loss_c1 + loss_c2 + loss_msda

The second step is something new - we classify the features of the unlabelled data using the classifiers separately. We then calculate the discrepancy between the classification of the two classifiers. We call this the *discrepancy loss*.

In [15]:
c1_sketch = classifier_1(feat_sketch)
c2_sketch = classifier_2(feat_sketch)
loss_d = torch.mean(torch.abs(nn.functional.softmax(c1_sketch, dim=-1) - nn.functional.softmax(c2_sketch, dim=-1)))

Finally, we *maximize* the discrepancy loss (note - instead of +) instead of minimizing it.  

Why? It turns out that if you only train to align the domains, i.e. only do step 1, the feature extractor has no exposure to the target domain. This causes features that a similar in nature, or shared between various classes, to be misclassified.  

Thus, in this step, we maximize the discrepancy between the two classifiers. In an ideal world, they should have the same outputs, yet they vary due to stochasticity. This is perfect for us, as it highlights the confusion present in the feature extractor, which is something that we will address in the third phase.  

In line with this explanation, we only optimize the two classifiers, and freeze the weights of the feature extractor. The combination of supervised and discrepancy losses ensure that the classifiers will not catastrophically forget everything it learned from the labelled domains.

In [16]:
loss = loss_s - loss_d
loss.backward()

optim_c1.step()
optim_c2.step()
optim_c1.zero_grad()
optim_c2.zero_grad()

We're finally at phase 3, which is simple - minimize this discrepancy by optimizing the feature extractor. It's as easy as freezing the classifiers and training the generator to minimize the discrepancy loss.  

Note that when training, this step takes longer to converge. Thus, we run this step a multiple times in one pass of the algorithm.

In [17]:
feat_sketch = backbone(sample_sketch_imgs)
c1_sketch = classifier_1(feat_sketch)
c2_sketch = classifier_2(feat_sketch)

loss_d = torch.mean(torch.abs(nn.functional.softmax(c1_sketch, dim=-1) - nn.functional.softmax(c2_sketch, dim=-1)))
loss_d.backward()
optim_backbone.step()
optim_backbone.zero_grad()

And... we're done! Let's clean our instance up a bit and ready the training loop.

In [18]:
import gc

cartoon_loader = DataLoader(cartoon_ds, batch_size=256, shuffle=True, num_workers=1, persistent_workers=True, drop_last=True, pin_memory=True)
photo_loader = DataLoader(photo_ds, batch_size=256, shuffle=True, num_workers=1, persistent_workers=True, drop_last=True, pin_memory=True)
sketch_loader = DataLoader(sketch_ds, batch_size=256, shuffle=True, num_workers=1, persistent_workers=True, drop_last=True, pin_memory=True)

gc.collect()
torch.cuda.empty_cache()

Since this is a multi-step training algorithm, performance is bound to vary. Thus, it is important that we keep track of the best model during training, by testing it against the validation dataset every epoch.

In [19]:
import os
from PIL import Image
from torch.utils.data import Dataset

class SketchTestDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        filename = self.image_files[idx]
        img_path = os.path.join(self.root_dir, filename)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, filename

In [20]:
val_dataset = SketchTestDataset(root_dir='/kaggle/input/drawn-apart-aicc-round-3/task_data/task_data/sketch_val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=1, persistent_workers=True)

test_dataset = SketchTestDataset(root_dir='/kaggle/input/drawn-apart-aicc-round-3/task_data/task_data/sketch_test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, persistent_workers=True)

print(f"Number of images in val dataset: {len(val_dataset)}")
print(f"Number of images in test dataset: {len(test_dataset)}")

Number of images in val dataset: 200
Number of images in test dataset: 4624


In [21]:
import pandas as pd
from sklearn.metrics import f1_score

val_solution_df = pd.read_csv('/kaggle/input/drawn-apart-aicc-round-3/task_data/task_data/val.csv')

Now, the below code looks like an absolute behemoth. However, it's just the three steps above concatenated, as well as some validation code.  

We run it for 30 epochs as its the number of epochs that runs in approximately an hour - not too long, but still effective.

In [None]:
from tqdm.auto import tqdm
import time

num_epochs = 30
num_steps_per_epoch = min(len(sketch_loader), len(cartoon_loader), len(photo_loader))
scaler = torch.amp.GradScaler()
best_score = 0

for epoch in range(num_epochs):
    cartoon_iter = iter(cartoon_loader)
    photo_iter = iter(photo_loader)
    sketch_iter = iter(sketch_loader)
    print("Epoch", epoch)
    backbone.train()
    classifier_1.train()
    classifier_2.train()
    
    for step in tqdm(range(num_steps_per_epoch)):
        start = time.time()
        sample_batch_cartoon = next(cartoon_iter)
        sample_batch_photo = next(photo_iter)
        sample_batch_sketch = next(sketch_iter)
        sample_cartoon_imgs = sample_batch_cartoon[0].to(device, non_blocking=True)
        sample_cartoon_labels = sample_batch_cartoon[1].to(device, non_blocking=True)
        sample_photo_imgs = sample_batch_photo[0].to(device, non_blocking=True)
        sample_photo_labels = sample_batch_photo[1].to(device, non_blocking=True)
        sample_sketch_imgs = sample_batch_sketch[0].to(device, non_blocking=True)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            feat_cartoon = backbone(sample_cartoon_imgs)
            feat_photo = backbone(sample_photo_imgs)
            feat_sketch = backbone(sample_sketch_imgs)

            c1_cartoon = classifier_1(feat_cartoon)
            c2_cartoon = classifier_2(feat_cartoon)
            c1_photo = classifier_1(feat_photo)
            c2_photo = classifier_2(feat_photo)
            c1_sketch = classifier_1(feat_sketch)
            c2_sketch = classifier_2(feat_sketch)

            loss_msda = 0.0002*msda_regulizer(feat_cartoon, feat_photo, feat_sketch, 5) # NOTE HPARAM HERE
            loss_c1 = nn.functional.cross_entropy(c1_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c1_photo, sample_photo_labels)
            loss_c2 = nn.functional.cross_entropy(c2_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c2_photo, sample_photo_labels)

            loss = loss_c1 + loss_c2 + loss_msda

        scaler.scale(loss).backward()
        scaler.step(optim_backbone)
        scaler.step(optim_c1)
        scaler.step(optim_c2)
        scaler.update()
        optim_backbone.zero_grad()
        optim_c1.zero_grad()
        optim_c2.zero_grad()

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            feat_cartoon = backbone(sample_cartoon_imgs)
            feat_photo = backbone(sample_photo_imgs)
            feat_sketch = backbone(sample_sketch_imgs)
            c1_cartoon = classifier_1(feat_cartoon)
            c2_cartoon = classifier_2(feat_cartoon)
            c1_photo = classifier_1(feat_photo)
            c2_photo = classifier_2(feat_photo)
            c1_sketch = classifier_1(feat_sketch)
            c2_sketch = classifier_2(feat_sketch)
            loss_msda = 0.0002*msda_regulizer(feat_cartoon, feat_photo, feat_sketch, 5) # NOTE HPARAM HERE
            loss_c1 = nn.functional.cross_entropy(c1_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c1_photo, sample_photo_labels)
            loss_c2 = nn.functional.cross_entropy(c2_cartoon, sample_cartoon_labels) + nn.functional.cross_entropy(c2_photo, sample_photo_labels)

            loss_s = loss_c1 + loss_c2 + loss_msda
            loss_d = torch.mean(torch.abs(nn.functional.softmax(c1_sketch, dim=-1) - nn.functional.softmax(c2_sketch, dim=-1)))

            loss = loss_s - loss_d

        scaler.scale(loss).backward()
        scaler.step(optim_c1)
        scaler.step(optim_c2)
        scaler.update()
        optim_c1.zero_grad()
        optim_c2.zero_grad()

        for i in range(4):
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                feat_sketch = backbone(sample_sketch_imgs)
                c1_sketch = classifier_1(feat_sketch)
                c2_sketch = classifier_2(feat_sketch)

                loss_d = torch.mean(torch.abs(nn.functional.softmax(c1_sketch, dim=-1) - nn.functional.softmax(c2_sketch, dim=-1)))
            scaler.scale(loss_d).backward()
            scaler.step(optim_backbone)
            scaler.update()
            optim_backbone.zero_grad()

    backbone.eval()
    classifier_1.eval()
    classifier_2.eval()
    all_predictions_1 = []
    all_predictions_2 = []
    all_filenames = []
    
    with torch.no_grad():
        for inputs, filenames in tqdm(val_loader, desc="Predicting on Test Data"):
            inputs = inputs.to(device)
            features = backbone(inputs)
            outputs_1 = classifier_1(features)
            outputs_2 = classifier_2(features)
            _, predicted_indices = torch.max(outputs_1, 1)
            all_predictions_1.extend(predicted_indices.cpu().numpy())
            _, predicted_indices = torch.max(outputs_2, 1)
            all_predictions_2.extend(predicted_indices.cpu().numpy())
            all_filenames.extend(filenames)

    class_names = cartoon_ds.classes
    predicted_class_names_1 = [class_names[idx] for idx in all_predictions_1]
    predicted_class_names_2 = [class_names[idx] for idx in all_predictions_2]
    
    val_submission_df = pd.DataFrame({
        'filename': all_filenames,
        'class_name': predicted_class_names_1
    })
    merged_df = pd.merge(val_submission_df, val_solution_df, on='filename', suffixes=('_predicted', '_true'))
    y_true = merged_df['class_name_true']
    y_pred = merged_df['class_name_predicted']
    score_1 = f1_score(y_true, y_pred, average='weighted')
    
    val_submission_df = pd.DataFrame({
        'filename': all_filenames,
        'class_name': predicted_class_names_2
    })
    merged_df = pd.merge(val_submission_df, val_solution_df, on='filename', suffixes=('_predicted', '_true'))
    y_true = merged_df['class_name_true']
    y_pred = merged_df['class_name_predicted']
    score_2 = f1_score(y_true, y_pred, average='weighted')

    score = max(score_1, score_2)
    
    print(f"F1 Score: {score_1, score_2}")
    if score > best_score:
        print("New best score! Saving model.")
        best_score = score
        torch.save({
            "backbone": backbone.state_dict(),
            "classifier_1": classifier_1.state_dict(),
            "classifier_2": classifier_2.state_dict()
        }, "best_models.tar")


Finally, we just have to predict on the test dataset. Load the weights back, and remember to test which classifier of the two is most effective!

In [None]:
weights = torch.load("best_models.tar")

In [None]:
backbone.load_state_dict(weights["backbone"])
classifier_1.load_state_dict(weights["classifier_1"])
classifier_2.load_state_dict(weights["classifier_2"])

In [None]:
backbone.eval()
classifier_1.eval()
all_predictions = []
all_filenames = []

with torch.no_grad():
    for inputs, filenames in tqdm(val_loader, desc="Predicting on Test Data"):
        inputs = inputs.to(device)
        outputs = classifier_1(backbone(inputs))
        _, predicted_indices = torch.max(outputs, 1)
        all_predictions.extend(predicted_indices.cpu().numpy())
        all_filenames.extend(filenames)

print("Predictions completed.")

In [None]:
backbone.eval()
classifier_2.eval()
all_predictions = []
all_filenames = []

with torch.no_grad():
    for inputs, filenames in tqdm(val_loader, desc="Predicting on Test Data"):
        inputs = inputs.to(device)
        outputs = classifier_2(backbone(inputs))
        _, predicted_indices = torch.max(outputs, 1)
        all_predictions.extend(predicted_indices.cpu().numpy())
        all_filenames.extend(filenames)

After choosing the classifier, we generate the submission CSV.

In [None]:
all_predictions = []
all_filenames = []

with torch.no_grad():
    for inputs, filenames in tqdm(test_loader, desc="Predicting on Test Data"):
        inputs = inputs.to(device)
        outputs = classifier_1(backbone(inputs))
        _, predicted_indices = torch.max(outputs, 1)
        all_predictions.extend(predicted_indices.cpu().numpy())
        all_filenames.extend(filenames)

print("Predictions completed.")

In [None]:
class_names = cartoon_ds.classes
predicted_class_names = [class_names[idx] for idx in all_predictions]

assert len(all_filenames) == len(predicted_class_names), "Mismatch in lengths of filenames and predicted class names!"

submission_df = pd.DataFrame({
    'filename': all_filenames,
    'class_name': predicted_class_names
})

submission_df.to_csv('submission.csv', index=False)

print("submission.csv created successfully.")
print(submission_df.head())

And... we're done! 

This code was tested and run on Kaggle's P100 GPU.