In [17]:
import pandas as pd
from torchvision.datasets import ImageFolder
import re

data_path = '/ix/akovashka/arr159/imagenet-r' 
dataset = ImageFolder(root=data_path)

# extract relative path
img_ids = ['/'.join(path.split('/')[-2:]) for path, _ in dataset.samples]

# extract attributes
pattern = re.compile(r"([^/]+)_(\d+)\.jpg$")

attributes=[]
for path, _ in dataset.samples:
    match = pattern.search(path)
    if match:
        transformation_type = match.group(1)
    else:
        raise Exception
    attributes.append(transformation_type)
        
ground_truth_classes = [label for _, label in dataset.samples]
gt_codes = [path.split('/')[-2] for path, _ in dataset.samples]

df = pd.DataFrame({
    'img_id': img_ids,
    'attribute': attributes,
    'gt_code': gt_codes,
    'gt': ground_truth_classes
})

In [7]:
df['attribute'].unique()

array(['art', 'cartoon', 'deviantart', 'embroidery', 'graffiti',
       'graphic', 'misc', 'origami', 'painting', 'sculpture', 'sketch',
       'sticker', 'toy', 'videogame', 'tattoo'], dtype=object)

In [14]:
from os import path

mapping_raw = open(path.join(data_path, 'README.txt')).readlines()
mapping = {line.split()[0]: line.split()[1].rstrip() for line in mapping_raw[13:]}

In [19]:
df.to_csv('mock_data/dataset.csv')

In [22]:
import json
with open('mock_data/mapping.json', 'w') as json_file:
    json.dump(mapping, json_file, indent=4)

In [31]:
import random

def corruption_fn(df, condition_fn, corruption_matrix, corruption_prob):
    predictions = []
    for _, row in df.iterrows():
        if condition_fn(row):
            if random.random() < corruption_prob:
                sampled_class = random.choices(list(range(len(corruption_matrix[row['gt']]))), weights=corruption_matrix[row['gt']], k=1)[0] 
                predictions.append(sampled_class)
            else:
                predictions.append(row['pred'])
        else:
            predictions.append(row['pred']) # keep it the same
    return predictions

In [35]:
import numpy as np
condition = lambda row: True
df['pred'] = df['gt']
corruption_matrix = np.ones((200,200)) # these are the weights for class change
# set diagonal to zero
np.fill_diagonal(corruption_matrix, 0)

predictions = corruption_fn(df, condition, corruption_matrix, corruption_prob=0.25)
with open('mock_data/pred_splits/split_0.txt', 'w') as f:
    for pred in predictions:
        f.write(f'{pred}\n')

In [44]:
condition = lambda row: row['attribute']=='sketch' and mapping[row['gt_code']] == 'goose'
df['pred'] = df['gt']
corruption_matrix = np.ones((200,200)) # these are the weights for class change
# set diagonal to zero
np.fill_diagonal(corruption_matrix, 0)

predictions = corruption_fn(df, condition, corruption_matrix, corruption_prob=0.6)
df['pred'] = predictions

condition = lambda row: not (row['attribute']=='sketch' and mapping[row['gt_code']] == 'goose')
corruption_matrix = np.ones((200,200)) # these are the weights for class change
# set diagonal to zero
np.fill_diagonal(corruption_matrix, 0)

predictions = corruption_fn(df, condition, corruption_matrix, corruption_prob=0.25)

In [45]:
(np.array(predictions) != df['gt'].values).mean()

0.25016666666666665

In [46]:

with open('mock_data/pred_splits/split_1.txt', 'w') as f:
    for pred in predictions:
        f.write(f'{pred}\n')

In [41]:
mapping

{'n01443537': 'goldfish',
 'n01484850': 'great_white_shark',
 'n01494475': 'hammerhead',
 'n01498041': 'stingray',
 'n01514859': 'hen',
 'n01518878': 'ostrich',
 'n01531178': 'goldfinch',
 'n01534433': 'junco',
 'n01614925': 'bald_eagle',
 'n01616318': 'vulture',
 'n01630670': 'newt',
 'n01632777': 'axolotl',
 'n01644373': 'tree_frog',
 'n01677366': 'iguana',
 'n01694178': 'African_chameleon',
 'n01748264': 'cobra',
 'n01770393': 'scorpion',
 'n01774750': 'tarantula',
 'n01784675': 'centipede',
 'n01806143': 'peacock',
 'n01820546': 'lorikeet',
 'n01833805': 'hummingbird',
 'n01843383': 'toucan',
 'n01847000': 'duck',
 'n01855672': 'goose',
 'n01860187': 'black_swan',
 'n01882714': 'koala',
 'n01910747': 'jellyfish',
 'n01944390': 'snail',
 'n01983481': 'lobster',
 'n01986214': 'hermit_crab',
 'n02007558': 'flamingo',
 'n02009912': 'american_egret',
 'n02051845': 'pelican',
 'n02056570': 'king_penguin',
 'n02066245': 'grey_whale',
 'n02071294': 'killer_whale',
 'n02077923': 'sea_lion',