<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p3_ModelA_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import os
import torch
import pandas as pd
from PIL import Image
import torch.nn as nn
from torch import optim
import torchvision.transforms as tr
from torch.utils.data import DataLoader

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Download zip file

In [None]:
!pip install gsutil
!gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip

In [None]:
!unzip /content/hw2_data.zip

#### Get cuda from GPU

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

#### Construct Dataset

In [14]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, label_path: str, join_path: str, transform) -> None:
        self.transform = transform
        self.img_paths = []
        self.img_labels = []

        label_csv = pd.read_csv(label_path).values.tolist()

        for row_csv in label_csv:
            for row in row_csv:
                self.img_paths.append(os.path.join(join_path, row[0]))
                self.img_labels.append(row[1])
        assert len(self.img_paths) == len(self.img_labels)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> (torch.Tensor, int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)

        label = self.img_labels[idx]
        return img, label

In [None]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

source_train_ds = MnistDataset(
    label_path='/content/hw2_data/digits/mnistm/train.csv',
    join_path='/content/hw2_data/digits/mnistm/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

target_train_ds = MnistDataset(
    label_path='/content/hw2_data/digits/svhn/train.csv',
    join_path='/content/hw2_data/digits/svhn/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

target_val_ds = MnistDataset(
    label_path='/content/hw2_data/digits/svhn/val.csv',
    join_path='/content/hw2_data/digits/svhn/data',
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ])
)

In [None]:
BATCH_SIZE = 512

source_train_loader = DataLoader(source_train_ds, BATCH_SIZE, shuffle=True, num_workers=4)
target_train_loader = DataLoader(target_train_ds, BATCH_SIZE, shuffle=True, num_workers=4)
target_val_loader = DataLoader(target_val_ds, BATCH_SIZE, shuffle=False, num_workers=4)

#### Domain-Adversarial Training of Neural Networks (DANN)

In [1]:
class GRL(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
         output = grad_output.neg() * ctx.alpha
         return output, None

class FeatureExtractor(nn.Module):
    def __init__(self) -> None:
        super(FeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64),
            # nn.Dropout2d(),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

    def forward(self, x):
        features = self.extractor(x)
        features = features.view(-1, 128)
        return features

class Classifier(nn.Module):
    def __init__(self) -> None:
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            # nn.Dropout2d(),

            nn.Linear(512, 128),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Linear(128, 10)
        )

    def forward(self, features):
        class_label = self.classifier(features)
        return class_label

class DomainClassifier(nn.Module):
    def __init__(self) -> None:
        super(DomainClassifier, self).__init__()
        self.domainclassifier = nn.Sequential(
            nn.Linear(128, 256),
            # nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256, 2)
        )

    def forward(self, features, alpha):
        reversed_input = GRL.apply(features, alpha)
        x = self.domainclassifier(reversed_input)
        return x

#### Train

In [None]:
EPOCHS = 200

FeatureExtractor().to(device)
Classifier().to(device)
DomainClassifier.to(device)

clf_loss_fn = nn.CrossEntropyLoss().to(device)
domainclf_loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(
    list(FeatureExtractor.parameters()) + list(Classifier.parameters()) + list(DomainClassifier.parameters()),
    lr=1e-3,
    momentum=0.9,
)