<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p3_ModelA_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
from torch import optim
import torchvision.transforms as tr
from torch.utils.data import DataLoader

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Download zip file

In [None]:
!pip install gsutil
!gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip

In [None]:
!unzip /content/hw2_data.zip

#### Get cuda from GPU

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


#### Construct Dataset

In [6]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, label_path: str, join_path: str, transform) -> None:
        self.transform = transform
        self.img_paths = []
        self.img_labels = []

        label_csv = pd.read_csv(label_path).values.tolist()

        for row in label_csv:
            self.img_paths.append(os.path.join(join_path, row[0]))
            self.img_labels.append(row[1])
        assert len(self.img_paths) == len(self.img_labels)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> (torch.Tensor, int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB') # (28, 28, 3)
        img = self.transform(img)

        label = self.img_labels[idx]
        return img, label

In [7]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

source_train_ds = MnistDataset(
    label_path='/content/hw2_data/digits/mnistm/train.csv',
    join_path='/content/hw2_data/digits/mnistm/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

target_train_ds = MnistDataset(
    label_path='/content/hw2_data/digits/svhn/train.csv',
    join_path='/content/hw2_data/digits/svhn/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

target_val_ds = MnistDataset(
    label_path='/content/hw2_data/digits/svhn/val.csv',
    join_path='/content/hw2_data/digits/svhn/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

In [None]:
BATCH_SIZE = 512

source_train_loader = DataLoader(source_train_ds, BATCH_SIZE, shuffle=True, num_workers=4)
target_train_loader = DataLoader(target_train_ds, BATCH_SIZE, shuffle=True, num_workers=4)
target_val_loader = DataLoader(target_val_ds, BATCH_SIZE, shuffle=False, num_workers=4)

#### Domain-Adversarial Training of Neural Networks (DANN)

In [9]:
class GRL(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
         output = grad_output.neg() * ctx.alpha
         return output, None

class FeatureExtractor(nn.Module):
    def __init__(self) -> None:
        super(FeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(3, 64, 5), # (64, 24, 24)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2), # (64, 11, 11)

            nn.Conv2d(64, 64, 5), # (64, 7, 7)
            nn.BatchNorm2d(64),
            # nn.Dropout2d(),
            nn.ReLU(),
            nn.MaxPool2d(3, 2), # (64, 3, 3)

            nn.Conv2d(64, 128, 3), # (128, 1, 1)
        )

    def forward(self, x):
        features = self.extractor(x)
        features = features.view(-1, 128)
        return features

class Classifier(nn.Module):
    def __init__(self) -> None:
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            # nn.Dropout2d(),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128, 10),
        )

    def forward(self, features):
        class_label = self.classifier(features)
        return class_label

class DomainClassifier(nn.Module):
    def __init__(self) -> None:
        super(DomainClassifier, self).__init__()
        self.domainclassifier = nn.Sequential(
            nn.Linear(128, 256),
            # nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256, 1),
        )

    def forward(self, features, alpha):
        reversed_input = GRL.apply(features, alpha)
        x = self.domainclassifier(reversed_input)
        return x

#### Train

In [10]:
def model_mode(mode: str, models: list) -> None:
    for model in models:
        if mode == 'train':
            model.train()
        elif mode == 'eval':
            model.eval()

In [17]:
EPOCHS = 200

# Earlystopping
patience = 200
counter = 0
best_accuracy = 0.35

# setting parameters
lr = 1e-3
momentum = 0.9

feature_extractor = FeatureExtractor()
feature_extractor.to(device)
classifier = Classifier()
classifier.to(device)
domain_classifier = DomainClassifier()
domain_classifier.to(device)

clf_loss_fn = nn.CrossEntropyLoss()
domainclf_loss_fn = nn.BCEWithLogitsLoss() # 'Binary' Cross Entropy With Logits Loss

optimizer = optim.SGD(
    list(feature_extractor.parameters()) + list(classifier.parameters()) + list(domain_classifier.parameters()),
    lr=lr,
    momentum=momentum,
)

# logs dictionary
logs = {'class_label_loss': [], 'domain_label_loss': [], 'val_accuracy': []}

In [None]:
for epoch in range(EPOCHS):
    # start_steps = epoch * len(source_train_loader)
    total_steps = EPOCHS * len(target_train_loader)

    for (source_x, source_y), (target_x, target_y) in tqdm(zip(source_train_loader, target_train_loader), total=len(source_train_loader)):
        model_mode('train', [feature_extractor, classifier, domain_classifier])

        source_x, source_y = source_x.to(device, non_blocking=True), source_y.to(device, non_blocking=True)
        target_x, target_y = target_x.to(device, non_blocking=True), target_y.to(device, non_blocking=True)

        # scheduling
        p = epoch / total_steps
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        optimizer.param_groups[0]['lr'] = lr / (1. + 10 * p) ** 0.75
        optimizer.zero_grad()

        # features extractor
        source_features = feature_extractor(source_x)
        target_features = feature_extractor(target_x)

        # classifier
        # compute class label loss
        class_label = classifier(source_features)
        class_label_loss = clf_loss_fn(class_label, source_y)

        # doamin classifier
        # source domain label = 0
        source_domain_label = domain_classifier(source_features, alpha).squeeze() # (batch_size, 1) -> (batch_size)
        domain_label_source_loss = domainclf_loss_fn(
            source_domain_label,
            torch.zeros(source_y.shape[0], dtype=torch.float, device=device),
        )
        # target domain label = 1
        target_domain_label = domain_classifier(target_features, alpha).squeeze()
        domain_label_target_loss = domainclf_loss_fn(
            target_domain_label,
            torch.ones(target_y.shape[0], dtype=torch.float, device=device),
        )
        # compute domain total loss
        domain_label_loss = domain_label_source_loss + domain_label_target_loss

        total_loss = class_label_loss + domain_label_loss
        total_loss.backward()
        optimizer.step()

        # write logs
        logs['class_label_loss'].append(class_label_loss)
        logs['domain_label_loss'].append(domain_label_loss)


    model_mode('eval', [feature_extractor, classifier, domain_classifier])
    val_accuracy = 0
    with torch.no_grad():
        for target_val_x, target_val_y in tqdm(target_val_loader):
            target_val_x, target_val_y = target_val_x.to(device), target_val_y.cpu().numpy()

            features = feature_extractor(target_val_x)
            class_label = classifier(features)
            class_label = class_label.argmax(-1).cpu().numpy()

            val_accuracy += np.mean((class_label == target_val_y).astype(int))

    val_accuracy /= len(target_val_loader)

    # write logs
    logs['val_accuracy'].append(val_accuracy)

    print('\n', f'EPOCH: {(epoch+1):04d} -> val_accuracy: {val_accuracy:.4f}, class_label_loss: {class_label_loss:.4f}, domain_label_loss: {domain_label_loss:.4f}')

    # chcek improvement
    if val_accuracy > best_accuracy:
        counter = 0
        best_accuracy = val_accuracy

        torch.save(feature_extractor.state_dict(), f"/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt/{epoch+1}feature_extractor.pth")
        torch.save(classifier.state_dict(), f"/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt/{epoch+1}classifier.pth")
        torch.save(domain_classifier.state_dict(), f"/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt/{epoch+1}domain_classifier.pth")
        print("--------------------Model saved--------------------")
    else:
        counter += 1
    if counter >= patience:
        print("--------------------Earlystop--------------------")
        break

100%|██████████| 88/88 [00:39<00:00,  2.25it/s]
100%|██████████| 32/32 [00:07<00:00,  4.41it/s]


 EPOCH: 0001 -> val_accuracy: 0.2283, class_label_loss: 1.3647, domain_label_loss: 1.2820



100%|██████████| 88/88 [00:37<00:00,  2.32it/s]
100%|██████████| 32/32 [00:08<00:00,  3.96it/s]



 EPOCH: 0002 -> val_accuracy: 0.2851, class_label_loss: 0.8574, domain_label_loss: 1.1560


100%|██████████| 88/88 [00:38<00:00,  2.29it/s]
100%|██████████| 32/32 [00:05<00:00,  5.59it/s]


 EPOCH: 0003 -> val_accuracy: 0.2987, class_label_loss: 0.6437, domain_label_loss: 1.0453



100%|██████████| 88/88 [00:38<00:00,  2.27it/s]
100%|██████████| 32/32 [00:09<00:00,  3.48it/s]


 EPOCH: 0004 -> val_accuracy: 0.3166, class_label_loss: 0.4514, domain_label_loss: 0.8670



100%|██████████| 88/88 [00:38<00:00,  2.31it/s]
100%|██████████| 32/32 [00:06<00:00,  5.33it/s]


 EPOCH: 0005 -> val_accuracy: 0.3290, class_label_loss: 0.4353, domain_label_loss: 0.7198



100%|██████████| 88/88 [00:38<00:00,  2.30it/s]
100%|██████████| 32/32 [00:06<00:00,  4.90it/s]


 EPOCH: 0006 -> val_accuracy: 0.3369, class_label_loss: 0.3146, domain_label_loss: 0.5646



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:08<00:00,  3.82it/s]


 EPOCH: 0007 -> val_accuracy: 0.3390, class_label_loss: 0.2742, domain_label_loss: 0.4730



100%|██████████| 88/88 [00:37<00:00,  2.33it/s]
100%|██████████| 32/32 [00:05<00:00,  5.51it/s]


 EPOCH: 0008 -> val_accuracy: 0.3396, class_label_loss: 0.2719, domain_label_loss: 0.4360



100%|██████████| 88/88 [00:38<00:00,  2.31it/s]
100%|██████████| 32/32 [00:08<00:00,  4.00it/s]


 EPOCH: 0009 -> val_accuracy: 0.3318, class_label_loss: 0.2117, domain_label_loss: 0.3793



100%|██████████| 88/88 [00:36<00:00,  2.42it/s]
100%|██████████| 32/32 [00:07<00:00,  4.30it/s]


 EPOCH: 0010 -> val_accuracy: 0.3442, class_label_loss: 0.1538, domain_label_loss: 0.3568



100%|██████████| 88/88 [00:37<00:00,  2.34it/s]
100%|██████████| 32/32 [00:05<00:00,  5.59it/s]


 EPOCH: 0011 -> val_accuracy: 0.3367, class_label_loss: 0.1876, domain_label_loss: 0.3455



100%|██████████| 88/88 [00:37<00:00,  2.32it/s]
100%|██████████| 32/32 [00:06<00:00,  4.96it/s]


 EPOCH: 0012 -> val_accuracy: 0.3443, class_label_loss: 0.1730, domain_label_loss: 0.3287



100%|██████████| 88/88 [00:36<00:00,  2.44it/s]
100%|██████████| 32/32 [00:08<00:00,  3.67it/s]


 EPOCH: 0013 -> val_accuracy: 0.3452, class_label_loss: 0.1264, domain_label_loss: 0.2976



100%|██████████| 88/88 [00:37<00:00,  2.35it/s]
100%|██████████| 32/32 [00:05<00:00,  5.58it/s]


 EPOCH: 0014 -> val_accuracy: 0.3453, class_label_loss: 0.1398, domain_label_loss: 0.2961



100%|██████████| 88/88 [00:37<00:00,  2.34it/s]
100%|██████████| 32/32 [00:06<00:00,  5.07it/s]


 EPOCH: 0015 -> val_accuracy: 0.3342, class_label_loss: 0.1155, domain_label_loss: 0.3329



100%|██████████| 88/88 [00:36<00:00,  2.42it/s]
100%|██████████| 32/32 [00:08<00:00,  3.61it/s]


 EPOCH: 0016 -> val_accuracy: 0.3482, class_label_loss: 0.1266, domain_label_loss: 0.3309



100%|██████████| 88/88 [00:36<00:00,  2.39it/s]
100%|██████████| 32/32 [00:05<00:00,  5.35it/s]


 EPOCH: 0017 -> val_accuracy: 0.3486, class_label_loss: 0.1378, domain_label_loss: 0.3630



100%|██████████| 88/88 [00:38<00:00,  2.31it/s]
100%|██████████| 32/32 [00:06<00:00,  5.02it/s]


 EPOCH: 0018 -> val_accuracy: 0.3577, class_label_loss: 0.1197, domain_label_loss: 0.3619
--------------------Model saved--------------------



100%|██████████| 88/88 [00:36<00:00,  2.44it/s]
100%|██████████| 32/32 [00:09<00:00,  3.50it/s]


 EPOCH: 0019 -> val_accuracy: 0.3509, class_label_loss: 0.1430, domain_label_loss: 0.3664



100%|██████████| 88/88 [00:37<00:00,  2.32it/s]
100%|██████████| 32/32 [00:05<00:00,  5.76it/s]


 EPOCH: 0020 -> val_accuracy: 0.3397, class_label_loss: 0.1276, domain_label_loss: 0.4445



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:06<00:00,  5.31it/s]


 EPOCH: 0021 -> val_accuracy: 0.3428, class_label_loss: 0.1065, domain_label_loss: 0.4431



100%|██████████| 88/88 [00:35<00:00,  2.47it/s]
100%|██████████| 32/32 [00:09<00:00,  3.52it/s]


 EPOCH: 0022 -> val_accuracy: 0.3434, class_label_loss: 0.1000, domain_label_loss: 0.6715



100%|██████████| 88/88 [00:36<00:00,  2.41it/s]
100%|██████████| 32/32 [00:06<00:00,  5.22it/s]


 EPOCH: 0023 -> val_accuracy: 0.3162, class_label_loss: 0.0681, domain_label_loss: 0.7731



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:05<00:00,  5.61it/s]


 EPOCH: 0024 -> val_accuracy: 0.3235, class_label_loss: 0.0724, domain_label_loss: 0.8842



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:08<00:00,  3.59it/s]


 EPOCH: 0025 -> val_accuracy: 0.3466, class_label_loss: 0.0779, domain_label_loss: 0.8637



