In [1]:
# Source data train only
from src.data import source_train_loader, source_test_loader, target_train_loader, target_test_loader
from src.layers.torch_nn import Classifier

from src.layers.utils import freeze_layers
from src.eval import evaluate, evaluate_domain_cls
import torch
import torch.nn as nn
import torch.optim as optim
from src.layers.grl import grad_reverse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm 
import os
from itertools import cycle


from torchvision.models import ResNet18_Weights, resnet18
from src.layers.instance_model import InstancewiseVisualPrompt_v2
import torch.nn.functional as F

In [2]:
class BaseClassifier(nn.Module):
    def __init__(
        self,
        num_classes=10,
        imgsize=224,
        vr_blocks=2,
        attribute_layers=[5,6],
        patch_size=[16,32],
        attribute_channels=3,
    ):
        super(BaseClassifier, self).__init__()
        self.backbone = resnet18(ResNet18_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Identity()

        assert len(attribute_layers) == len(patch_size) == vr_blocks

        self.visual_prompt = nn.ModuleList([
            InstancewiseVisualPrompt_v2(
                imgsize, attribute_layers[idx], patch_size[idx], attribute_channels, dropout_p=0.5
        ) for idx in range(vr_blocks)])
        self.classifier_head = Classifier(
            in_dim=512, hidden_dim=256, out_dim=num_classes, num_res_blocks=2, dropout=0.5
        )
    
    def forward(self, x, output_type="logits"):
        for layers in self.visual_prompt:
            x = layers(x)
        feat = self.backbone(x)

        if output_type == "feat":
            return feat
        elif output_type == "logits":
            return self.classifier_head(feat)
        else:
            print(f"Not implemented output type {output_type}") 
    
    def mc_feature(self, x, mc_samples):
        self.train()
        B = x.size(0)
        D = 512  # resnet18 feature dim
        feat_sum = torch.zeros(B, D, device=x.device)
        for _ in range(mc_samples):
            feat = self.forward(x, output_type="feat")  # dropout on
            feat_sum += feat
        return feat_sum / mc_samples

    def mc_logits(self, x, mc_samples = 4, tau = 1):
        self.train()
        B = x.size(0)
        D = 10  # resnet18 feature dim
        p_mean = torch.zeros(B, D, device=x.device)

        for _ in range(mc_samples):
            logits = self.forward(x, output_type="logits") / tau 
            p_mean += F.softmax(logits, dim = -1)
        
        p_mean /= mc_samples
        
        ent = -(p_mean*(p_mean+1e-8).log()).sum(-1)
        return p_mean, ent

In [3]:
class FCDiscriminator_img(nn.Module):
    def __init__(self, input_dim, ndf1=256, ndf2=128):
        super(FCDiscriminator_img, self).__init__()

        self.ln1 = nn.Linear(input_dim, ndf1)
        self.ln2 = nn.Linear(ndf1, ndf2)
        self.ln3 = nn.Linear(ndf2, ndf2)
        self.classifier = nn.Linear(ndf2, 2)

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x = self.ln1(x)
        x = self.leaky_relu(x)
        x = self.ln2(x)
        x = self.leaky_relu(x)
        x = self.ln3(x)
        x = self.leaky_relu(x)
        x = self.classifier(x)
        return x

In [4]:
epochs = 10

stu_model = BaseClassifier(vr_blocks=1, patch_size=[32], attribute_layers = [6])
# teacher model has two VR blocks for both source/target domains
tch_model = BaseClassifier(vr_blocks=1, patch_size=[32], attribute_layers = [6])
domain_classifier = FCDiscriminator_img(input_dim=512)
device = torch.device("cuda:0")
ckpt = torch.load("checkpoints/best_model_v12.pth")

stu_model.load_state_dict(ckpt)
stu_model = stu_model.to(device)
tch_model.load_state_dict(ckpt)
tch_model =tch_model.to(device)

domain_classifier = domain_classifier.to(device)

layers = (
    [param for name, param in stu_model.named_parameters() if "visual_prompt" not in name] +
    [param for name, param in tch_model.named_parameters() if "visual_prompt" not in name] 
)

optimizer = optim.AdamW(layers, lr=0.01)
vr_layers = list(stu_model.visual_prompt.parameters()) + list(tch_model.visual_prompt.parameters()) + list(domain_classifier.parameters())
optimizer_vr = optim.AdamW(vr_layers, lr=0.001)
criterion_class = nn.CrossEntropyLoss()
criterion_domain = nn.CrossEntropyLoss()



In [5]:
def compute_alpha(u, t_u):
    # u: [B], entropy
    mask = (u <= t_u).float()
    w = torch.exp(-u) * mask
    # normalize within batch
    return w / (w.sum() + 1e-12) * u.numel()

In [6]:
log_dir = os.path.join("runs", "da_exp_v12_phase_2")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)

