In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import torch.nn.functional as F

In [2]:
class AlexNet_OH_DOM(nn.Module):
    def __init__(self):
        super(AlexNet_OH_DOM, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels= 96, kernel_size= 11, stride=4, padding=0 )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride= 1, padding= 2)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride= 1, padding= 1)
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.fc1  = nn.Linear(in_features= 6400, out_features= 4096)
        self.fc2  = nn.Linear(in_features= 4096, out_features= 128)
        self.fc3 = nn.Linear(in_features=128 , out_features=4)


    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class AlexNet_OH_OBJ(nn.Module):
    def __init__(self):
        super(AlexNet_OH_OBJ, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels= 96, kernel_size= 11, stride=4, padding=0 )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride= 1, padding= 2)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride= 1, padding= 1)
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.fc1  = nn.Linear(in_features= 6400, out_features= 4096)
        self.fc2  = nn.Linear(in_features= 4096, out_features= 128)
        self.fc3 = nn.Linear(in_features=128 , out_features=65)


    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

In [3]:
PACS_Domain = AlexNet_OH_DOM()
PACS_Domain.load_state_dict(torch.load("1a_Resnet18_dom.pth"))
PACS_Domain.eval()

AlexNet_OH_DOM(
  (conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=6400, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=4, bias=True)
)

In [4]:
PACS_Object = AlexNet_OH_OBJ()
PACS_Object.load_state_dict(torch.load("1b_Resnet18_dom.pth"))
PACS_Object.eval()

AlexNet_OH_OBJ(
  (conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=6400, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=7, bias=True)
)

In [3]:
data_transforms = {
    'train':
    transforms.Compose([
        transforms.Resize((224,224)),
        # transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        # normalize
    ]),
    'validation':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        # normalize
    ]),
}


image_datasets = {
    'test_object': 
    datasets.ImageFolder('/home/jlw2247/vondrick_2/OfficeHome_Test_Object', data_transforms['validation']),
    'test_domain': 
    datasets.ImageFolder('/home/jlw2247/vondrick_2/OfficeHome_Test_Domain', data_transforms['validation'])
    
}

dataloaders = {
    'test_object':
    torch.utils.data.DataLoader(image_datasets['test_object'],
                                batch_size=32,
                                shuffle=False, num_workers=4),

    'test_domain':
    torch.utils.data.DataLoader(image_datasets['test_domain'],
                                batch_size=32,
                                shuffle=False, num_workers=4),
}



In [7]:
import glob
import os
import skimage.io as si
from PIL import Image
from tqdm import tqdm

correct, total = 0,0 

correct_obj = 0
correct_domain = 0

with torch.no_grad():
    domains = [os.path.basename(x) for x in sorted(glob.glob("/home/jlw2247/vondrick_2/OfficeHomeDataset_10072016/*"))]
    for domain in domains:
            classNames = sorted(glob.glob(os.path.join("/home/jlw2247/vondrick_2/OfficeHomeDataset_10072016/", domain, "*")))
            for className_ in tqdm(classNames):
                className = os.path.basename(className_)
                classLabel = image_datasets['test_object'].class_to_idx[className]
                domainLabel = image_datasets['test_domain'].class_to_idx[domain]

                images = sorted(glob.glob(os.path.join(className_,'*')))
                for image in images:
                    img = Image.open(image)
                    img = data_transforms['validation'](img).unsqueeze(0)

                    predicted_objectLabel = torch.argmax(PACS_Object(img)).item()
                    predicted_domainLabel = torch.argmax(PACS_Domain(img)).item()
                    
                    if predicted_objectLabel == classLabel:
                        correct_obj += 1
                    
                    if predicted_domainLabel == domainLabel:
                        correct_domain += 1
                        
                    if predicted_objectLabel == classLabel and predicted_domainLabel == domainLabel:
                        correct += 1

                    total += 1

print(correct, total, correct / total)
print(correct_domain, total, correct_domain / total)
print(correct_obj, total, correct_obj / total)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:45<00:00,  6.43s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:47<00:00,  6.82s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:34<00:00,  4.89s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:16<00:00, 10.97s/it]

1536 9991 0.15373836452807527
9518 9991 0.9526573916524872
1641 9991 0.16424782304073665