100%|██████████| 88/88 [00:36<00:00,  2.42it/s]
100%|██████████| 32/32 [00:06<00:00,  5.04it/s]


 EPOCH: 0026 -> val_accuracy: 0.3453, class_label_loss: 0.0383, domain_label_loss: 0.9277



100%|██████████| 88/88 [00:37<00:00,  2.32it/s]
100%|██████████| 32/32 [00:05<00:00,  5.63it/s]


 EPOCH: 0027 -> val_accuracy: 0.3464, class_label_loss: 0.0934, domain_label_loss: 0.9837



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:08<00:00,  3.97it/s]


 EPOCH: 0028 -> val_accuracy: 0.3349, class_label_loss: 0.0783, domain_label_loss: 0.9699



100%|██████████| 88/88 [00:35<00:00,  2.49it/s]
100%|██████████| 32/32 [00:07<00:00,  4.06it/s]


 EPOCH: 0029 -> val_accuracy: 0.3447, class_label_loss: 0.0568, domain_label_loss: 0.9368



100%|██████████| 88/88 [00:37<00:00,  2.37it/s]
100%|██████████| 32/32 [00:05<00:00,  5.55it/s]


 EPOCH: 0030 -> val_accuracy: 0.3560, class_label_loss: 0.0452, domain_label_loss: 0.9056