mc_samples = 4
tau = 0.8
t_u = 0.3
#training script
best_test_acc = 0
freeze_layers([stu_model.backbone, tch_model.backbone])
total_steps = epochs * len(source_train_loader)
for epoch in range(epochs):
    tgt_cycle = cycle(target_train_loader)
    stu_model.train()
    tch_model.train()
    domain_classifier.train()
    running_loss = 0.0

    pbar = tqdm(source_train_loader, total=len(source_train_loader), desc=f"Epoch {epoch+1}", ncols=100)
    
    for batch_idx, source_data in enumerate(pbar):
        pbar.set_description_str(f"Epoch {epoch+1}", refresh=True)
        target_data = next(tgt_cycle)
        current_step = epoch * len(source_train_loader) + batch_idx

        src_q_data, src_k_data, src_labels = source_data
        tgt_q_data, tgt_k_data, _ = target_data

        src_img = src_k_data.to(device)
        src_labels = src_labels.to(device)
        tgt_img = tgt_k_data.to(device)

        optimizer.zero_grad()
        optimizer_vr.zero_grad()

        p_s, u_s = tch_model.mc_logits(src_img, mc_samples, tau)
        p_t, u_t = stu_model.mc_logits(tgt_img, mc_samples, tau)

        loss_cls = criterion_class(p_s, src_labels)
        loss_uncertainty = (u_s.mean() - u_t.mean()).pow(2)
        del p_s, p_t
        torch.cuda.empty_cache()

        p = current_step / total_steps
        alpha = 2.0 / (1.0 + np.exp(-10 * p)) - 1.0

        # Domain loss
        f_s = tch_model.mc_feature(src_img, mc_samples)
        f_t = stu_model.mc_feature(tgt_img, mc_samples)
        f_t_rvs = grad_reverse(f_t, alpha)

        alpha_dis_s = compute_alpha(u_s, t_u)
        alpha_dis_t = compute_alpha(u_t, t_u)
        
        domain_src_logits = domain_classifier(f_t.detach())
        loss_dis_src_img = criterion_domain(domain_src_logits, torch.zeros(f_t.size(0), dtype=torch.long).to(device)) * alpha_dis_s

        domain_tgt_logits = domain_classifier(f_t_rvs.detach())
        loss_dis_tgt_img = criterion_domain(domain_tgt_logits, torch.ones(f_t_rvs.size(0), dtype=torch.long).to(device)) * alpha_dis_t

        loss_dis = loss_dis_src_img.mean() + loss_dis_tgt_img.mean()
        
        loss = loss_cls + 0.25*loss_uncertainty + loss_dis

        loss.backward()
        optimizer.step()
        optimizer_vr.step()

        # update_ema(tch_model.visual_prompt[0],stu_model.visual_prompt, decay=0.9996)

        writer.add_scalar("DA/Train Cls loss", loss_cls.item(), current_step)
        writer.add_scalar("DA/Train Dis loss", loss_dis.item(), current_step)
        writer.add_scalar("DA/Train Unt loss", loss_uncertainty.item(), current_step)
        writer.add_scalar("DA/Train BatchLoss", loss.item(), current_step)

    test_loss_src, test_accuracy_src = evaluate(tch_model, test_loader=source_test_loader, device=device)
    test_loss_tgt, test_accuracy_tgt = evaluate(tch_model, test_loader=target_test_loader, device=device)
    # test_loss_tgt_wo_vr, test_accuracy_tgt_wo_vr = evaluate_wo_vr(tch_model, test_loader=target_test_loader, device=device)
    test_loss_domain, test_acc_domain = evaluate_domain_cls(source_test_loader, target_test_loader, tch_model, stu_model, domain_classifier, device)

    writer.add_scalar("Source/Test EpochLoss", test_loss_src, epoch)
    writer.add_scalar("Source/Test Accuracy", test_accuracy_src, epoch)

    writer.add_scalar("Target/Test EpochLoss", test_loss_tgt, epoch)
    writer.add_scalar("Target/Test Accuracy", test_accuracy_tgt, epoch)

    writer.add_scalar("Domain/Test EpochLoss", test_loss_domain, epoch)
    writer.add_scalar("Domain/Test Accuracy", test_acc_domain, epoch)

    # writer.add_scalar("Target/Test EpochLoss (w/o VR)", test_loss_tgt_wo_vr, epoch)
    # writer.add_scalar("Target/Test Accuracy (w/o VR)", test_accuracy_tgt_wo_vr, epoch)    

    print(
        f"Epoch [{epoch + 1}/{epochs}] Test Loss Source: {test_loss_src:.4f}, Test Accuracy Source: {test_accuracy_src:.2f}%"
    )
    print(
        f"Epoch [{epoch + 1}/{epochs}] Test Loss Target: {test_loss_tgt:.4f}, Test Accuracy Target: {test_accuracy_tgt:.2f}%"
    )
    print(
        f"Epoch [{epoch + 1}/{epochs}] Test Loss Domain: {test_loss_domain:.4f}, Test Accuracy Domain: {test_acc_domain:.2f}%"
    )
    # print(
    #     f"Epoch [{epoch + 1}/{epochs}] Test Loss Target (w/o VR): {test_loss_tgt_wo_vr:.4f}, Test Accuracy Target (w/o VR): {test_accuracy_tgt_wo_vr:.2f}%"
    # )

    # Save the best model based on test accuracy.
    if  test_accuracy_src > best_test_acc:
        best_test_acc = test_accuracy_src
        best_checkpoint_path = os.path.join("checkpoints", "best_model_v7_1.pth")
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(tch_model.state_dict(), best_checkpoint_path)
        print(
            f"Epoch [{epoch + 1}]: New best model saved with test accuracy: {test_accuracy_src:.2f}%"
        )

