<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p3_ModelA_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
from torch import optim
import torchvision.transforms as tr
from torch.utils.data import DataLoader

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Download zip file

In [None]:
!pip install gsutil
!gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip

In [None]:
!unzip /content/hw2_data.zip

#### Get cuda from GPU

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


#### Construct Dataset

In [6]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, label_path: str, join_path: str, transform) -> None:
        self.transform = transform
        self.img_paths = []
        self.img_labels = []

        label_csv = pd.read_csv(label_path).values.tolist()

        for row in label_csv:
            self.img_paths.append(os.path.join(join_path, row[0]))
            self.img_labels.append(row[1])
        assert len(self.img_paths) == len(self.img_labels)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> (torch.Tensor, int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB') # (28, 28, 3)
        img = self.transform(img)

        label = self.img_labels[idx]
        return img, label

In [7]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

source_train_ds = MnistDataset(
    label_path='/content/hw2_data/digits/mnistm/train.csv',
    join_path='/content/hw2_data/digits/mnistm/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

target_train_ds = MnistDataset(
    label_path='/content/hw2_data/digits/svhn/train.csv',
    join_path='/content/hw2_data/digits/svhn/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

target_val_ds = MnistDataset(
    label_path='/content/hw2_data/digits/svhn/val.csv',
    join_path='/content/hw2_data/digits/svhn/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

In [10]:
BATCH_SIZE = 1024

source_train_loader = DataLoader(source_train_ds, BATCH_SIZE, shuffle=True, num_workers=4)
target_train_loader = DataLoader(target_train_ds, BATCH_SIZE, shuffle=True, num_workers=4)
target_val_loader = DataLoader(target_val_ds, BATCH_SIZE, shuffle=False, num_workers=4)

#### Domain-Adversarial Training of Neural Networks (DANN)

In [11]:
class GRL(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
         output = grad_output.neg() * ctx.alpha
         return output, None

class FeatureExtractor(nn.Module):
    def __init__(self) -> None:
        super(FeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(3, 64, 5), # (64, 24, 24)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2), # (64, 11, 11)

            nn.Conv2d(64, 64, 5), # (64, 7, 7)
            nn.BatchNorm2d(64),
            # nn.Dropout2d(),
            nn.ReLU(),
            nn.MaxPool2d(3, 2), # (64, 3, 3)

            nn.Conv2d(64, 128, 3), # (128, 1, 1)
        )

    def forward(self, x):
        features = self.extractor(x)
        features = features.view(-1, 128)
        return features

class Classifier(nn.Module):
    def __init__(self) -> None:
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            # nn.Dropout2d(),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128, 10),
        )

    def forward(self, features):
        class_label = self.classifier(features)
        return class_label

class DomainClassifier(nn.Module):
    def __init__(self) -> None:
        super(DomainClassifier, self).__init__()
        self.domainclassifier = nn.Sequential(
            nn.Linear(128, 256),
            # nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256, 1)
        )

    def forward(self, features, alpha):
        reversed_input = GRL.apply(features, alpha)
        x = self.domainclassifier(reversed_input)
        return x

#### Train

In [12]:
def model_mode(mode: str, models: list) -> None:
    for model in models:
        if mode == 'train':
            model.train()
        elif mode == 'eval':
            model.eval()

In [13]:
EPOCHS = 200

# Earlystopping
patience = 200
counter = 0
best_accuracy = np.inf

# setting parameters
lr = 1e-1
momentum = 0.9

feature_extractor = FeatureExtractor()
feature_extractor.to(device)
classifier = Classifier()
classifier.to(device)
domain_classifier = DomainClassifier()
domain_classifier.to(device)

clf_loss_fn = nn.CrossEntropyLoss().to(device)
domainclf_loss_fn = nn.BCEWithLogitsLoss().to(device) # 'Binary' Cross Entropy With Logits Loss

optimizer = optim.SGD(
    list(feature_extractor.parameters()) + list(classifier.parameters()) + list(domain_classifier.parameters()),
    lr=lr,
    momentum=momentum,
)

# logs dictionary
logs = {'class_label_loss': [], 'domain_label_loss': [], 'val_accuracy': []}

In [17]:
for epoch in range(EPOCHS):
    start_steps = epoch * len(source_train_loader)
    total_steps = EPOCHS * len(target_train_loader)

    for (source_x, source_y), (target_x, target_y) in tqdm(zip(source_train_loader, target_train_loader), total=len(source_train_loader)):
        model_mode('train', [feature_extractor, classifier, domain_classifier])

        source_x, source_y = source_x.to(device, non_blocking=True), source_y.to(device, non_blocking=True)
        target_x, target_y = target_x.to(device, non_blocking=True), target_y.to(device, non_blocking=True)

        # scheduling
        p = start_steps / total_steps
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        optimizer.param_groups[0]['lr'] = lr / (1. + 10 * p) ** 0.75
        optimizer.zero_grad()

        # features extractor
        source_features = feature_extractor(source_x)
        target_features = feature_extractor(target_x)

        # classifier
        # compute class label loss
        class_label = classifier(source_features)
        class_label_loss = clf_loss_fn(class_label, source_y)

        # doamin classifier
        # source domain label = 0
        source_domain_label = domain_classifier(source_features, alpha).squeeze() # (batch_size, 1) -> (batch_size)
        domain_label_source_loss = domainclf_loss_fn(
            source_domain_label,
            torch.zeros(source_y.shape[0], dtype=torch.float, device=device),
        )
        # target domain label = 1
        target_domain_label = domain_classifier(target_features, alpha).squeeze()
        domain_label_target_loss = domainclf_loss_fn(
            target_domain_label,
            torch.ones(target_y.shape[0], dtype=torch.float, device=device),
        )
        # compute domain total loss
        domain_label_loss = domain_label_source_loss + domain_label_target_loss

        total_loss = class_label_loss + domain_label_loss
        total_loss.backward()
        optimizer.step()

        # write logs
        logs['class_label_loss'].append(class_label_loss)
        logs['domain_label_loss'].append(domain_label_loss)


    model_mode('eval', [feature_extractor, classifier, domain_classifier])
    val_accuracy = 0
    with torch.no_grad():
        for target_val_x, target_val_y in tqdm(target_val_loader):
            target_val_x, target_val_y = target_val_x.to(device), target_val_y.cpu().numpy()

            features = feature_extractor(target_val_x)
            class_label = classifier(features)
            class_label = class_label.argmax(-1).cpu().numpy()

            val_accuracy += np.mean((class_label == target_val_y).astype(int))

    val_accuracy /= len(target_val_loader)

    # write logs
    logs['val_accuracy'].append(val_accuracy)

    print(f'EPOCH: {(epoch+1):04d} -> val_accuracy: {val_accuracy:.4f}, class_label_loss: {class_label_loss:.4f}, domain_label_loss: {domain_label_loss:.4f}')

    # chcek improvement
    if val_accuracy >= best_accuracy:
        counter = 0
        best_accuracy = val_accuracy

        torch.save(feature_extractor.state_dict(), f"/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt/{epoch+1}feature_extractor.pth")
        torch.save(classifier.state_dict(), f"/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt/{epoch+1}classifier.pth")
        torch.save(domain_classifier.state_dict(), f"/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt/{epoch+1}domain_classifier.pth")
        print("Model saved!")
    else:
        counter += 1
    if counter >= patience:
        print("Earlystop!")
        break

100%|██████████| 44/44 [00:43<00:00,  1.01it/s]
100%|██████████| 16/16 [00:09<00:00,  1.77it/s]

EPOCH: 0001 -> val_accuracy: 0.3182, class_label_loss: 0.1832, domain_label_loss: 0.6574



100%|██████████| 44/44 [00:38<00:00,  1.13it/s]
100%|██████████| 16/16 [00:05<00:00,  2.85it/s]

EPOCH: 0002 -> val_accuracy: 0.3791, class_label_loss: 0.1469, domain_label_loss: 1.0073



100%|██████████| 44/44 [00:38<00:00,  1.14it/s]
100%|██████████| 16/16 [00:05<00:00,  2.84it/s]

EPOCH: 0003 -> val_accuracy: 0.4029, class_label_loss: 0.1352, domain_label_loss: 1.2543



100%|██████████| 44/44 [00:37<00:00,  1.17it/s]
100%|██████████| 16/16 [00:08<00:00,  1.85it/s]

EPOCH: 0004 -> val_accuracy: 0.4074, class_label_loss: 0.1023, domain_label_loss: 1.5943



100%|██████████| 44/44 [00:36<00:00,  1.21it/s]
100%|██████████| 16/16 [00:07<00:00,  2.05it/s]

EPOCH: 0005 -> val_accuracy: 0.2978, class_label_loss: 0.0900, domain_label_loss: 1.4143



100%|██████████| 44/44 [00:37<00:00,  1.16it/s]
100%|██████████| 16/16 [00:05<00:00,  2.97it/s]

EPOCH: 0006 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:37<00:00,  1.16it/s]
100%|██████████| 16/16 [00:05<00:00,  2.82it/s]

EPOCH: 0007 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:38<00:00,  1.15it/s]
100%|██████████| 16/16 [00:07<00:00,  2.25it/s]

EPOCH: 0008 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:36<00:00,  1.21it/s]
100%|██████████| 16/16 [00:09<00:00,  1.75it/s]

EPOCH: 0009 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:35<00:00,  1.23it/s]
100%|██████████| 16/16 [00:06<00:00,  2.50it/s]