100%|██████████| 88/88 [00:37<00:00,  2.35it/s]
100%|██████████| 32/32 [00:07<00:00,  4.36it/s]


 EPOCH: 0031 -> val_accuracy: 0.3588, class_label_loss: 0.0372, domain_label_loss: 0.9141
--------------------Model saved--------------------



100%|██████████| 88/88 [00:35<00:00,  2.48it/s]
100%|██████████| 32/32 [00:08<00:00,  4.00it/s]


 EPOCH: 0032 -> val_accuracy: 0.3413, class_label_loss: 0.0365, domain_label_loss: 0.8550



100%|██████████| 88/88 [00:37<00:00,  2.35it/s]
100%|██████████| 32/32 [00:05<00:00,  5.41it/s]



 EPOCH: 0033 -> val_accuracy: 0.3612, class_label_loss: 0.0474, domain_label_loss: 0.8582
--------------------Model saved--------------------


100%|██████████| 88/88 [00:37<00:00,  2.34it/s]
100%|██████████| 32/32 [00:06<00:00,  4.70it/s]


 EPOCH: 0034 -> val_accuracy: 0.3472, class_label_loss: 0.0323, domain_label_loss: 0.8646



100%|██████████| 88/88 [00:35<00:00,  2.51it/s]
100%|██████████| 32/32 [00:08<00:00,  3.71it/s]


 EPOCH: 0035 -> val_accuracy: 0.3714, class_label_loss: 0.0446, domain_label_loss: 0.8859