Epoch 1:   0%|                                                             | 0/1875 [00:00<?, ?it/s]

Epoch 1: 100%|██████████████████████████████████████████████████| 1875/1875 [13:51<00:00,  2.25it/s]


Epoch [1/10] Test Loss Source: 0.4142, Test Accuracy Source: 96.83%
Epoch [1/10] Test Loss Target: 11.9818, Test Accuracy Target: 46.14%
Epoch [1/10] Test Loss Domain: 0.6932, Test Accuracy Domain: 0.47%
Epoch [1]: New best model saved with test accuracy: 96.83%


Epoch 2: 100%|██████████████████████████████████████████████████| 1875/1875 [13:44<00:00,  2.27it/s]


Epoch [2/10] Test Loss Source: 0.4633, Test Accuracy Source: 96.88%
Epoch [2/10] Test Loss Target: 17.7915, Test Accuracy Target: 46.29%
Epoch [2/10] Test Loss Domain: 0.6932, Test Accuracy Domain: 0.44%
Epoch [2]: New best model saved with test accuracy: 96.88%


Epoch 3: 100%|██████████████████████████████████████████████████| 1875/1875 [13:42<00:00,  2.28it/s]


Epoch [3/10] Test Loss Source: 0.4553, Test Accuracy Source: 97.13%
Epoch [3/10] Test Loss Target: 10.5265, Test Accuracy Target: 50.37%
Epoch [3/10] Test Loss Domain: 0.6931, Test Accuracy Domain: 0.54%
Epoch [3]: New best model saved with test accuracy: 97.13%


