In [None]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms, models

import matplotlib.pyplot as plt

from models import get_model

from sklearn.svm import LinearSVC

from metrics import averageMeter

import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(3)

from sklearn.linear_model import LinearRegression

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    # ResizeImage(resize_size),
    transforms.CenterCrop((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
  ])

In [None]:
def wrap_array(n_samples, n_bins):
    
    n_samples = np.array(n_samples)
    n_samples = n_samples.reshape(n_bins, len(n_samples)//n_bins).mean(1)
    return n_samples

## Load the source and target data

In [None]:
fe = {"arch":"resnet50" , "pretrained":True}
cls = {"arch":"mlpcls", "nonlinear":"none", "feat_size":[2048,256], "n_class":345}

In [None]:
!nvidia-smi

In [None]:
model_fe = get_model(fe, verbose=False).cuda()
model_cls = get_model(cls, verbose=False).cuda()

In [None]:
def percls_accuracy(all_pred, all_label, num_class=0):
    """Computes per class accuracy"""
    num_class = len(set(all_label)) if num_class == 0 else num_class
    all_pred = np.asarray(all_pred)
    all_label = np.asarray(all_label)

    cls_acc = -np.ones([num_class])
    for i in range(num_class):
        idx = (all_label == i)
        if idx.sum() > 0:
            cls_acc[i] = (all_pred[idx] == all_label[idx]).mean() * 100.0

    return cls_acc

In [None]:
source, target = "sketch", "painting"

CBS = True
if CBS:
    model = fine_net(345, batch_norm=True)
    pretrained = torch.load("snapshot/domainNet_full_ablation/MemSAC_%s%s_QS_48000_BS_32_tau_0-07_lambda_0_CAS/best_model.pth.tar"%(source, target))
else:
    model = fine_net(345)
    pretrained = torch.load("snapshot/domainNet_full/CDAN/CDAN_%s%s_QS_48000_BS_32_tau_0-007_lambda_0/best_model.pth.tar"%(source, target))
    pretrained = {k.partition("module.")[-1]:v for k,v in pretrained.items()}
    
# 
model.load_state_dict(pretrained, strict=True)
# 
model = model.cuda()
model = model.eval()

In [None]:
dataset_list = ImageList("/newfoundland/tarun/datasets/Adaptation/visDA/", open("./data/visDA_full/%s_test.txt"%(target)).readlines(), transform=prep.image_test(resize_size=256, crop_size=224))
dataset_loader_target = torch.utils.data.DataLoader(dataset_list, batch_size=64, shuffle=False, num_workers=16, drop_last=False)
nclasses=345

dataset_list = ImageList("/newfoundland/tarun/datasets/Adaptation/visDA/", open("./data/visDA_full/%s_test.txt"%(source)).readlines(), transform=prep.image_test(resize_size=256, crop_size=224))
dataset_loader_source = torch.utils.data.DataLoader(dataset_list, batch_size=64, shuffle=False, num_workers=16, drop_last=False)
nclasses=345

# dataset_list = ImageList("/newfoundland/tarun/datasets/Adaptation/OfficeHome/Dataset10072016/", open("./data/officeHome/Product.txt").readlines(), transform=prep.image_test(resize_size=256, crop_size=224))
# dataset_loader_source = torch.utils.data.DataLoader(dataset_list, batch_size=64, shuffle=False, num_workers=16, drop_last=False)
# nclasses=65

# dataset_list = ImageList("/newfoundland/tarun/datasets/birds/", open("./data/cub200/cub200_2011.txt").readlines(), transform=prep.image_test(resize_size=256, crop_size=224))
# dataset_loader_source = torch.utils.data.DataLoader(dataset_list, batch_size=64, shuffle=False, num_workers=16, drop_last=False)
# nclasses=200

In [None]:
accuracy = AverageMeter()
all_preds = []
all_labels = []
for idx , (image, labels) in enumerate(dataset_loader_target):
    print("{}/{}".format(idx+1, len(dataset_loader_target)), end="\r")
    image = image.cuda()
    labels = labels.cuda()
    with torch.no_grad():
        outputs = model(image)
        predictions = outputs.detach().argmax(1)
    correct = torch.sum((predictions == labels).float())
    accuracy.update(correct/len(outputs), len(outputs))
    all_preds.extend(predictions.cpu().numpy().tolist())
    all_labels.extend(labels.cpu().numpy().tolist())
print_str = "\nCorrect Predictions: {}/{}".format(int(accuracy.sum), accuracy.count)
print_str1 = '\ntest_acc:{:.4f}'.format(accuracy.avg)
print(print_str + print_str1)

In [None]:
classwise_accuracy = percls_accuracy(all_preds, all_labels, 345)
classwise_accuracy = {i:ca for i,ca in enumerate(classwise_accuracy)}

In [None]:
all_labels_source = []
for idx, (image, labels) in enumerate(dataset_loader_source):
    print("{}/{}".format(idx+1, len(dataset_loader_source)), end="\r")
    all_labels_source.extend(labels.numpy().tolist())

In [None]:
n_samples_per_class_target = dict(zip(*np.unique(all_labels, return_counts=True)))
n_samples_per_class_source = dict(zip(*np.unique(all_labels_source, return_counts=True)))

## Plot wrt source

In [None]:
sorted_sample = sorted(n_samples_per_class_source.items(), key=lambda l : l[1], reverse=True)
keyset = [k[0] for k in sorted_sample]
n_samples = [k[1] for k in sorted_sample]
accuracy = [classwise_accuracy[k] for k in keyset]
n_samples = wrap_array(n_samples, n_bins=69)
accuracy = wrap_array(accuracy, n_bins=69)

reg = LinearRegression().fit(np.array(n_samples).reshape(-1,1), accuracy)
print(reg.score(np.array(n_samples).reshape(-1,1), accuracy))

In [None]:
plt.figure(figsize=(8,6))
plt.scatter(np.arange(len(n_samples))*5, n_samples)

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.xlabel("Class Id", fontsize=20)
plt.ylabel("#samples in class", fontsize=20)
plt.show()

In [None]:
plt.figure(figsize=(8,6))

plt.scatter(n_samples, accuracy)
ax = plt.gca()
ax.invert_xaxis()

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.xlabel("#samples in class", fontsize=20)
plt.ylabel("Class Accuracy", fontsize=20)

plt.ylim(0,60)

plt.show()

In [None]:
accuracy

## Plot wrt target

In [None]:
sorted_sample = sorted(n_samples_per_class_target.items(), key=lambda l : l[1], reverse=True)
keyset = [k[0] for k in sorted_sample]
n_samples = [k[1] for k in sorted_sample]
accuracy = [classwise_accuracy[k] for k in keyset]
n_samples = wrap_array(n_samples, n_bins=43)
accuracy = wrap_array(accuracy, n_bins=43)

reg = LinearRegression().fit(np.array(n_samples).reshape(-1,1), accuracy)
print(reg.score(np.array(n_samples).reshape(-1,1), accuracy))

In [None]:
plt.figure(figsize=(8,6))
plt.scatter(np.arange(len(n_samples))*5, n_samples)

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.xlabel("Class Id", fontsize=20)
plt.ylabel("#samples in class", fontsize=20)
plt.show()

In [None]:
plt.figure(figsize=(8,6))

plt.scatter(n_samples, accuracy)
ax = plt.gca()
ax.invert_xaxis()

plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.xlabel("#samples in class", fontsize=20)
plt.ylabel("Class Accuracy", fontsize=20)

plt.ylim(0,54)
plt.show()