--------------------Model saved--------------------



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:05<00:00,  5.53it/s]


 EPOCH: 0036 -> val_accuracy: 0.3675, class_label_loss: 0.0597, domain_label_loss: 0.9473



100%|██████████| 88/88 [00:37<00:00,  2.37it/s]
100%|██████████| 32/32 [00:06<00:00,  5.23it/s]


 EPOCH: 0037 -> val_accuracy: 0.3549, class_label_loss: 0.0338, domain_label_loss: 0.9130



100%|██████████| 88/88 [00:36<00:00,  2.43it/s]
100%|██████████| 32/32 [00:09<00:00,  3.49it/s]


 EPOCH: 0038 -> val_accuracy: 0.3502, class_label_loss: 0.0394, domain_label_loss: 0.9499



100%|██████████| 88/88 [00:36<00:00,  2.43it/s]
100%|██████████| 32/32 [00:06<00:00,  5.17it/s]


 EPOCH: 0039 -> val_accuracy: 0.3578, class_label_loss: 0.0489, domain_label_loss: 0.9291



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:05<00:00,  5.45it/s]


 EPOCH: 0040 -> val_accuracy: 0.3492, class_label_loss: 0.0356, domain_label_loss: 0.8930



100%|██████████| 88/88 [00:36<00:00,  2.43it/s]
100%|██████████| 32/32 [00:09<00:00,  3.53it/s]


 EPOCH: 0041 -> val_accuracy: 0.3228, class_label_loss: 0.0496, domain_label_loss: 0.9133