Epoch 4: 100%|██████████████████████████████████████████████████| 1875/1875 [13:32<00:00,  2.31it/s]


Epoch [4/10] Test Loss Source: 0.4478, Test Accuracy Source: 96.85%
Epoch [4/10] Test Loss Target: 17.6882, Test Accuracy Target: 44.59%
Epoch [4/10] Test Loss Domain: 0.6930, Test Accuracy Domain: 0.62%


Epoch 5: 100%|██████████████████████████████████████████████████| 1875/1875 [13:34<00:00,  2.30it/s]


Epoch [5/10] Test Loss Source: 0.4047, Test Accuracy Source: 97.34%
Epoch [5/10] Test Loss Target: 12.1314, Test Accuracy Target: 53.26%
Epoch [5/10] Test Loss Domain: 0.6930, Test Accuracy Domain: 0.54%
Epoch [5]: New best model saved with test accuracy: 97.34%


Epoch 6: 100%|██████████████████████████████████████████████████| 1875/1875 [13:35<00:00,  2.30it/s]


Epoch [6/10] Test Loss Source: 0.4345, Test Accuracy Source: 97.17%
Epoch [6/10] Test Loss Target: 14.9701, Test Accuracy Target: 49.98%
Epoch [6/10] Test Loss Domain: 0.6925, Test Accuracy Domain: 0.81%


Epoch 7: 100%|██████████████████████████████████████████████████| 1875/1875 [13:27<00:00,  2.32it/s]


Epoch [7/10] Test Loss Source: 0.4042, Test Accuracy Source: 97.35%
Epoch [7/10] Test Loss Target: 9.0155, Test Accuracy Target: 57.40%
Epoch [7/10] Test Loss Domain: 0.6928, Test Accuracy Domain: 0.81%
Epoch [7]: New best model saved with test accuracy: 97.35%


Epoch 8: 100%|██████████████████████████████████████████████████| 1875/1875 [13:29<00:00,  2.32it/s]


Epoch [8/10] Test Loss Source: 0.4392, Test Accuracy Source: 97.11%
Epoch [8/10] Test Loss Target: 11.2610, Test Accuracy Target: 47.14%
Epoch [8/10] Test Loss Domain: 0.6932, Test Accuracy Domain: 0.32%


Epoch 9: 100%|██████████████████████████████████████████████████| 1875/1875 [13:41<00:00,  2.28it/s]


Epoch [9/10] Test Loss Source: 0.5238, Test Accuracy Source: 97.11%
Epoch [9/10] Test Loss Target: 19.9797, Test Accuracy Target: 44.64%
Epoch [9/10] Test Loss Domain: 0.6931, Test Accuracy Domain: 0.67%


Epoch 10: 100%|█████████████████████████████████████████████████| 1875/1875 [13:41<00:00,  2.28it/s]


Epoch [10/10] Test Loss Source: 0.3570, Test Accuracy Source: 97.70%
Epoch [10/10] Test Loss Target: 13.3914, Test Accuracy Target: 54.81%
Epoch [10/10] Test Loss Domain: 0.6931, Test Accuracy Domain: 0.62%
Epoch [10]: New best model saved with test accuracy: 97.70%


In [None]:
#  TODO:
#  [] make a simpler FD dis / move the layer of dis to be same as visual prompt
#  [] update training style, following adaptive teacher with branches 
#  [] log the model's uncertainty ?  