EPOCH: 0010 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:38<00:00,  1.16it/s]
100%|██████████| 16/16 [00:05<00:00,  2.85it/s]

EPOCH: 0011 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:37<00:00,  1.16it/s]
100%|██████████| 16/16 [00:05<00:00,  2.72it/s]

EPOCH: 0012 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:36<00:00,  1.20it/s]
100%|██████████| 16/16 [00:08<00:00,  1.79it/s]

EPOCH: 0013 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:35<00:00,  1.25it/s]
100%|██████████| 16/16 [00:07<00:00,  2.11it/s]

EPOCH: 0014 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:37<00:00,  1.19it/s]
100%|██████████| 16/16 [00:05<00:00,  2.86it/s]

EPOCH: 0015 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:38<00:00,  1.16it/s]
100%|██████████| 16/16 [00:05<00:00,  2.94it/s]

EPOCH: 0016 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:37<00:00,  1.16it/s]
100%|██████████| 16/16 [00:07<00:00,  2.05it/s]

EPOCH: 0017 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:35<00:00,  1.24it/s]
100%|██████████| 16/16 [00:08<00:00,  1.91it/s]

EPOCH: 0018 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



100%|██████████| 44/44 [00:36<00:00,  1.21it/s]
100%|██████████| 16/16 [00:06<00:00,  2.54it/s]

EPOCH: 0019 -> val_accuracy: 0.0673, class_label_loss: nan, domain_label_loss: nan



 82%|████████▏ | 36/44 [00:34<00:07,  1.04it/s]


KeyboardInterrupt: ignored