100%|██████████| 88/88 [00:37<00:00,  2.37it/s]
100%|██████████| 32/32 [00:05<00:00,  5.44it/s]


 EPOCH: 0042 -> val_accuracy: 0.3630, class_label_loss: 0.0265, domain_label_loss: 0.8913



100%|██████████| 88/88 [00:37<00:00,  2.38it/s]
100%|██████████| 32/32 [00:06<00:00,  4.95it/s]


 EPOCH: 0043 -> val_accuracy: 0.3414, class_label_loss: 0.0223, domain_label_loss: 0.8913



100%|██████████| 88/88 [00:35<00:00,  2.48it/s]
100%|██████████| 32/32 [00:08<00:00,  3.67it/s]


 EPOCH: 0044 -> val_accuracy: 0.3552, class_label_loss: 0.0289, domain_label_loss: 0.8452



100%|██████████| 88/88 [00:36<00:00,  2.38it/s]
100%|██████████| 32/32 [00:05<00:00,  5.57it/s]


 EPOCH: 0045 -> val_accuracy: 0.3496, class_label_loss: 0.0275, domain_label_loss: 0.8223



100%|██████████| 88/88 [00:37<00:00,  2.34it/s]
100%|██████████| 32/32 [00:06<00:00,  4.98it/s]


 EPOCH: 0046 -> val_accuracy: 0.3539, class_label_loss: 0.0270, domain_label_loss: 0.8635



100%|██████████| 88/88 [00:35<00:00,  2.45it/s]
100%|██████████| 32/32 [00:08<00:00,  3.66it/s]


 EPOCH: 0047 -> val_accuracy: 0.3610, class_label_loss: 0.0335, domain_label_loss: 0.8421



