In [1]:
import argparse
import datetime
import pathlib
import json
import yaml
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from tqdm import tqdm
from ml_collections import ConfigDict

from core.dataset import COCODatasetWithID
from core.config import save_config
from core.model import Model
from core.metrics import AccuracyLogger, IndividualScoreLogger
import torchvision
import pickle
import timm
import random
import matplotlib.pyplot as plt

In [2]:
import yaml
from torchvision import models

In [3]:
# out_dir = 'v3_output/non_scannet_paired_augmented/'
# model_name = out_dir.split('/')[-2].split('_')[0]

# config_file = os.path.join(out_dir, 'config.yaml')

# config_file

# os.listdir(out_dir)

# checkpoint_file = os.path.join(out_dir, 'checkpoint_5.tar')

# with open(config_file) as f:
#     cfg = ConfigDict(yaml.load(f, Loader=yaml.Loader))

# checkpoint = torch.load(checkpoint_file, map_location="cpu")

# NUM_CLASSES=13

# cfg.pretrained = False

# if model_name == 'resnet':
#     # model_ft = models.resnet18(pretrained=False, num_classes = NUM_CLASSES)
#     model = models.resnet18(pretrained=False)
#     model.fc = nn.Linear(512, NUM_CLASSES)
#     simple_model = True
# elif model_name == 'squeezenet':
#     # model_ft = models.squeezenet1_1(pretrained=False, num_classes = NUM_CLASSES)
#     model = models.squeezenet1_1(pretrained=False)
#     model.classifier[1] = nn.Conv2d(512, NUM_CLASSES, kernel_size=(1,1), stride=(1,1))
#     simple_model = True
# elif model_name == 'densenet':
#     # model_ft = models.densenet121(pretrained=False, num_classes = NUM_CLASSES)
#     model = models.densenet121(pretrained=False)
#     model.classifier = nn.Linear(1024, NUM_CLASSES)
#     simple_model = True
# elif model_name == 'mobilenet':
#     # model_ft = models.mobilenet_v2(pretrained=False, num_classes = NUM_CLASSES)
#     model = models.mobilenet_v2(pretrained=False)
#     model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=NUM_CLASSES)
#     simple_model = True
# elif model_name == 'vit':
#     model = timm.create_model('vit_base_patch16_224', pretrained=False)
#     model.head = nn.Linear(768, NUM_CLASSES)
#     simple_model = True
# else:
#     model = Model.from_config(cfg)
#     simple_model = False
# missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
# assert not missing_keys, "Checkpoint is missing keys required to initialize the model: {}".format(missing_keys)
# if len(unexpected_keys):
#     print("Checkpoint contains unexpected keys that were not used to initialize the model: ")
#     print(unexpected_keys)

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device);

# annotations_file = '../openrooms/annotation_files/scannet_handmade_test_set.json'

# imagedir = '../'

# testset = COCODatasetWithID(annotations_file, imagedir, (224,224), normalize_means=[0.485, 0.456, 0.406], normalize_stds=[0.229, 0.224, 0.225])
# dataloader = DataLoader(testset, batch_size=100, num_workers=1, shuffle=False, drop_last=False)

# class UnNormalize(object):
#     def __init__(self, mean, std):
#         self.mean = mean
#         self.std = std

#     def __call__(self, tensor):
#         """
#         Args:
#             tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
#         Returns:
#             Tensor: Normalized image.
#         """
#         for t, m, s in zip(tensor, self.mean, self.std):
#             t.mul_(s).add_(m)
#             # The normalize code -> t.sub_(m).div_(s)
#         return tensor

# unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

# %matplotlib inline

# testset.idx2label

# def morph_labels(vec):
#     for i in range(len(vec)):
#         if vec[i].item() == 1:
#             vec[i] = 0
#         if vec[i].item() == 6:
#             vec[i] = 3
#     return vec

# from tqdm import tqdm

# corrects = 0
# totals = 0
# model.eval() # set eval mode
# all_matches = []

# with torch.no_grad():
#     for i, (context_images, target_images, bbox, labels_cpu, annotation_ids) in enumerate(tqdm(dataloader, desc="Test Batches", leave=True)):
#         black_images = 0
#         context_images = context_images.to(device)
#         target_images = target_images.to(device)
#         bbox = bbox.to(device)
#         labels = labels_cpu.to(device) # keep a copy of labels on cpu to avoid unnecessary transfer back to cpu later
#         if simple_model:
#             output = model(target_images)
#         else:
#             output = model(context_images, target_images, bbox) 
#         _, predictions = torch.max(output.detach(), 1) # choose idx with maximum score as prediction
# #         morphed_labels = morph_labels(labels)
# #         morphed_predictions = morph_labels(predictions)
# #         match = morphed_predictions == morphed_labels
#         match = predictions == labels
#         all_matches.extend(match)
#         totals += len(match)
#         corrects += torch.sum(match).item()
#         for i in range(len(match)):
#             if not match[i].item():
#                 test = unorm(target_images[i])
#                 if torch.sum(test).item() < 1000:
#                     black_images += 1
#         totals = totals - black_images
# print(corrects/totals)

In [4]:
architectures = ['resnet', 'vit', 'mobilenet', 'squeezenet', 'densenet']
transformations = ['light','viewpoint','material']

In [5]:
import os
import json

In [6]:
perfs = {}
for arch in architectures:
    perfs[arch] = {}
    for trans in transformations:
        fol = 'v3_output/%s_%s_transfer_scratch'%(arch, trans)
        file = os.path.join(fol, 'scannet_handmade_accuracies.json')
        with open(file, 'r') as F:
            contents = json.load(F)
            acc = "%0.02f"%contents['total_accuracy']
            perfs[arch][trans] = acc

In [7]:
perfs

{'resnet': {'light': '0.18', 'viewpoint': '0.17', 'material': '0.07'},
 'vit': {'light': '0.13', 'viewpoint': '0.08', 'material': '0.08'},
 'mobilenet': {'light': '0.22', 'viewpoint': '0.14', 'material': '0.22'},
 'squeezenet': {'light': '0.16', 'viewpoint': '0.15', 'material': '0.22'},
 'densenet': {'light': '0.25', 'viewpoint': '0.14', 'material': '0.23'}}