<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p3_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import os
import csv
import torch
import argparse
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
from torch import optim
import torchvision.transforms as tr
from torch.utils.data import DataLoader

#### Delete

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install gsutil
!gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip

In [None]:
!unzip /content/hw2_data.zip

#### Get cuda from GPU

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


In [37]:
# parser = argparse.ArgumentParser()
# parser.add_argument("path1")
# parser.add_argument("path2")
# args = parser.parse_args()

# testing_path = args.path1
testing_path = '/content/hw2_data/digits/usps/data'

# out_path = args.path2
out_path = '/content/out'
out_path = os.path.join(out_path,'test_pred.csv')

In [None]:
!gdown https://drive.google.com/file/d/13wxXfHHKN01ab-iERBad07Xw3uqG_etQ/view?usp=share_link -O SVHN_304_classifier.pth
!gdown https://drive.google.com/file/d/1VGXAq7wswlgLfpZV30DoDKCvZtlNlTqA/view?usp=share_link -O SVHN_304feature_extractor.pth
!gdown https://drive.google.com/file/d/1pC_Q9JchlIvcFM_1XVrd31giLZlsDjHG/view?usp=share_link -O USPS_323classifier.pth
!gdown https://drive.google.com/file/d/1G8KDHrpenokrUfCTc3HmaoqevcxT1uUo/view?usp=share_link -O USPS_323feature_extractor.pth

#### Construct Dataset

In [31]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, path: str, transform) -> None:
        self.transform = transform
        self.img_paths = sorted([os.path.join(path, i) for i in os.listdir(path) if i.endswith('.png')])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> (torch.Tensor, int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB') # (28, 28, 3)
        img = self.transform(img)
        return img, img_path.split('/')[-1]

In [32]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

testing_ds = MnistDataset(
    path=testing_path,
    transform=tr.Compose([
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std),
    ]),
)

In [33]:
BATCH_SIZE = 128

testing_loader = DataLoader(testing_ds, BATCH_SIZE, shuffle=False, num_workers=4)



#### Domain-Adversarial Training of Neural Networks (DANN)

In [34]:
class FeatureExtractor(nn.Module):
    def __init__(self) -> None:
        super(FeatureExtractor, self).__init__()
        self.extractor = nn.Sequential(
            nn.Conv2d(3, 64, 5), # (64, 24, 24)
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=2), # (64, 11, 11)

            nn.Conv2d(64, 64, 5), # (64, 7, 7)
            nn.BatchNorm2d(64),
            nn.Dropout2d(),
            nn.ReLU(),
            nn.MaxPool2d(3, 2), # (64, 3, 3)

            nn.Conv2d(64, 128, 3), # (128, 1, 1)
        )

    def forward(self, x):
        features = self.extractor(x)
        features = features.view(-1, 128)
        return features

class Classifier(nn.Module):
    def __init__(self) -> None:
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout2d(),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(128, 10),
        )

    def forward(self, features):
        class_label = self.classifier(features)
        return class_label

#### Testing

In [None]:
feature_extractor = FeatureExtractor()
feature_extractor.to(device)
classifier = Classifier()
classifier.to(device)

if 'svhn' in testing_path:
    feature_extractor.load_state_dict(torch.load('SVHN_304feature_extractor.pth', map_location=device))
    classifier.load_state_dict(torch.load('SVHN_304_classifier.pth', map_location=device))
else:
    # feature_extractor.load_state_dict(torch.load('USPS_323feature_extractor.pth', map_location=device))
    # classifier.load_state_dict(torch.load('USPS_323classifier.pth', map_location=device))

    feature_extractor.load_state_dict(torch.load('/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt_USPS/USPS_323feature_extractor.pth', map_location=device))
    classifier.load_state_dict(torch.load('/content/drive/MyDrive/NTU_DLCV/Hw2/p3_ckpt_USPS/USPS_323classifier.pth', map_location=device))

feature_extractor.eval()
classifier.eval()

In [40]:
class_labels = []
ids = []

with torch.no_grad():
    for x, id in testing_loader:
        x = x.to(device)

        features = feature_extractor(x)
        label = classifier(features)
        class_label = label.argmax(-1).cpu().tolist()

        class_labels.extend(class_label)
        ids.extend(id)

    with open(out_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(('image_name', 'label'))
        for i in zip(ids, class_labels):
            writer.writerow(i)