100%|██████████| 88/88 [00:38<00:00,  2.31it/s]
100%|██████████| 32/32 [00:05<00:00,  5.39it/s]


 EPOCH: 0048 -> val_accuracy: 0.3540, class_label_loss: 0.0210, domain_label_loss: 0.7991



100%|██████████| 88/88 [00:39<00:00,  2.22it/s]
100%|██████████| 32/32 [00:09<00:00,  3.46it/s]


 EPOCH: 0049 -> val_accuracy: 0.3460, class_label_loss: 0.0299, domain_label_loss: 0.8488



100%|██████████| 88/88 [00:37<00:00,  2.37it/s]
100%|██████████| 32/32 [00:06<00:00,  4.92it/s]


 EPOCH: 0050 -> val_accuracy: 0.3524, class_label_loss: 0.0165, domain_label_loss: 0.8270



100%|██████████| 88/88 [00:38<00:00,  2.31it/s]
100%|██████████| 32/32 [00:05<00:00,  5.35it/s]


 EPOCH: 0051 -> val_accuracy: 0.3468, class_label_loss: 0.0134, domain_label_loss: 0.8244



100%|██████████| 88/88 [00:36<00:00,  2.41it/s]
100%|██████████| 32/32 [00:09<00:00,  3.47it/s]


 EPOCH: 0052 -> val_accuracy: 0.3374, class_label_loss: 0.0187, domain_label_loss: 0.8008



100%|██████████| 88/88 [00:37<00:00,  2.34it/s]
100%|██████████| 32/32 [00:05<00:00,  5.42it/s]


 EPOCH: 0053 -> val_accuracy: 0.3465, class_label_loss: 0.0227, domain_label_loss: 0.8503



100%|██████████| 88/88 [00:37<00:00,  2.34it/s]
100%|██████████| 32/32 [00:06<00:00,  4.91it/s]


 EPOCH: 0054 -> val_accuracy: 0.3489, class_label_loss: 0.0445, domain_label_loss: 0.8124



100%|██████████| 88/88 [00:36<00:00,  2.43it/s]
100%|██████████| 32/32 [00:08<00:00,  3.67it/s]


 EPOCH: 0055 -> val_accuracy: 0.3279, class_label_loss: 0.0099, domain_label_loss: 0.8921



100%|██████████| 88/88 [00:37<00:00,  2.38it/s]
100%|██████████| 32/32 [00:05<00:00,  5.37it/s]


 EPOCH: 0056 -> val_accuracy: 0.3474, class_label_loss: 0.0206, domain_label_loss: 0.8423



100%|██████████| 88/88 [00:37<00:00,  2.35it/s]
100%|██████████| 32/32 [00:06<00:00,  5.24it/s]


 EPOCH: 0057 -> val_accuracy: 0.3515, class_label_loss: 0.0152, domain_label_loss: 0.9036



100%|██████████| 88/88 [00:36<00:00,  2.41it/s]
100%|██████████| 32/32 [00:08<00:00,  3.72it/s]


 EPOCH: 0058 -> val_accuracy: 0.3527, class_label_loss: 0.0157, domain_label_loss: 0.8707



100%|██████████| 88/88 [00:37<00:00,  2.36it/s]
100%|██████████| 32/32 [00:06<00:00,  5.30it/s]


 EPOCH: 0059 -> val_accuracy: 0.3431, class_label_loss: 0.0216, domain_label_loss: 0.9049



100%|██████████| 88/88 [00:39<00:00,  2.25it/s]
100%|██████████| 32/32 [00:07<00:00,  4.12it/s]


 EPOCH: 0060 -> val_accuracy: 0.3542, class_label_loss: 0.0158, domain_label_loss: 0.8995



100%|██████████| 88/88 [00:39<00:00,  2.23it/s]
100%|██████████| 32/32 [00:07<00:00,  4.03it/s]


 EPOCH: 0061 -> val_accuracy: 0.3526, class_label_loss: 0.0171, domain_label_loss: 0.9149



 23%|██▎       | 20/88 [00:09<00:32,  2.07it/s]